Skip to content
Snippets Groups Projects
Commit 67b23074 authored by tomrink's avatar tomrink
Browse files

more work on saving/restoring ewa trained weights

parent 07716af6
No related branches found
No related tags found
No related merge requests found
......@@ -690,6 +690,8 @@ class IcingIntensityNN:
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
for var in self.model.trainable_variables:
var.assign(self.ema.average(var))
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -871,14 +873,6 @@ class IcingIntensityNN:
ckpt_manager.save()
if TRACK_MOVING_AVERAGE:
vars_ema = []
for var in self.model.trainable_variables:
vars_ema.append(self.ema.average(var))
saver_ewa = tf.compat.v1.train.Saver(var_list=vars_ema)
saver_ewa.save(None, ewa_varsdir)
if self.DISK_CACHE and epoch == 0:
f = open(cachepath, 'wb')
pickle.dump(self.in_mem_data_cache, f)
......@@ -915,17 +909,13 @@ class IcingIntensityNN:
self.build_dnn(flat)
self.model = tf.keras.Model(self.inputs, self.logits)
def restore(self, ckpt_dir, varsdir=None):
def restore(self, ckpt_dir):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
if TRACK_MOVING_AVERAGE:
savr = tf.compat.v1.train.Saver(self.model.trainable_variables)
savr.restore(None, varsdir)
self.test_loss.reset_states()
self.test_accuracy.reset_states()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment