Skip to content
Snippets Groups Projects
Commit 48c567e2 authored by tomrink's avatar tomrink
Browse files

Merge remote-tracking branch 'origin/master' into use_flight_altitude

parents 2588ad43 cb090084
No related branches found
No related tags found
No related merge requests found
...@@ -178,6 +178,7 @@ class IcingIntensityNN: ...@@ -178,6 +178,7 @@ class IcingIntensityNN:
self.model = None self.model = None
self.optimizer = None self.optimizer = None
self.ema = None
self.train_loss = None self.train_loss = None
self.train_accuracy = None self.train_accuracy = None
self.test_loss = None self.test_loss = None
...@@ -647,8 +648,10 @@ class IcingIntensityNN: ...@@ -647,8 +648,10 @@ class IcingIntensityNN:
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule) optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE: # Not really sure this works properly if TRACK_MOVING_AVERAGE:
optimizer = tfa.optimizers.MovingAverage(optimizer) # Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
self.optimizer = optimizer self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
...@@ -684,6 +687,11 @@ class IcingIntensityNN: ...@@ -684,6 +687,11 @@ class IcingIntensityNN:
total_loss = loss + reg_loss total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables) gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
# TODO: This doesn't seem to work
# for var in self.model.trainable_variables:
# var.assign(self.ema.average(var))
self.train_loss(loss) self.train_loss(loss)
self.train_accuracy(labels, pred) self.train_accuracy(labels, pred)
...@@ -853,8 +861,11 @@ class IcingIntensityNN: ...@@ -853,8 +861,11 @@ class IcingIntensityNN:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------') print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE: # This may not really work properly if TRACK_MOVING_AVERAGE:
self.optimizer.assign_average_vars(self.model.trainable_variables) # This may not really work properly (from tfa)
# self.optimizer.assign_average_vars(self.model.trainable_variables)
for var in self.model.trainable_variables:
var.assign(self.ema.average(var))
tst_loss = self.test_loss.result().numpy() tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss: if tst_loss < best_test_loss:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment