From c10b01291c57dc05d1b62013c676672b242db285 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 24 Apr 2024 15:39:18 -0500
Subject: [PATCH] snapshot...

---
 modules/machine_learning/classification.py | 36 ++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index dcbfa2de..9b532038 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -10,10 +10,46 @@ from sklearn.linear_model import LogisticRegression
 from sklearn.neighbors import KNeighborsClassifier
 from sklearn.tree import DecisionTreeClassifier
 from sklearn.ensemble import RandomForestClassifier
+import itertools
 import sklearn.tree as tree
 from sklearn.tree import export_graphviz
 
 
+def plot_confusion_matrix(cm, classes,
+                          normalize=False,
+                          title='Confusion matrix',
+                          cmap=plt.cm.Blues):
+    """
+    This function prints and plots the confusion matrix.
+    Normalization can be applied by setting `normalize=True`.
+    """
+    if normalize:
+        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
+        print("Normalized confusion matrix")
+    else:
+        print('Confusion matrix, without normalization')
+
+    print(cm)
+
+    plt.imshow(cm, interpolation='nearest', cmap=cmap)
+    plt.title(title)
+    plt.colorbar()
+    tick_marks = np.arange(len(classes))
+    plt.xticks(tick_marks, classes, rotation=45)
+    plt.yticks(tick_marks, classes)
+
+    fmt = '.2f' if normalize else 'd'
+    thresh = cm.max() / 2.
+    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
+        plt.text(j, i, format(cm[i, j], fmt),
+                 horizontalalignment="center",
+                 color="white" if cm[i, j] > thresh else "black")
+
+    plt.tight_layout()
+    plt.ylabel('True label')
+    plt.xlabel('Predicted label')
+
+
 def get_csv_as_dataframe(csv_file, reduce_frac=None):
     icing_df = pd.read_csv(csv_file)
     # Random selection of reduce_frac of the rows
-- 
GitLab