diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py
index 502dbabdf75be559ce7e5dba105bd54fb691c4c9..508c88ecf9a6d9f4d348e30d38d1ee77fdd9593c 100644
--- a/modules/deeplearning/icing_fcn.py
+++ b/modules/deeplearning/icing_fcn.py
@@ -24,8 +24,13 @@ NumFlightLevels = 5
 BATCH_SIZE = 128
 NUM_EPOCHS = 60
 
-TRACK_MOVING_AVERAGE = False
+
 EARLY_STOP = True
+USE_EMA = False
+EMA_OVERWRITE_FREQUENCY = 5
+EMA_MOMENTUM = 0.99
+BETA_1 = 0.9
+BETA_2 = 0.999
 
 TRIPLET = False
 CONV3D = False
@@ -739,14 +744,11 @@ class IcingIntensityFCN:
 
         self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
 
-        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
-
-        if TRACK_MOVING_AVERAGE:
-            # Not really sure this works properly (from tfa)
-            # optimizer = tfa.optimizers.MovingAverage(optimizer)
-            self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
-            self.ema.apply(self.model.trainable_variables)
-            self.ema_trainable_variables = [self.ema.average(var) for var in self.model.trainable_variables]
+        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule,
+                                             beta_1=BETA_1, beta_2=BETA_2,
+                                             use_ema=USE_EMA,
+                                             ema_momentum=EMA_MOMENTUM,
+                                             ma_overwrite_frequency=EMA_OVERWRITE_FREQUENCY)
 
         self.optimizer = optimizer
         self.initial_learning_rate = initial_learning_rate
@@ -784,9 +786,6 @@ class IcingIntensityFCN:
         gradients = tape.gradient(total_loss, self.model.trainable_variables)
         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
 
-        if TRACK_MOVING_AVERAGE:
-            self.ema.apply(self.model.trainable_variables)
-
         self.train_loss(loss)
         self.train_accuracy(labels, pred)
 
@@ -859,17 +858,9 @@ class IcingIntensityFCN:
 
     def do_training(self, ckpt_dir=None):
 
-        model_weights = self.model.trainable_variables
-        ema_model_weights = None
-        if TRACK_MOVING_AVERAGE:
-            model_weights = self.model.trainable_variables
-            ema_model_weights = self.ema_trainable_variables
-
         if ckpt_dir is None:
             if not os.path.exists(modeldir):
                 os.mkdir(modeldir)
-            # ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights)
-            # ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
             ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
             ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
         else:
@@ -1022,19 +1013,12 @@ class IcingIntensityFCN:
 
     def restore(self, ckpt_dir):
 
-        if TRACK_MOVING_AVERAGE:
-            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=self.model.trainable_variables, averaged_weights=self.ema_trainable_variables)
-        else:
-            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
+        ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
 
         ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
 
         ckpt.restore(ckpt_manager.latest_checkpoint)
 
-        if TRACK_MOVING_AVERAGE:
-            for idx, var in enumerate(self.model.trainable_variables):
-                var.assign(self.ema_trainable_variables[idx])
-
         self.reset_test_metrics()
 
         for data0, data1, label in self.test_dataset: