diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 606f881ecfcf39919040cf54f1b3a50935b0de29..0cb69040534128416889088a087f0e7b23461bf0 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -915,13 +915,17 @@ class IcingIntensityNN: self.build_dnn(flat) self.model = tf.keras.Model(self.inputs, self.logits) - def restore(self, ckpt_dir): + def restore(self, ckpt_dir, varsdir=None): ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) ckpt.restore(ckpt_manager.latest_checkpoint) + if TRACK_MOVING_AVERAGE: + savr = tf.compat.v1.train.Saver(self.model.trainable_variables) + savr.restore(None, varsdir) + self.test_loss.reset_states() self.test_accuracy.reset_states()