Skip to content
Snippets Groups Projects
Commit 868e06cd authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 1980c56b
No related branches found
No related tags found
No related merge requests found
...@@ -31,7 +31,6 @@ else: ...@@ -31,7 +31,6 @@ else:
BATCH_SIZE = 128 BATCH_SIZE = 128
NUM_EPOCHS = 80 NUM_EPOCHS = 80
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True EARLY_STOP = True
NOISE_TRAINING = False NOISE_TRAINING = False
...@@ -41,6 +40,8 @@ DO_AUGMENT = True ...@@ -41,6 +40,8 @@ DO_AUGMENT = True
DO_SMOOTH = False DO_SMOOTH = False
SIGMA = 1.0 SIGMA = 1.0
DO_ZERO_OUT = False DO_ZERO_OUT = False
CACHE_FILE = '/scratch/long/rink/cld_opd_abi_128x128_cache'
USE_EMA = True
# setup scaling parameters dictionary # setup scaling parameters dictionary
mean_std_dct = {} mean_std_dct = {}
...@@ -481,12 +482,7 @@ class SRCNN: ...@@ -481,12 +482,7 @@ class SRCNN:
self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate) self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule) optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule, use_ema=USE_EMA)
if TRACK_MOVING_AVERAGE:
# Not sure that this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.optimizer = optimizer self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
...@@ -509,8 +505,6 @@ class SRCNN: ...@@ -509,8 +505,6 @@ class SRCNN:
total_loss = loss + reg_loss total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables) gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
self.train_loss(loss) self.train_loss(loss)
self.train_accuracy(labels, pred) self.train_accuracy(labels, pred)
...@@ -575,7 +569,7 @@ class SRCNN: ...@@ -575,7 +569,7 @@ class SRCNN:
step = 0 step = 0
total_time = 0 total_time = 0
best_test_loss = np.finfo(dtype=np.float).max best_test_loss = np.finfo(dtype=np.float64).max
if EARLY_STOP: if EARLY_STOP:
es = EarlyStop() es = EarlyStop()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment