Skip to content
Snippets Groups Projects
Commit 48c567e2 authored by tomrink's avatar tomrink
Browse files

Merge remote-tracking branch 'origin/master' into use_flight_altitude

parents 2588ad43 cb090084
Branches
No related tags found
No related merge requests found
......@@ -178,6 +178,7 @@ class IcingIntensityNN:
self.model = None
self.optimizer = None
self.ema = None
self.train_loss = None
self.train_accuracy = None
self.test_loss = None
......@@ -647,8 +648,10 @@ class IcingIntensityNN:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE: # Not really sure this works properly
optimizer = tfa.optimizers.MovingAverage(optimizer)
if TRACK_MOVING_AVERAGE:
# Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
......@@ -684,6 +687,11 @@ class IcingIntensityNN:
total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
# TODO: This doesn't seem to work
# for var in self.model.trainable_variables:
# var.assign(self.ema.average(var))
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -853,8 +861,11 @@ class IcingIntensityNN:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE: # This may not really work properly
self.optimizer.assign_average_vars(self.model.trainable_variables)
if TRACK_MOVING_AVERAGE:
# This may not really work properly (from tfa)
# self.optimizer.assign_average_vars(self.model.trainable_variables)
for var in self.model.trainable_variables:
var.assign(self.ema.average(var))
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment