From 383f473b6543ff773347390a3d25f7a0a6e502cb Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 24 May 2021 14:57:35 -0500
Subject: [PATCH] minor...

---
 modules/deeplearning/icing_cnn.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index a5ce63e6..e20b7527 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -19,9 +19,9 @@ PROC_BATCH_BUFFER_SIZE = 50000
 NumClasses = 2
 NumLogits = 1
 BATCH_SIZE = 256
-NUM_EPOCHS = 200
+NUM_EPOCHS = 100
 
-TRACK_MOVING_AVERAGE = True
+TRACK_MOVING_AVERAGE = False
 EARLY_STOP = True
 
 TRIPLET = False
@@ -464,7 +464,7 @@ class IcingIntensityNN:
 
         optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
 
-        if TRACK_MOVING_AVERAGE:
+        if TRACK_MOVING_AVERAGE:  # Not really sure this works properly
             optimizer = tfa.optimizers.MovingAverage(optimizer)
 
         self.optimizer = optimizer
@@ -663,6 +663,9 @@ class IcingIntensityNN:
             print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
                   self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
             print('------------------------------------------------------')
+
+            if TRACK_MOVING_AVERAGE:  # This may not really work properly
+                self.optimizer.assign_average_vars(self.model.trainable_variables)
             ckpt_manager.save()
 
             if self.DISK_CACHE and epoch == 0:
-- 
GitLab