Skip to content
Snippets Groups Projects
Commit 0580e39b authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e2ce003b
No related branches found
No related tags found
No related merge requests found
......@@ -24,8 +24,13 @@ NumFlightLevels = 5
BATCH_SIZE = 128
NUM_EPOCHS = 60
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True
USE_EMA = False
EMA_OVERWRITE_FREQUENCY = 5
EMA_MOMENTUM = 0.99
BETA_1 = 0.9
BETA_2 = 0.999
TRIPLET = False
CONV3D = False
......@@ -739,14 +744,11 @@ class IcingIntensityFCN:
self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE:
# Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.ema.apply(self.model.trainable_variables)
self.ema_trainable_variables = [self.ema.average(var) for var in self.model.trainable_variables]
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule,
beta_1=BETA_1, beta_2=BETA_2,
use_ema=USE_EMA,
ema_momentum=EMA_MOMENTUM,
ma_overwrite_frequency=EMA_OVERWRITE_FREQUENCY)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
......@@ -784,9 +786,6 @@ class IcingIntensityFCN:
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -859,17 +858,9 @@ class IcingIntensityFCN:
def do_training(self, ckpt_dir=None):
model_weights = self.model.trainable_variables
ema_model_weights = None
if TRACK_MOVING_AVERAGE:
model_weights = self.model.trainable_variables
ema_model_weights = self.ema_trainable_variables
if ckpt_dir is None:
if not os.path.exists(modeldir):
os.mkdir(modeldir)
# ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights)
# ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
else:
......@@ -1022,19 +1013,12 @@ class IcingIntensityFCN:
def restore(self, ckpt_dir):
if TRACK_MOVING_AVERAGE:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=self.model.trainable_variables, averaged_weights=self.ema_trainable_variables)
else:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
if TRACK_MOVING_AVERAGE:
for idx, var in enumerate(self.model.trainable_variables):
var.assign(self.ema_trainable_variables[idx])
self.reset_test_metrics()
for data0, data1, label in self.test_dataset:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment