diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py index 6ed2173a82ceadc1381c6d36b87db658d18c928a..8cd6268877343be4d362e85254517fc7aed9f023 100644 --- a/modules/deeplearning/cnn_cld_frac_mod_res.py +++ b/modules/deeplearning/cnn_cld_frac_mod_res.py @@ -602,10 +602,10 @@ class SRCNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) - def predict(self, mini_batch): - inputs = [mini_batch[0]] - labels = mini_batch[1] - pred = self.model(inputs, training=False) + @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) + def predict(self, inputs, labels): + labels = tf.squeeze(labels) + pred = self.model([inputs], training=False) t_loss = self.loss(labels, pred) self.test_labels.append(labels) @@ -746,7 +746,7 @@ class SRCNN: ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: - self.predict(mini_batch_test) + self.predict(mini_batch_test[0], mini_batch_test[1]) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) @@ -754,10 +754,6 @@ class SRCNN: preds = np.concatenate(self.test_preds) print(labels.shape, preds.shape) - # if label_param != 'cloud_probability': - # labels_denorm = denormalize(labels, label_param, mean_std_dct) - # preds_denorm = denormalize(preds, label_param, mean_std_dct) - return labels, preds def do_evaluate(self, data, ckpt_dir):