Skip to content
Snippets Groups Projects
Commit 0580e39b authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e2ce003b
No related branches found
No related tags found
No related merge requests found
...@@ -24,8 +24,13 @@ NumFlightLevels = 5 ...@@ -24,8 +24,13 @@ NumFlightLevels = 5
BATCH_SIZE = 128 BATCH_SIZE = 128
NUM_EPOCHS = 60 NUM_EPOCHS = 60
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True EARLY_STOP = True
USE_EMA = False
EMA_OVERWRITE_FREQUENCY = 5
EMA_MOMENTUM = 0.99
BETA_1 = 0.9
BETA_2 = 0.999
TRIPLET = False TRIPLET = False
CONV3D = False CONV3D = False
...@@ -739,14 +744,11 @@ class IcingIntensityFCN: ...@@ -739,14 +744,11 @@ class IcingIntensityFCN:
self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate) self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule) optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule,
beta_1=BETA_1, beta_2=BETA_2,
if TRACK_MOVING_AVERAGE: use_ema=USE_EMA,
# Not really sure this works properly (from tfa) ema_momentum=EMA_MOMENTUM,
# optimizer = tfa.optimizers.MovingAverage(optimizer) ma_overwrite_frequency=EMA_OVERWRITE_FREQUENCY)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.ema.apply(self.model.trainable_variables)
self.ema_trainable_variables = [self.ema.average(var) for var in self.model.trainable_variables]
self.optimizer = optimizer self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
...@@ -784,9 +786,6 @@ class IcingIntensityFCN: ...@@ -784,9 +786,6 @@ class IcingIntensityFCN:
gradients = tape.gradient(total_loss, self.model.trainable_variables) gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
self.train_loss(loss) self.train_loss(loss)
self.train_accuracy(labels, pred) self.train_accuracy(labels, pred)
...@@ -859,17 +858,9 @@ class IcingIntensityFCN: ...@@ -859,17 +858,9 @@ class IcingIntensityFCN:
def do_training(self, ckpt_dir=None): def do_training(self, ckpt_dir=None):
model_weights = self.model.trainable_variables
ema_model_weights = None
if TRACK_MOVING_AVERAGE:
model_weights = self.model.trainable_variables
ema_model_weights = self.ema_trainable_variables
if ckpt_dir is None: if ckpt_dir is None:
if not os.path.exists(modeldir): if not os.path.exists(modeldir):
os.mkdir(modeldir) os.mkdir(modeldir)
# ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights)
# ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3) ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
else: else:
...@@ -1022,19 +1013,12 @@ class IcingIntensityFCN: ...@@ -1022,19 +1013,12 @@ class IcingIntensityFCN:
def restore(self, ckpt_dir): def restore(self, ckpt_dir):
if TRACK_MOVING_AVERAGE: ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=self.model.trainable_variables, averaged_weights=self.ema_trainable_variables)
else:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint) ckpt.restore(ckpt_manager.latest_checkpoint)
if TRACK_MOVING_AVERAGE:
for idx, var in enumerate(self.model.trainable_variables):
var.assign(self.ema_trainable_variables[idx])
self.reset_test_metrics() self.reset_test_metrics()
for data0, data1, label in self.test_dataset: for data0, data1, label in self.test_dataset:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment