From 24e8c721182099ff2031a1c34bf318c88cd6b075 Mon Sep 17 00:00:00 2001 From: rink <rink@ssec.wisc.edu> Date: Wed, 14 Oct 2020 17:10:27 -0500 Subject: [PATCH] minor --- modules/deeplearning/cloudheight.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index fe7c878b..0210c863 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -637,7 +637,7 @@ class CloudHeightNN: self.variable_averages = tf.train.ExponentialMovingAverage(0.999, self.global_step) self.variable_averages.apply(self.model.trainable_variables) - # @tf.function + @tf.function def train_step(self, mini_batch): inputs = [mini_batch[0], mini_batch[1], mini_batch[3]] labels = mini_batch[2] @@ -656,7 +656,7 @@ class CloudHeightNN: return loss - # @tf.function + @tf.function def test_step(self, mini_batch): inputs = [mini_batch[0], mini_batch[1], mini_batch[3]] labels = mini_batch[2] @@ -806,7 +806,7 @@ class CloudHeightNN: ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: - self.test_step(mini_batch_test) + self.predict(mini_batch_test) print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result()) print('acc_0', self.num_0, self.accuracy_0.result()) print('acc_1', self.num_1, self.accuracy_1.result()) -- GitLab