Skip to content
Snippets Groups Projects
Commit 7f8db191 authored by tomrink's avatar tomrink
Browse files

minor

parent f8b1f0f9
Branches
No related tags found
No related merge requests found
......@@ -408,27 +408,11 @@ class SRCNN:
self.initial_learning_rate = initial_learning_rate
def build_evaluation(self):
#self.train_loss = tf.keras.metrics.Mean(name='train_loss')
#self.test_loss = tf.keras.metrics.Mean(name='test_loss')
self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy')
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
# if NumClasses == 2:
# self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
# self.test_auc = tf.keras.metrics.AUC(name='test_auc')
# self.test_recall = tf.keras.metrics.Recall(name='test_recall')
# self.test_precision = tf.keras.metrics.Precision(name='test_precision')
# self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
# self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
# self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
# self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
# else:
# self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(self, mini_batch):
inputs = [mini_batch[0]]
......@@ -459,14 +443,6 @@ class SRCNN:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def predict(self, mini_batch):
inputs = [mini_batch[0]]
......@@ -483,14 +459,6 @@ class SRCNN:
def reset_test_metrics(self):
self.test_loss.reset_states()
self.test_accuracy.reset_states()
# if NumClasses == 2:
# self.test_auc.reset_states()
# self.test_recall.reset_states()
# self.test_precision.reset_states()
# self.test_true_neg.reset_states()
# self.test_true_pos.reset_states()
# self.test_false_neg.reset_states()
# self.test_false_pos.reset_states()
def get_metrics(self):
recall = self.test_recall.result()
......@@ -570,14 +538,6 @@ class SRCNN:
with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
# if NumClasses == 2:
# tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
# tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
# tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
# tf.summary.scalar('f1_val', f1, step=step)
# tf.summary.scalar('mcc_val', mcc, step=step)
# tf.summary.scalar('num_train_steps', step, step=step)
# tf.summary.scalar('num_epochs', epoch, step=step)
with self.writer_train_valid_loss.as_default():
tf.summary.scalar('loss_trn', loss.numpy(), step=step)
......@@ -605,25 +565,11 @@ class SRCNN:
self.test_step(mini_batch)
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
# print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
# self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
# else:
# print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
best_test_loss = tst_loss
# if NumClasses == 2:
# best_test_acc = self.test_accuracy.result().numpy()
# best_test_recall = self.test_recall.result().numpy()
# best_test_precision = self.test_precision.result().numpy()
# best_test_auc = self.test_auc.result().numpy()
# best_test_f1 = f1.numpy()
# best_test_mcc = mcc.numpy()
ckpt_manager.save()
if EARLY_STOP and es.check_stop(tst_loss):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment