diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py index 0905a2e89ac82f2141b0b97f5c26cba45bf6e57a..5052bdf74d14b59807c105be8555d9e1975c9b43 100644 --- a/modules/deeplearning/icing_fcn.py +++ b/modules/deeplearning/icing_fcn.py @@ -841,11 +841,14 @@ class IcingIntensityFCN: if ckpt_dir is None: if not os.path.exists(modeldir): os.mkdir(modeldir) - ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights) + # ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights) + # ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3) + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3) else: ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + ckpt.restore(ckpt_manager.latest_checkpoint) self.writer_train = tf.summary.create_file_writer(os.path.join(logdir, 'plot_train')) self.writer_valid = tf.summary.create_file_writer(os.path.join(logdir, 'plot_valid'))