From fed205e9b63b94ab13d5a9ae6d25e1d50d50dc8b Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 20 Mar 2023 14:18:11 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cnn_cld_frac.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py
index a2f9f831..2ad505f9 100644
--- a/modules/deeplearning/cnn_cld_frac.py
+++ b/modules/deeplearning/cnn_cld_frac.py
@@ -736,7 +736,7 @@ class SRCNN:
                 trn_ds = trn_ds.batch(BATCH_SIZE)
                 for mini_batch in trn_ds:
                     if self.learningRateSchedule is not None:
-                        loss = self.train_step(mini_batch)
+                        loss = self.train_step(mini_batch[0], mini_batch[1])
 
                     if (step % 100) == 0:
 
@@ -751,7 +751,7 @@ class SRCNN:
                             tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
                             tst_ds = tst_ds.batch(BATCH_SIZE)
                             for mini_batch_test in tst_ds:
-                                self.test_step(mini_batch_test)
+                                self.test_step(mini_batch_test[0], mini_batch_test[1])
 
                         with self.writer_valid.as_default():
                             tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
@@ -780,7 +780,7 @@ class SRCNN:
                 ds = tf.data.Dataset.from_tensor_slices((data, label))
                 ds = ds.batch(BATCH_SIZE)
                 for mini_batch in ds:
-                    self.test_step(mini_batch)
+                    self.test_step(mini_batch[0], mini_batch[1])
 
             print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')
-- 
GitLab