From e26d0d0c0762a5a6e305fce0a938a70f511f497c Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 23 Jan 2023 18:57:10 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/srcnn_cld_frac.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/srcnn_cld_frac.py b/modules/deeplearning/srcnn_cld_frac.py
index 9ae109ca..527a607d 100644
--- a/modules/deeplearning/srcnn_cld_frac.py
+++ b/modules/deeplearning/srcnn_cld_frac.py
@@ -559,11 +559,23 @@ class SRCNN:
         self.initial_learning_rate = initial_learning_rate
 
     def build_evaluation(self):
-        self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy')
-        self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy')
         self.train_loss = tf.keras.metrics.Mean(name='train_loss')
         self.test_loss = tf.keras.metrics.Mean(name='test_loss')
 
+        if NumClasses == 2:
+            self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
+            self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
+            self.test_auc = tf.keras.metrics.AUC(name='test_auc')
+            self.test_recall = tf.keras.metrics.Recall(name='test_recall')
+            self.test_precision = tf.keras.metrics.Precision(name='test_precision')
+            self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
+            self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
+            self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
+            self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
+        else:
+            self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
+            self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
+
     @tf.function
     def train_step(self, mini_batch):
         inputs = [mini_batch[0]]
-- 
GitLab