From 92bddedbce468382620a4d47af0d36f4a9287622 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Sat, 4 Mar 2023 11:31:49 -0600 Subject: [PATCH] snapshot... --- modules/deeplearning/cnn_cld_frac_mod_res.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py index 6ed2173a..8cd62688 100644 --- a/modules/deeplearning/cnn_cld_frac_mod_res.py +++ b/modules/deeplearning/cnn_cld_frac_mod_res.py @@ -602,10 +602,10 @@ class SRCNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) - def predict(self, mini_batch): - inputs = [mini_batch[0]] - labels = mini_batch[1] - pred = self.model(inputs, training=False) + @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) + def predict(self, inputs, labels): + labels = tf.squeeze(labels) + pred = self.model([inputs], training=False) t_loss = self.loss(labels, pred) self.test_labels.append(labels) @@ -746,7 +746,7 @@ class SRCNN: ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: - self.predict(mini_batch_test) + self.predict(mini_batch_test[0], mini_batch_test[1]) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) @@ -754,10 +754,6 @@ class SRCNN: preds = np.concatenate(self.test_preds) print(labels.shape, preds.shape) - # if label_param != 'cloud_probability': - # labels_denorm = denormalize(labels, label_param, mean_std_dct) - # preds_denorm = denormalize(preds, label_param, mean_std_dct) - return labels, preds def do_evaluate(self, data, ckpt_dir): -- GitLab