From 5f9cfdc1217ffd14af6d850a3819417b0ef1b710 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Fri, 4 Nov 2022 11:31:35 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/cnn_cld_frac.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py index 3d98de98..8c316773 100644 --- a/modules/deeplearning/cnn_cld_frac.py +++ b/modules/deeplearning/cnn_cld_frac.py @@ -499,11 +499,23 @@ class CNN: self.initial_learning_rate = initial_learning_rate def build_evaluation(self): - self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy') - self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy') self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.test_loss = tf.keras.metrics.Mean(name='test_loss') + if NumClasses == 2: + self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy') + self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy') + self.test_auc = tf.keras.metrics.AUC(name='test_auc') + self.test_recall = tf.keras.metrics.Recall(name='test_recall') + self.test_precision = tf.keras.metrics.Precision(name='test_precision') + self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg') + self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos') + self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg') + self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos') + else: + self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') + self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') + @tf.function def train_step(self, mini_batch): inputs = [mini_batch[0]] -- GitLab