diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py index ca9792d7ea8d4dc8a1a9766a2c4feab306460233..1b6d7e06018c5c4130bab909263d7e4e5716250e 100644 --- a/modules/deeplearning/icing_fcn.py +++ b/modules/deeplearning/icing_fcn.py @@ -993,7 +993,7 @@ class IcingIntensityFCN: def restore(self, ckpt_dir): if TRACK_MOVING_AVERAGE: - ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, averaged_weights=self.model.trainable_variables) + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=self.model.trainable_variables, averaged_weights=self.ema_trainable_variables) else: ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) @@ -1001,6 +1001,9 @@ class IcingIntensityFCN: ckpt.restore(ckpt_manager.latest_checkpoint) + for idx, var in enumerate(self.model.trainable_variables): + var.assign(self.ema_trainable_variables[idx]) + self.reset_test_metrics() for data0, data1, label in self.test_dataset: