From 868e06cd0ea57c0d16a54213435eeafe7c374822 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 31 Aug 2023 08:40:46 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_fcn_abi.py | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_fcn_abi.py b/modules/deeplearning/cloud_opd_fcn_abi.py
index bc49fe6e..fa1a5c6e 100644
--- a/modules/deeplearning/cloud_opd_fcn_abi.py
+++ b/modules/deeplearning/cloud_opd_fcn_abi.py
@@ -31,7 +31,6 @@ else:
 BATCH_SIZE = 128
 NUM_EPOCHS = 80
 
-TRACK_MOVING_AVERAGE = False
 EARLY_STOP = True
 
 NOISE_TRAINING = False
@@ -41,6 +40,8 @@ DO_AUGMENT = True
 DO_SMOOTH = False
 SIGMA = 1.0
 DO_ZERO_OUT = False
+CACHE_FILE = '/scratch/long/rink/cld_opd_abi_128x128_cache'
+USE_EMA = True
 
 # setup scaling parameters dictionary
 mean_std_dct = {}
@@ -481,12 +482,7 @@ class SRCNN:
 
         self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
 
-        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
-
-        if TRACK_MOVING_AVERAGE:
-            # Not sure that this works properly (from tfa)
-            # optimizer = tfa.optimizers.MovingAverage(optimizer)
-            self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
+        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule, use_ema=USE_EMA)
 
         self.optimizer = optimizer
         self.initial_learning_rate = initial_learning_rate
@@ -509,8 +505,6 @@ class SRCNN:
                 total_loss = loss + reg_loss
         gradients = tape.gradient(total_loss, self.model.trainable_variables)
         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
-        if TRACK_MOVING_AVERAGE:
-            self.ema.apply(self.model.trainable_variables)
 
         self.train_loss(loss)
         self.train_accuracy(labels, pred)
@@ -575,7 +569,7 @@ class SRCNN:
 
         step = 0
         total_time = 0
-        best_test_loss = np.finfo(dtype=np.float).max
+        best_test_loss = np.finfo(dtype=np.float64).max
 
         if EARLY_STOP:
             es = EarlyStop()
-- 
GitLab