Skip to content
Snippets Groups Projects
Commit ad23361d authored by tomrink's avatar tomrink
Browse files

Use tf.compat.v1.train.Saver to save EWAaveraged trainable variables

parent 24b110a2
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
import tensorflow_addons as tfa
from util.setup import logdir, modeldir, cachepath, now
from util.setup import logdir, modeldir, cachepath, now, ewa_varsdir
from util.util import homedir, EarlyStop, normalize, make_for_full_domain_predict
from util.geos_nav import get_navigation
......@@ -72,7 +72,7 @@ train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'te
# 'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
# ---------------------------------------------
train_params = train_params_l1b
train_params = train_params_l2
# -- Zero out params (Experimentation Only) ------------
zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
DO_ZERO_OUT = False
......@@ -652,7 +652,7 @@ class IcingIntensityNN:
if TRACK_MOVING_AVERAGE:
# Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.999)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
......@@ -690,9 +690,6 @@ class IcingIntensityNN:
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
# TODO: This doesn't seem to work
# for var in self.model.trainable_variables:
# var.assign(self.ema.average(var))
self.train_loss(loss)
self.train_accuracy(labels, pred)
......@@ -862,12 +859,6 @@ class IcingIntensityNN:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE:
# This may not really work properly (from tfa)
# self.optimizer.assign_average_vars(self.model.trainable_variables)
for var in self.model.trainable_variables:
var.assign(self.ema.average(var))
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
best_test_loss = tst_loss
......@@ -880,6 +871,14 @@ class IcingIntensityNN:
ckpt_manager.save()
if TRACK_MOVING_AVERAGE:
vars_ema = []
for var in self.model.trainable_variables:
vars_ema.append(self.ema.average(var))
saver_ewa = tf.compat.v1.train.Saver(var_list=vars_ema)
saver_ewa.save(None, ewa_varsdir)
if self.DISK_CACHE and epoch == 0:
f = open(cachepath, 'wb')
pickle.dump(self.in_mem_data_cache, f)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment