Skip to content
Snippets Groups Projects
Commit 13d8bfd3 authored by tomrink's avatar tomrink
Browse files

snapshot... changes to more easily switch between binary and multiclass

parent 6c5f12ed
No related branches found
No related tags found
No related merge requests found
...@@ -495,13 +495,6 @@ class IcingIntensityNN: ...@@ -495,13 +495,6 @@ class IcingIntensityNN:
else: else:
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
@tf.function @tf.function
def train_step(self, mini_batch): def train_step(self, mini_batch):
...@@ -562,13 +555,14 @@ class IcingIntensityNN: ...@@ -562,13 +555,14 @@ class IcingIntensityNN:
def reset_test_metrics(self): def reset_test_metrics(self):
self.test_loss.reset_states() self.test_loss.reset_states()
self.test_accuracy.reset_states() self.test_accuracy.reset_states()
self.test_auc.reset_states() if NumClasses == 2:
self.test_recall.reset_states() self.test_auc.reset_states()
self.test_precision.reset_states() self.test_recall.reset_states()
self.test_true_neg.reset_states() self.test_precision.reset_states()
self.test_true_pos.reset_states() self.test_true_neg.reset_states()
self.test_false_neg.reset_states() self.test_true_pos.reset_states()
self.test_false_pos.reset_states() self.test_false_neg.reset_states()
self.test_false_pos.reset_states()
def get_metrics(self): def get_metrics(self):
recall = self.test_recall.result() recall = self.test_recall.result()
...@@ -634,18 +628,20 @@ class IcingIntensityNN: ...@@ -634,18 +628,20 @@ class IcingIntensityNN:
for mini_batch_test in tst_ds: for mini_batch_test in tst_ds:
self.test_step(mini_batch_test) self.test_step(mini_batch_test)
f1, mcc = self.get_metrics() if NumClasses == 2:
f1, mcc = self.get_metrics()
with self.writer_valid.as_default(): with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step) tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step) tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
tf.summary.scalar('auc_val', self.test_auc.result(), step=step) if NumClasses == 2:
tf.summary.scalar('recall_val', self.test_recall.result(), step=step) tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
tf.summary.scalar('prec_val', self.test_precision.result(), step=step) tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
tf.summary.scalar('f1_val', f1, step=step) tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
tf.summary.scalar('mcc_val', mcc, step=step) tf.summary.scalar('f1_val', f1, step=step)
tf.summary.scalar('num_train_steps', step, step=step) tf.summary.scalar('mcc_val', mcc, step=step)
tf.summary.scalar('num_epochs', epoch, step=step) tf.summary.scalar('num_train_steps', step, step=step)
tf.summary.scalar('num_epochs', epoch, step=step)
print('****** test loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('****** test loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
...@@ -667,10 +663,12 @@ class IcingIntensityNN: ...@@ -667,10 +663,12 @@ class IcingIntensityNN:
for mini_batch in ds: for mini_batch in ds:
self.test_step(mini_batch) self.test_step(mini_batch)
f1, mcc = self.get_metrics() if NumClasses == 2:
f1, mcc = self.get_metrics()
print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy()) self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
else:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------') print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE: # This may not really work properly if TRACK_MOVING_AVERAGE: # This may not really work properly
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment