diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 0cb69040534128416889088a087f0e7b23461bf0..d8cc5523ff3cd46b9f6162929da23c8b90208a2a 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -690,6 +690,8 @@ class IcingIntensityNN:
         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
         if TRACK_MOVING_AVERAGE:
             self.ema.apply(self.model.trainable_variables)
+            for var in self.model.trainable_variables:
+                var.assign(self.ema.average(var))
 
         self.train_loss(loss)
         self.train_accuracy(labels, pred)
@@ -871,14 +873,6 @@ class IcingIntensityNN:
 
                 ckpt_manager.save()
 
-                if TRACK_MOVING_AVERAGE:
-                    vars_ema = []
-                    for var in self.model.trainable_variables:
-                        vars_ema.append(self.ema.average(var))
-
-                    saver_ewa = tf.compat.v1.train.Saver(var_list=vars_ema)
-                    saver_ewa.save(None, ewa_varsdir)
-
             if self.DISK_CACHE and epoch == 0:
                 f = open(cachepath, 'wb')
                 pickle.dump(self.in_mem_data_cache, f)
@@ -915,17 +909,13 @@ class IcingIntensityNN:
         self.build_dnn(flat)
         self.model = tf.keras.Model(self.inputs, self.logits)
 
-    def restore(self, ckpt_dir, varsdir=None):
+    def restore(self, ckpt_dir):
 
         ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
         ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
 
         ckpt.restore(ckpt_manager.latest_checkpoint)
 
-        if TRACK_MOVING_AVERAGE:
-            savr = tf.compat.v1.train.Saver(self.model.trainable_variables)
-            savr.restore(None, varsdir)
-
         self.test_loss.reset_states()
         self.test_accuracy.reset_states()