From 92bddedbce468382620a4d47af0d36f4a9287622 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Sat, 4 Mar 2023 11:31:49 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cnn_cld_frac_mod_res.py | 14 +++++---------
 1 file changed, 5 insertions(+), 9 deletions(-)

diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py
index 6ed2173a..8cd62688 100644
--- a/modules/deeplearning/cnn_cld_frac_mod_res.py
+++ b/modules/deeplearning/cnn_cld_frac_mod_res.py
@@ -602,10 +602,10 @@ class SRCNN:
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
 
-    def predict(self, mini_batch):
-        inputs = [mini_batch[0]]
-        labels = mini_batch[1]
-        pred = self.model(inputs, training=False)
+    @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
+    def predict(self, inputs, labels):
+        labels = tf.squeeze(labels)
+        pred = self.model([inputs], training=False)
         t_loss = self.loss(labels, pred)
 
         self.test_labels.append(labels)
@@ -746,7 +746,7 @@ class SRCNN:
             ds = tf.data.Dataset.from_tensor_slices((data, label))
             ds = ds.batch(BATCH_SIZE)
             for mini_batch_test in ds:
-                self.predict(mini_batch_test)
+                self.predict(mini_batch_test[0], mini_batch_test[1])
 
         print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
 
@@ -754,10 +754,6 @@ class SRCNN:
         preds = np.concatenate(self.test_preds)
         print(labels.shape, preds.shape)
 
-        # if label_param != 'cloud_probability':
-        #     labels_denorm = denormalize(labels, label_param, mean_std_dct)
-        #     preds_denorm = denormalize(preds, label_param, mean_std_dct)
-
         return labels, preds
 
     def do_evaluate(self, data, ckpt_dir):
-- 
GitLab