diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 0cb69040534128416889088a087f0e7b23461bf0..d8cc5523ff3cd46b9f6162929da23c8b90208a2a 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -690,6 +690,8 @@ class IcingIntensityNN: self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) if TRACK_MOVING_AVERAGE: self.ema.apply(self.model.trainable_variables) + for var in self.model.trainable_variables: + var.assign(self.ema.average(var)) self.train_loss(loss) self.train_accuracy(labels, pred) @@ -871,14 +873,6 @@ class IcingIntensityNN: ckpt_manager.save() - if TRACK_MOVING_AVERAGE: - vars_ema = [] - for var in self.model.trainable_variables: - vars_ema.append(self.ema.average(var)) - - saver_ewa = tf.compat.v1.train.Saver(var_list=vars_ema) - saver_ewa.save(None, ewa_varsdir) - if self.DISK_CACHE and epoch == 0: f = open(cachepath, 'wb') pickle.dump(self.in_mem_data_cache, f) @@ -915,17 +909,13 @@ class IcingIntensityNN: self.build_dnn(flat) self.model = tf.keras.Model(self.inputs, self.logits) - def restore(self, ckpt_dir, varsdir=None): + def restore(self, ckpt_dir): ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) ckpt.restore(ckpt_manager.latest_checkpoint) - if TRACK_MOVING_AVERAGE: - savr = tf.compat.v1.train.Saver(self.model.trainable_variables) - savr.restore(None, varsdir) - self.test_loss.reset_states() self.test_accuracy.reset_states()