From 6df07956db0ab3626a2884a5b7f77f21afcd5b76 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 8 Feb 2023 12:16:19 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/srcnn_cld_frac.py | 55 +++++++++++++++++++++++---
 1 file changed, 50 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/srcnn_cld_frac.py b/modules/deeplearning/srcnn_cld_frac.py
index a9895659..c45a69d8 100644
--- a/modules/deeplearning/srcnn_cld_frac.py
+++ b/modules/deeplearning/srcnn_cld_frac.py
@@ -904,28 +904,73 @@ def analyze(file='/Users/tomrink/cld_opd_frac.npy'):
     lbls_1_2 = np.where(lbls_1_2 == 2, 1, lbls_1_2)
 
     pred_1_2 = pred[msk_1_2]
-    pred_1_2 = np.where(pred_1_2 == 0, 1, pred_1_2)
+    pred_1_2 = np.where(pred_1_2 == 0, -9, pred_1_2)
     pred_1_2 = np.where(pred_1_2 == 1, 0, pred_1_2)
     pred_1_2 = np.where(pred_1_2 == 2, 1, pred_1_2)
+    pred_1_2 = np.where(pred_1_2 == -9, 1, pred_1_2)
 
     # ----
     lbls_0_2 = lbls[msk_0_2]
-    lbls_0_2 = np.where(lbls_0_2 == 0, 0, lbls_0_2)
     lbls_0_2 = np.where(lbls_0_2 == 2, 1, lbls_0_2)
 
     pred_0_2 = pred[msk_0_2]
-    pred_0_2 = np.where(pred_0_2 == 0, 0, pred_0_2)
-    pred_0_2 = np.where(pred_0_2 == 1, 1, pred_0_2)
     pred_0_2 = np.where(pred_0_2 == 2, 1, pred_0_2)
 
     cm_0_1 = confusion_matrix_values(lbls_0_1, pred_0_1)
     cm_1_2 = confusion_matrix_values(lbls_1_2, pred_1_2)
     cm_0_2 = confusion_matrix_values(lbls_0_2, pred_0_2)
 
-    return cm_0_1, cm_1_2,  cm_0_2
+    true_0_1 = (lbls_0_1 == 0) & (pred_0_1 == 0)
+    false_0_1 = (lbls_0_1 == 1) & (pred_0_1 == 0)
 
+    true_no_0_1 = (lbls_0_1 == 1) & (pred_0_1 == 1)
+    false_no_0_1 = (lbls_0_1 == 0) & (pred_0_1 == 1)
 
+    true_0_2 = (lbls_0_2 == 0) & (pred_0_2 == 0)
+    false_0_2 = (lbls_0_2 == 1) & (pred_0_2 == 0)
 
+    true_no_0_2 = (lbls_0_2 == 1) & (pred_0_2 == 1)
+    false_no_0_2 = (lbls_0_2 == 0) & (pred_0_2 == 1)
+
+    true_1_2 = (lbls_1_2 == 0) & (pred_1_2 == 0)
+    false_1_2 = (lbls_1_2 == 1) & (pred_1_2 == 0)
+
+    true_no_1_2 = (lbls_1_2 == 1) & (pred_1_2 == 1)
+    false_no_1_2 = (lbls_1_2 == 0) & (pred_1_2 == 1)
+
+    tp_0 = np.sum(true_0_1)
+    tp_1 = np.sum(true_1_2)
+    tp_2 = np.sum(true_0_2)
+
+    tn_0 = np.sum(true_no_0_1)
+    tn_1 = np.sum(true_no_1_2)
+    tn_2 = np.sum(true_no_0_2)
+
+    fp_0 = np.sum(false_0_1)
+    fp_1 = np.sum(false_1_2)
+    fp_2 = np.sum(false_0_2)
+
+    fn_0 = np.sum(false_no_0_1)
+    fn_1 = np.sum(false_no_1_2)
+    fn_2 = np.sum(false_no_0_2)
+
+    recall_0 = tp_0 / (tp_0 + fn_0)
+    recall_1 = tp_1 / (tp_1 + fn_1)
+    recall_2 = tp_2 / (tp_2 + fn_2)
+
+    precision_0 = tp_0 / (tp_0 + fp_0)
+    precision_1 = tp_1 / (tp_1 + fp_1)
+    precision_2 = tp_2 / (tp_2 + fp_2)
+
+    mcc_0 = ((tp_0 * tn_0) - (fp_0 * fn_0)) / np.sqrt((tp_0 + fp_0) * (tp_0 + fn_0) * (tn_0 + fp_0) * (tn_0 + fn_0))
+    mcc_1 = ((tp_1 * tn_1) - (fp_1 * fn_1)) / np.sqrt((tp_1 + fp_1) * (tp_1 + fn_1) * (tn_1 + fp_1) * (tn_1 + fn_1))
+    mcc_2 = ((tp_2 * tn_2) - (fp_2 * fn_2)) / np.sqrt((tp_2 + fp_2) * (tp_2 + fn_2) * (tn_2 + fp_2) * (tn_2 + fn_2))
+
+    print(recall_0, precision_0, mcc_0)
+    print(recall_1, precision_1, mcc_1)
+    print(recall_2, precision_2, mcc_2)
+
+    return cm_0_1, cm_1_2, cm_0_2
 
 
 if __name__ == "__main__":
-- 
GitLab