diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py
index c8f4d4b4493d916d7ca0ff21465e4b6873324f42..db2fdc69c06af3febf63aaff616074a239ba6d5f 100644
--- a/modules/deeplearning/icing_fcn.py
+++ b/modules/deeplearning/icing_fcn.py
@@ -256,6 +256,8 @@ class IcingIntensityFCN:
 
         tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
 
+        self.ema_trainable_variables = None
+
         # Doesn't seem to play well with SLURM
         # gpus = tf.config.experimental.list_physical_devices('GPU')
         # if gpus:
@@ -716,6 +718,8 @@ class IcingIntensityFCN:
             # Not really sure this works properly (from tfa)
             # optimizer = tfa.optimizers.MovingAverage(optimizer)
             self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
+            self.ema.apply(self.model.trainable_variables)
+            self.ema_trainable_variables = self.ema.average(self.model.trainable_variables)
 
         self.optimizer = optimizer
         self.initial_learning_rate = initial_learning_rate
@@ -752,6 +756,7 @@ class IcingIntensityFCN:
                 total_loss = loss + reg_loss
         gradients = tape.gradient(total_loss, self.model.trainable_variables)
         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
+
         if TRACK_MOVING_AVERAGE:
             self.ema.apply(self.model.trainable_variables)
 
@@ -827,10 +832,16 @@ class IcingIntensityFCN:
 
     def do_training(self, ckpt_dir=None):
 
+        model_weights = self.model.trainable_variables
+        ema_model_weights = None
+        if TRACK_MOVING_AVERAGE:
+            model_weights = self.model.trainable_variables
+            ema_model_weights = self.ema_trainable_variables
+
         if ckpt_dir is None:
             if not os.path.exists(modeldir):
                 os.mkdir(modeldir)
-            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
+            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights)
             ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
         else:
             ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)