From 4751ca868782e34b2c2721ed1255da3d99a97155 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 6 Sep 2023 10:12:04 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_fcn_abi.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_fcn_abi.py b/modules/deeplearning/cloud_opd_fcn_abi.py
index 02baf190..bb43831b 100644
--- a/modules/deeplearning/cloud_opd_fcn_abi.py
+++ b/modules/deeplearning/cloud_opd_fcn_abi.py
@@ -350,17 +350,19 @@ class SRCNN:
         tmp = tmp[:, slc_y, slc_x]
         data_norm.append(tmp)
         # ---------
-        data = np.stack(data_norm, axis=3)
-        data = data.astype(np.float32)
+        # data = np.stack(data_norm, axis=3)
+        # data = data.astype(np.float32)
 
         # -----------------------------------------------------
         # -----------------------------------------------------
         label = input_label[:, label_idx_i, :, :]
         label = label[:, y_64, x_64]
         cld_prob = cld_prob[:, y_64, x_64]
-        if not is_training:
-            cat_cf = get_label_data_5cat(cld_prob)
-            self.test_cat_cf.append(cat_cf)
+        cat_cf = get_label_data_5cat(cld_prob)
+        data_norm.append(cat_cf)
+        data = np.stack(data_norm, axis=3)
+        data = data.astype(np.float32)
+
         label = get_cldy_frac_opd(cld_prob, label)
         # label = scale(label, label_param, mean_std_dct)
         label = np.where(np.isnan(label), 0, label)
@@ -470,6 +472,7 @@ class SRCNN:
         num_filters = 64
 
         input_2d = self.inputs[0]
+        input_2d = input_2d[:, :, :, 0:self.n_chans]
         print('input: ', input_2d.shape)
 
         conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=KERNEL_SIZE, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d)
@@ -568,6 +571,7 @@ class SRCNN:
         self.test_labels.append(labels)
         self.test_preds.append(pred.numpy())
         self.test_input.append(inputs)
+        self.test_cat_cf.append(inputs[:, :, :, self.n_chans])
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
-- 
GitLab