From fed205e9b63b94ab13d5a9ae6d25e1d50d50dc8b Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 20 Mar 2023 14:18:11 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/cnn_cld_frac.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py index a2f9f831..2ad505f9 100644 --- a/modules/deeplearning/cnn_cld_frac.py +++ b/modules/deeplearning/cnn_cld_frac.py @@ -736,7 +736,7 @@ class SRCNN: trn_ds = trn_ds.batch(BATCH_SIZE) for mini_batch in trn_ds: if self.learningRateSchedule is not None: - loss = self.train_step(mini_batch) + loss = self.train_step(mini_batch[0], mini_batch[1]) if (step % 100) == 0: @@ -751,7 +751,7 @@ class SRCNN: tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst)) tst_ds = tst_ds.batch(BATCH_SIZE) for mini_batch_test in tst_ds: - self.test_step(mini_batch_test) + self.test_step(mini_batch_test[0], mini_batch_test[1]) with self.writer_valid.as_default(): tf.summary.scalar('loss_val', self.test_loss.result(), step=step) @@ -780,7 +780,7 @@ class SRCNN: ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = ds.batch(BATCH_SIZE) for mini_batch in ds: - self.test_step(mini_batch) + self.test_step(mini_batch[0], mini_batch[1]) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('------------------------------------------------------') -- GitLab