diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py index 8cd6268877343be4d362e85254517fc7aed9f023..eb12f2313789f3eccd074f0d6ea79fb03bf49ffc 100644 --- a/modules/deeplearning/cnn_cld_frac_mod_res.py +++ b/modules/deeplearning/cnn_cld_frac_mod_res.py @@ -602,11 +602,11 @@ class SRCNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) - @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) + # @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) def predict(self, inputs, labels): labels = tf.squeeze(labels) pred = self.model([inputs], training=False) - t_loss = self.loss(labels, pred) + t_loss = self.loss(tf.squeeze(labels), pred) self.test_labels.append(labels) self.test_preds.append(pred.numpy())