From 9477cc98cbfa3b8b2b9d1be23e1952c3a1b97818 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 14 Mar 2023 19:03:08 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cnn_cld_frac_mod_res.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py
index 87ee61a2..ed8a26ae 100644
--- a/modules/deeplearning/cnn_cld_frac_mod_res.py
+++ b/modules/deeplearning/cnn_cld_frac_mod_res.py
@@ -602,7 +602,7 @@ class SRCNN:
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
     def train_step(self, inputs, labels):
-        labels = tf.squeeze(labels)
+        labels = tf.squeeze(labels, axis=[3])
         with tf.GradientTape() as tape:
             pred = self.model([inputs], training=True)
             loss = self.loss(labels, pred)
@@ -622,7 +622,7 @@ class SRCNN:
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
     def test_step(self, inputs, labels):
-        labels = tf.squeeze(labels)
+        labels = tf.squeeze(labels, axis=[3])
         pred = self.model([inputs], training=False)
         t_loss = self.loss(labels, pred)
 
-- 
GitLab