From e22a3abbe79f0a9adb9bf1965b5a58212bc9d3f1 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 21 Oct 2021 14:56:55 -0500
Subject: [PATCH] employ ExponentialWeigthedAverage to training

---
 modules/deeplearning/icing_cnn.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index c1858ab6..e13d3cf4 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -178,6 +178,7 @@ class IcingIntensityNN:
 
         self.model = None
         self.optimizer = None
+        self.ema = None
         self.train_loss = None
         self.train_accuracy = None
         self.test_loss = None
@@ -640,8 +641,10 @@ class IcingIntensityNN:
 
         optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
 
-        if TRACK_MOVING_AVERAGE:  # Not really sure this works properly
-            optimizer = tfa.optimizers.MovingAverage(optimizer)
+        if TRACK_MOVING_AVERAGE:
+            # Not really sure this works properly (from tfa)
+            # optimizer = tfa.optimizers.MovingAverage(optimizer)
+            self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
 
         self.optimizer = optimizer
         self.initial_learning_rate = initial_learning_rate
@@ -677,6 +680,10 @@ class IcingIntensityNN:
                 total_loss = loss + reg_loss
         gradients = tape.gradient(total_loss, self.model.trainable_variables)
         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
+        if TRACK_MOVING_AVERAGE:
+            self.ema.apply(self.model.trainable_variables)
+            # for var in self.model.trainable_variables:
+            #     var.assign(self.ema.average(var))
 
         self.train_loss(loss)
         self.train_accuracy(labels, pred)
@@ -846,8 +853,11 @@ class IcingIntensityNN:
                 print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')
 
-            if TRACK_MOVING_AVERAGE:  # This may not really work properly
-                self.optimizer.assign_average_vars(self.model.trainable_variables)
+            if TRACK_MOVING_AVERAGE:
+                # This may not really work properly (from tfa)
+                # self.optimizer.assign_average_vars(self.model.trainable_variables)
+                for var in self.model.trainable_variables:
+                    var.assign(self.ema.average(var))
 
             tst_loss = self.test_loss.result().numpy()
             if tst_loss < best_test_loss:
-- 
GitLab