diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py index 87ee61a204082929fff73f8bed0b103c32c3c3ef..ed8a26ae38e865dbe6e0c875841ff2408bc58820 100644 --- a/modules/deeplearning/cnn_cld_frac_mod_res.py +++ b/modules/deeplearning/cnn_cld_frac_mod_res.py @@ -602,7 +602,7 @@ class SRCNN: @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) def train_step(self, inputs, labels): - labels = tf.squeeze(labels) + labels = tf.squeeze(labels, axis=[3]) with tf.GradientTape() as tape: pred = self.model([inputs], training=True) loss = self.loss(labels, pred) @@ -622,7 +622,7 @@ class SRCNN: @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) def test_step(self, inputs, labels): - labels = tf.squeeze(labels) + labels = tf.squeeze(labels, axis=[3]) pred = self.model([inputs], training=False) t_loss = self.loss(labels, pred)