Skip to content
Snippets Groups Projects
Commit e22a3abb authored by tomrink's avatar tomrink
Browse files

employ ExponentialWeigthedAverage to training

parent 2f80cb49
No related branches found
No related tags found
No related merge requests found
......@@ -178,6 +178,7 @@ class IcingIntensityNN:
self.model = None
self.optimizer = None
self.ema = None
self.train_loss = None
self.train_accuracy = None
self.test_loss = None
......@@ -640,8 +641,10 @@ class IcingIntensityNN:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE: # Not really sure this works properly
optimizer = tfa.optimizers.MovingAverage(optimizer)
if TRACK_MOVING_AVERAGE:
# Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
......@@ -677,6 +680,10 @@ class IcingIntensityNN:
total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
# for var in self.model.trainable_variables:
# var.assign(self.ema.average(var))
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -846,8 +853,11 @@ class IcingIntensityNN:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE: # This may not really work properly
self.optimizer.assign_average_vars(self.model.trainable_variables)
if TRACK_MOVING_AVERAGE:
# This may not really work properly (from tfa)
# self.optimizer.assign_average_vars(self.model.trainable_variables)
for var in self.model.trainable_variables:
var.assign(self.ema.average(var))
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment