diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 2ee9f47d8fa1a74e578ecbb73c25576644a612c7..78e551d7ddb6722174ee06b9f0bce8dc1a851f8f 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -178,6 +178,7 @@ class IcingIntensityNN:
 
         self.model = None
         self.optimizer = None
+        self.ema = None
         self.train_loss = None
         self.train_accuracy = None
         self.test_loss = None
@@ -647,8 +648,10 @@ class IcingIntensityNN:
 
         optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
 
-        if TRACK_MOVING_AVERAGE:  # Not really sure this works properly
-            optimizer = tfa.optimizers.MovingAverage(optimizer)
+        if TRACK_MOVING_AVERAGE:
+            # Not really sure this works properly (from tfa)
+            # optimizer = tfa.optimizers.MovingAverage(optimizer)
+            self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
 
         self.optimizer = optimizer
         self.initial_learning_rate = initial_learning_rate
@@ -684,6 +687,11 @@ class IcingIntensityNN:
                 total_loss = loss + reg_loss
         gradients = tape.gradient(total_loss, self.model.trainable_variables)
         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
+        if TRACK_MOVING_AVERAGE:
+            self.ema.apply(self.model.trainable_variables)
+            # TODO: This doesn't seem to work
+            # for var in self.model.trainable_variables:
+            #     var.assign(self.ema.average(var))
 
         self.train_loss(loss)
         self.train_accuracy(labels, pred)
@@ -853,8 +861,11 @@ class IcingIntensityNN:
                 print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')
 
-            if TRACK_MOVING_AVERAGE:  # This may not really work properly
-                self.optimizer.assign_average_vars(self.model.trainable_variables)
+            if TRACK_MOVING_AVERAGE:
+                # This may not really work properly (from tfa)
+                # self.optimizer.assign_average_vars(self.model.trainable_variables)
+                for var in self.model.trainable_variables:
+                    var.assign(self.ema.average(var))
 
             tst_loss = self.test_loss.result().numpy()
             if tst_loss < best_test_loss: