diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 2ee9f47d8fa1a74e578ecbb73c25576644a612c7..78e551d7ddb6722174ee06b9f0bce8dc1a851f8f 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -178,6 +178,7 @@ class IcingIntensityNN: self.model = None self.optimizer = None + self.ema = None self.train_loss = None self.train_accuracy = None self.test_loss = None @@ -647,8 +648,10 @@ class IcingIntensityNN: optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule) - if TRACK_MOVING_AVERAGE: # Not really sure this works properly - optimizer = tfa.optimizers.MovingAverage(optimizer) + if TRACK_MOVING_AVERAGE: + # Not really sure this works properly (from tfa) + # optimizer = tfa.optimizers.MovingAverage(optimizer) + self.ema = tf.train.ExponentialMovingAverage(decay=0.999) self.optimizer = optimizer self.initial_learning_rate = initial_learning_rate @@ -684,6 +687,11 @@ class IcingIntensityNN: total_loss = loss + reg_loss gradients = tape.gradient(total_loss, self.model.trainable_variables) self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) + if TRACK_MOVING_AVERAGE: + self.ema.apply(self.model.trainable_variables) + # TODO: This doesn't seem to work + # for var in self.model.trainable_variables: + # var.assign(self.ema.average(var)) self.train_loss(loss) self.train_accuracy(labels, pred) @@ -853,8 +861,11 @@ class IcingIntensityNN: print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('------------------------------------------------------') - if TRACK_MOVING_AVERAGE: # This may not really work properly - self.optimizer.assign_average_vars(self.model.trainable_variables) + if TRACK_MOVING_AVERAGE: + # This may not really work properly (from tfa) + # self.optimizer.assign_average_vars(self.model.trainable_variables) + for var in self.model.trainable_variables: + var.assign(self.ema.average(var)) tst_loss = self.test_loss.result().numpy() if tst_loss < best_test_loss: