Skip to content
Snippets Groups Projects
Commit 0580e39b authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e2ce003b
Branches
No related tags found
No related merge requests found
......@@ -24,8 +24,13 @@ NumFlightLevels = 5
BATCH_SIZE = 128
NUM_EPOCHS = 60
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True
USE_EMA = False
EMA_OVERWRITE_FREQUENCY = 5
EMA_MOMENTUM = 0.99
BETA_1 = 0.9
BETA_2 = 0.999
TRIPLET = False
CONV3D = False
......@@ -739,14 +744,11 @@ class IcingIntensityFCN:
self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE:
# Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.ema.apply(self.model.trainable_variables)
self.ema_trainable_variables = [self.ema.average(var) for var in self.model.trainable_variables]
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule,
beta_1=BETA_1, beta_2=BETA_2,
use_ema=USE_EMA,
ema_momentum=EMA_MOMENTUM,
ma_overwrite_frequency=EMA_OVERWRITE_FREQUENCY)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
......@@ -784,9 +786,6 @@ class IcingIntensityFCN:
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -859,17 +858,9 @@ class IcingIntensityFCN:
def do_training(self, ckpt_dir=None):
model_weights = self.model.trainable_variables
ema_model_weights = None
if TRACK_MOVING_AVERAGE:
model_weights = self.model.trainable_variables
ema_model_weights = self.ema_trainable_variables
if ckpt_dir is None:
if not os.path.exists(modeldir):
os.mkdir(modeldir)
# ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=model_weights, averaged_weights=ema_model_weights)
# ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
else:
......@@ -1022,19 +1013,12 @@ class IcingIntensityFCN:
def restore(self, ckpt_dir):
if TRACK_MOVING_AVERAGE:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model, model_weights=self.model.trainable_variables, averaged_weights=self.ema_trainable_variables)
else:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
if TRACK_MOVING_AVERAGE:
for idx, var in enumerate(self.model.trainable_variables):
var.assign(self.ema_trainable_variables[idx])
self.reset_test_metrics()
for data0, data1, label in self.test_dataset:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment