From 7f8db191fb15e637b87210ef8ee4fc0397b67ad2 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Thu, 18 Aug 2022 15:44:58 -0500 Subject: [PATCH] minor --- modules/deeplearning/srcnn.py | 54 ----------------------------------- 1 file changed, 54 deletions(-) diff --git a/modules/deeplearning/srcnn.py b/modules/deeplearning/srcnn.py index 6b56b7af..f4985fb9 100644 --- a/modules/deeplearning/srcnn.py +++ b/modules/deeplearning/srcnn.py @@ -408,27 +408,11 @@ class SRCNN: self.initial_learning_rate = initial_learning_rate def build_evaluation(self): - #self.train_loss = tf.keras.metrics.Mean(name='train_loss') - #self.test_loss = tf.keras.metrics.Mean(name='test_loss') self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy') self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy') self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.test_loss = tf.keras.metrics.Mean(name='test_loss') - # if NumClasses == 2: - # self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy') - # self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy') - # self.test_auc = tf.keras.metrics.AUC(name='test_auc') - # self.test_recall = tf.keras.metrics.Recall(name='test_recall') - # self.test_precision = tf.keras.metrics.Precision(name='test_precision') - # self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg') - # self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos') - # self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg') - # self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos') - # else: - # self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') - # self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') - @tf.function def train_step(self, mini_batch): inputs = [mini_batch[0]] @@ -459,14 +443,6 @@ class SRCNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) - # if NumClasses == 2: - # self.test_auc(labels, pred) - # self.test_recall(labels, pred) - # self.test_precision(labels, pred) - # self.test_true_neg(labels, pred) - # self.test_true_pos(labels, pred) - # self.test_false_neg(labels, pred) - # self.test_false_pos(labels, pred) def predict(self, mini_batch): inputs = [mini_batch[0]] @@ -483,14 +459,6 @@ class SRCNN: def reset_test_metrics(self): self.test_loss.reset_states() self.test_accuracy.reset_states() - # if NumClasses == 2: - # self.test_auc.reset_states() - # self.test_recall.reset_states() - # self.test_precision.reset_states() - # self.test_true_neg.reset_states() - # self.test_true_pos.reset_states() - # self.test_false_neg.reset_states() - # self.test_false_pos.reset_states() def get_metrics(self): recall = self.test_recall.result() @@ -570,14 +538,6 @@ class SRCNN: with self.writer_valid.as_default(): tf.summary.scalar('loss_val', self.test_loss.result(), step=step) tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step) - # if NumClasses == 2: - # tf.summary.scalar('auc_val', self.test_auc.result(), step=step) - # tf.summary.scalar('recall_val', self.test_recall.result(), step=step) - # tf.summary.scalar('prec_val', self.test_precision.result(), step=step) - # tf.summary.scalar('f1_val', f1, step=step) - # tf.summary.scalar('mcc_val', mcc, step=step) - # tf.summary.scalar('num_train_steps', step, step=step) - # tf.summary.scalar('num_epochs', epoch, step=step) with self.writer_train_valid_loss.as_default(): tf.summary.scalar('loss_trn', loss.numpy(), step=step) @@ -605,25 +565,11 @@ class SRCNN: self.test_step(mini_batch) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) - # if NumClasses == 2: - # f1, mcc = self.get_metrics() - # print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), - # self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy()) - # else: - # print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('------------------------------------------------------') tst_loss = self.test_loss.result().numpy() if tst_loss < best_test_loss: best_test_loss = tst_loss - # if NumClasses == 2: - # best_test_acc = self.test_accuracy.result().numpy() - # best_test_recall = self.test_recall.result().numpy() - # best_test_precision = self.test_precision.result().numpy() - # best_test_auc = self.test_auc.result().numpy() - # best_test_f1 = f1.numpy() - # best_test_mcc = mcc.numpy() - ckpt_manager.save() if EARLY_STOP and es.check_stop(tst_loss): -- GitLab