From 08060ba0b7bee8536809c96a8fcab0e0f16c2626 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Tue, 28 Mar 2023 12:48:28 -0500 Subject: [PATCH] snapshot... --- modules/util/plot_cm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/util/plot_cm.py b/modules/util/plot_cm.py index 67f09d1b..619d00cd 100644 --- a/modules/util/plot_cm.py +++ b/modules/util/plot_cm.py @@ -12,7 +12,7 @@ def confusion_matrix_values(correct_labels, predict_labels): return cm -def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False): +def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False, axis=1): ''' Parameters: correct_labels : These are your true classification categories. @@ -30,7 +30,10 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'conf ''' if normalize: - cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + if axis == 1: + cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis] + elif axis == 0: + cm = cm.astype('float') / cm.sum(axis=axis)[np.newaxis, :] cm *= 100 cm = np.nan_to_num(cm, copy=True) cm = cm.astype('int') -- GitLab