From 383f473b6543ff773347390a3d25f7a0a6e502cb Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 24 May 2021 14:57:35 -0500 Subject: [PATCH] minor... --- modules/deeplearning/icing_cnn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index a5ce63e6..e20b7527 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -19,9 +19,9 @@ PROC_BATCH_BUFFER_SIZE = 50000 NumClasses = 2 NumLogits = 1 BATCH_SIZE = 256 -NUM_EPOCHS = 200 +NUM_EPOCHS = 100 -TRACK_MOVING_AVERAGE = True +TRACK_MOVING_AVERAGE = False EARLY_STOP = True TRIPLET = False @@ -464,7 +464,7 @@ class IcingIntensityNN: optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule) - if TRACK_MOVING_AVERAGE: + if TRACK_MOVING_AVERAGE: # Not really sure this works properly optimizer = tfa.optimizers.MovingAverage(optimizer) self.optimizer = optimizer @@ -663,6 +663,9 @@ class IcingIntensityNN: print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy()) print('------------------------------------------------------') + + if TRACK_MOVING_AVERAGE: # This may not really work properly + self.optimizer.assign_average_vars(self.model.trainable_variables) ckpt_manager.save() if self.DISK_CACHE and epoch == 0: -- GitLab