From 24e8c721182099ff2031a1c34bf318c88cd6b075 Mon Sep 17 00:00:00 2001
From: rink <rink@ssec.wisc.edu>
Date: Wed, 14 Oct 2020 17:10:27 -0500
Subject: [PATCH] minor

---
 modules/deeplearning/cloudheight.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index fe7c878b..0210c863 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -637,7 +637,7 @@ class CloudHeightNN:
             self.variable_averages = tf.train.ExponentialMovingAverage(0.999, self.global_step)
             self.variable_averages.apply(self.model.trainable_variables)
 
-    # @tf.function
+    @tf.function
     def train_step(self, mini_batch):
         inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
         labels = mini_batch[2]
@@ -656,7 +656,7 @@ class CloudHeightNN:
 
         return loss
 
-    # @tf.function
+    @tf.function
     def test_step(self, mini_batch):
         inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
         labels = mini_batch[2]
@@ -806,7 +806,7 @@ class CloudHeightNN:
             ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst))
             ds = ds.batch(BATCH_SIZE)
             for mini_batch_test in ds:
-                self.test_step(mini_batch_test)
+                self.predict(mini_batch_test)
         print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
         print('acc_0', self.num_0, self.accuracy_0.result())
         print('acc_1', self.num_1, self.accuracy_1.result())
-- 
GitLab