Skip to content
Snippets Groups Projects
Commit f3f7f63e authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 3c2f1b33
Branches
No related tags found
No related merge requests found
...@@ -256,6 +256,8 @@ class IcingIntensityFCN: ...@@ -256,6 +256,8 @@ class IcingIntensityFCN:
tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT) tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
self.ema_trainable_variables = None
# Doesn't seem to play well with SLURM # Doesn't seem to play well with SLURM
# gpus = tf.config.experimental.list_physical_devices('GPU') # gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus: # if gpus:
...@@ -716,6 +718,8 @@ class IcingIntensityFCN: ...@@ -716,6 +718,8 @@ class IcingIntensityFCN:
# Not really sure this works properly (from tfa) # Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer) # optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999) self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.ema.apply(self.model.trainable_variables)
self.ema_trainable_variables = self.ema.average(self.model.trainable_variables)
self.optimizer = optimizer self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
...@@ -752,6 +756,7 @@ class IcingIntensityFCN: ...@@ -752,6 +756,7 @@ class IcingIntensityFCN:
total_loss = loss + reg_loss total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables) gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE: if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables) self.ema.apply(self.model.trainable_variables)
...@@ -827,10 +832,16 @@ class IcingIntensityFCN: ...@@ -827,10 +832,16 @@ class IcingIntensityFCN:
def do_training(self, ckpt_dir=None): def do_training(self, ckpt_dir=None):
model_weights = self.model.trainable_variables
ema_model_weights = None
if TRACK_MOVING_AVERAGE:
model_weights = self.model.trainable_variables
ema_model_weights = self.ema_trainable_variables
if ckpt_dir is None: if ckpt_dir is None:
if not os.path.exists(modeldir): if not os.path.exists(modeldir):
os.mkdir(modeldir) os.mkdir(modeldir)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights)
ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3) ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
else: else:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment