Skip to content
Snippets Groups Projects
Commit ad23361d authored by tomrink's avatar tomrink
Browse files

Use tf.compat.v1.train.Saver to save EWAaveraged trainable variables

parent 24b110a2
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa import tensorflow_addons as tfa
from util.setup import logdir, modeldir, cachepath, now from util.setup import logdir, modeldir, cachepath, now, ewa_varsdir
from util.util import homedir, EarlyStop, normalize, make_for_full_domain_predict from util.util import homedir, EarlyStop, normalize, make_for_full_domain_predict
from util.geos_nav import get_navigation from util.geos_nav import get_navigation
...@@ -72,7 +72,7 @@ train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'te ...@@ -72,7 +72,7 @@ train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'te
# 'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp'] # 'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
# --------------------------------------------- # ---------------------------------------------
train_params = train_params_l1b train_params = train_params_l2
# -- Zero out params (Experimentation Only) ------------ # -- Zero out params (Experimentation Only) ------------
zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp'] zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
DO_ZERO_OUT = False DO_ZERO_OUT = False
...@@ -652,7 +652,7 @@ class IcingIntensityNN: ...@@ -652,7 +652,7 @@ class IcingIntensityNN:
if TRACK_MOVING_AVERAGE: if TRACK_MOVING_AVERAGE:
# Not really sure this works properly (from tfa) # Not really sure this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer) # optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.999) self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.optimizer = optimizer self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
...@@ -690,9 +690,6 @@ class IcingIntensityNN: ...@@ -690,9 +690,6 @@ class IcingIntensityNN:
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE: if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables) self.ema.apply(self.model.trainable_variables)
# TODO: This doesn't seem to work
# for var in self.model.trainable_variables:
# var.assign(self.ema.average(var))
self.train_loss(loss) self.train_loss(loss)
self.train_accuracy(labels, pred) self.train_accuracy(labels, pred)
...@@ -862,12 +859,6 @@ class IcingIntensityNN: ...@@ -862,12 +859,6 @@ class IcingIntensityNN:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------') print('------------------------------------------------------')
if TRACK_MOVING_AVERAGE:
# This may not really work properly (from tfa)
# self.optimizer.assign_average_vars(self.model.trainable_variables)
for var in self.model.trainable_variables:
var.assign(self.ema.average(var))
tst_loss = self.test_loss.result().numpy() tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss: if tst_loss < best_test_loss:
best_test_loss = tst_loss best_test_loss = tst_loss
...@@ -880,6 +871,14 @@ class IcingIntensityNN: ...@@ -880,6 +871,14 @@ class IcingIntensityNN:
ckpt_manager.save() ckpt_manager.save()
if TRACK_MOVING_AVERAGE:
vars_ema = []
for var in self.model.trainable_variables:
vars_ema.append(self.ema.average(var))
saver_ewa = tf.compat.v1.train.Saver(var_list=vars_ema)
saver_ewa.save(None, ewa_varsdir)
if self.DISK_CACHE and epoch == 0: if self.DISK_CACHE and epoch == 0:
f = open(cachepath, 'wb') f = open(cachepath, 'wb')
pickle.dump(self.in_mem_data_cache, f) pickle.dump(self.in_mem_data_cache, f)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment