Skip to content
Snippets Groups Projects
Commit e26d0d0c authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 172f42e6
Branches
No related tags found
No related merge requests found
...@@ -559,11 +559,23 @@ class SRCNN: ...@@ -559,11 +559,23 @@ class SRCNN:
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
def build_evaluation(self): def build_evaluation(self):
self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy')
self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss') self.test_loss = tf.keras.metrics.Mean(name='test_loss')
if NumClasses == 2:
self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
else:
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function @tf.function
def train_step(self, mini_batch): def train_step(self, mini_batch):
inputs = [mini_batch[0]] inputs = [mini_batch[0]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment