From 08060ba0b7bee8536809c96a8fcab0e0f16c2626 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 28 Mar 2023 12:48:28 -0500
Subject: [PATCH] snapshot...

---
 modules/util/plot_cm.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/modules/util/plot_cm.py b/modules/util/plot_cm.py
index 67f09d1b..619d00cd 100644
--- a/modules/util/plot_cm.py
+++ b/modules/util/plot_cm.py
@@ -12,7 +12,7 @@ def confusion_matrix_values(correct_labels, predict_labels):
     return cm
 
 
-def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False):
+def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False, axis=1):
     '''
     Parameters:
         correct_labels                  : These are your true classification categories.
@@ -30,7 +30,10 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'conf
     '''
 
     if normalize:
-        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
+        if axis == 1:
+            cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis]
+        elif axis == 0:
+            cm = cm.astype('float') / cm.sum(axis=axis)[np.newaxis, :]
         cm *= 100
         cm = np.nan_to_num(cm, copy=True)
         cm = cm.astype('int')
-- 
GitLab