Skip to content
Snippets Groups Projects
Commit eebafafb authored by tomrink's avatar tomrink
Browse files

minor...

parent 97955bd5
Branches
No related tags found
No related merge requests found
......@@ -148,6 +148,7 @@ class IcingIntensityNN:
self.test_loss = None
self.test_accuracy = None
self.test_auc = None
self.test_f1 = None
self.test_recall = None
self.test_precision = None
self.test_confusion_matrix = None
......@@ -462,12 +463,14 @@ class IcingIntensityNN:
self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_f1 = tfa.metrics.F1Score(name='test_f1')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
else:
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_f1 = tfa.metrics.F1Score(name='f1_score')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
......@@ -509,6 +512,7 @@ class IcingIntensityNN:
self.test_accuracy(labels, pred)
if NumClasses == 2:
self.test_auc(labels, pred)
self.test_f1(labels, pred)
self.test_recall(labels, pred)
self.test_precision(labels, pred)
......@@ -524,6 +528,7 @@ class IcingIntensityNN:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
self.test_auc(labels, pred)
self.test_f1(labels, pred)
self.test_recall(labels, pred)
self.test_precision(labels, pred)
......@@ -580,6 +585,7 @@ class IcingIntensityNN:
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
tf.summary.scalar('f1_val', self.test_f1.result(), step=step)
tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
tf.summary.scalar('num_train_steps', step, step=step)
......@@ -601,6 +607,7 @@ class IcingIntensityNN:
self.test_loss.reset_states()
self.test_accuracy.reset_states()
self.test_auc.reset_states()
self.test_f1.reset_states()
self.test_recall.reset_states()
self.test_precision.reset_states()
......@@ -611,7 +618,7 @@ class IcingIntensityNN:
self.test_step(mini_batch)
print('loss, acc, auc, recall, precision: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
self.test_auc.result().numpy(), self.test_recall.result().numpy(), self.test_precision.result().numpy())
self.test_auc.result().numpy(), self.test_f1.result().numpy(), self.test_recall.result().numpy(), self.test_precision.result().numpy())
print('--------------------------------------------------')
ckpt_manager.save()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment