Skip to content
Snippets Groups Projects
Commit 868e06cd authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 1980c56b
No related branches found
No related tags found
No related merge requests found
......@@ -31,7 +31,6 @@ else:
BATCH_SIZE = 128
NUM_EPOCHS = 80
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True
NOISE_TRAINING = False
......@@ -41,6 +40,8 @@ DO_AUGMENT = True
DO_SMOOTH = False
SIGMA = 1.0
DO_ZERO_OUT = False
CACHE_FILE = '/scratch/long/rink/cld_opd_abi_128x128_cache'
USE_EMA = True
# setup scaling parameters dictionary
mean_std_dct = {}
......@@ -481,12 +482,7 @@ class SRCNN:
self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE:
# Not sure that this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule, use_ema=USE_EMA)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
......@@ -509,8 +505,6 @@ class SRCNN:
total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -575,7 +569,7 @@ class SRCNN:
step = 0
total_time = 0
best_test_loss = np.finfo(dtype=np.float).max
best_test_loss = np.finfo(dtype=np.float64).max
if EARLY_STOP:
es = EarlyStop()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment