Skip to content
Snippets Groups Projects
Commit 13d8bfd3 authored by tomrink's avatar tomrink
Browse files

snapshot... changes to more easily switch between binary and multiclass

parent 6c5f12ed
No related branches found
No related tags found
No related merge requests found
......@@ -495,13 +495,6 @@ class IcingIntensityNN:
else:
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
@tf.function
def train_step(self, mini_batch):
......@@ -562,13 +555,14 @@ class IcingIntensityNN:
def reset_test_metrics(self):
self.test_loss.reset_states()
self.test_accuracy.reset_states()
self.test_auc.reset_states()
self.test_recall.reset_states()
self.test_precision.reset_states()
self.test_true_neg.reset_states()
self.test_true_pos.reset_states()
self.test_false_neg.reset_states()
self.test_false_pos.reset_states()
if NumClasses == 2:
self.test_auc.reset_states()
self.test_recall.reset_states()
self.test_precision.reset_states()
self.test_true_neg.reset_states()
self.test_true_pos.reset_states()
self.test_false_neg.reset_states()
self.test_false_pos.reset_states()
def get_metrics(self):
recall = self.test_recall.result()
......@@ -634,18 +628,20 @@ class IcingIntensityNN:
for mini_batch_test in tst_ds:
self.test_step(mini_batch_test)
f1, mcc = self.get_metrics()
if NumClasses == 2:
f1, mcc = self.get_metrics()
with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
tf.summary.scalar('f1_val', f1, step=step)
tf.summary.scalar('mcc_val', mcc, step=step)
tf.summary.scalar('num_train_steps', step, step=step)
tf.summary.scalar('num_epochs', epoch, step=step)
if NumClasses == 2:
tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
tf.summary.scalar('f1_val', f1, step=step)
tf.summary.scalar('mcc_val', mcc, step=step)
tf.summary.scalar('num_train_steps', step, step=step)
tf.summary.scalar('num_epochs', epoch, step=step)
print('****** test loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
......@@ -667,10 +663,12 @@ class IcingIntensityNN:
for mini_batch in ds:
self.test_step(mini_batch)
f1, mcc = self.get_metrics()
print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
if NumClasses == 2:
f1, mcc = self.get_metrics()
print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
else:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE: # This may not really work properly
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment