From bd20b31ccd52c19ce4193ebf739635f03ed59508 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 17 Jan 2024 15:06:25 -0600
Subject: [PATCH] snapshot...

---
 .../cloud_fraction_fcn_abi_hkm_refl.py        | 86 +++++++++++--------
 1 file changed, 48 insertions(+), 38 deletions(-)

diff --git a/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py b/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py
index fad30b29..e8efe3d9 100644
--- a/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py
+++ b/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py
@@ -86,6 +86,8 @@ if KERNEL_SIZE == 3:
     slc_y = slice(0, int(Y_LEN/4) + 2)
     x_64 = slice(4, X_LEN + 4)
     y_64 = slice(4, Y_LEN + 4)
+    slc_x_hkm = slice(0, X_LEN + 2)
+    slc_y_hkm = slice(0, Y_LEN + 2)
 # ----------------------------------------
 
 
@@ -321,8 +323,10 @@ class SRCNN:
         self.n_chans = 5
 
         self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
+        self.X_hkm_img = tf.keras.Input(shape=(None, None, 1))
 
         self.inputs.append(self.X_img)
+        self.inputs.append(self.X_hkm_img)
 
         tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
 
@@ -347,6 +351,7 @@ class SRCNN:
         input_data = np.concatenate(data_s)
         input_label = np.concatenate(label_s)
 
+        data_hkm_norm = []
         data_norm = []
         for param in data_params_half:
             idx = params.index(param)
@@ -356,30 +361,28 @@ class SRCNN:
             # tmp = scale(tmp, param, mean_std_dct)
             data_norm.append(tmp)
 
-        # refl_i = input_label[:, params_i.index('refl_0_65um_nom'), :, :]
-        # refl_avg = get_grid_cell_mean(refl_i)
-        # refl_avg = refl_avg[:, slc_y, slc_x]
-        # refl_avg = normalize(refl_avg, 'refl_0_65um_nom', mean_std_dct)
-        # data_norm.append(refl_avg)
-        #
-        # rlo, rhi, rstd, _ = get_min_max_std(refl_i)
-
-        rlo = input_data[:, params.index('refl_submin_ch01'), :, :]
-        rlo = rlo[:, slc_y, slc_x]
-        rlo = normalize(rlo, 'refl_0_65um_nom', mean_std_dct)
-        # rlo = scale(rlo, 'refl_0_65um_nom', mean_std_dct)
+        refl_i = input_label[:, params_i.index('refl_0_65um_nom'), :, :]
+        refl_i = refl_i[:, slc_y_hkm, slc_x_hkm]
+        refl_avg = normalize(refl_i, 'refl_0_65um_nom', mean_std_dct)
+        data_hkm_norm.append(refl_i)
 
-        rhi = input_data[:, params.index('refl_submax_ch01'), :, :]
-        rhi = rhi[:, slc_y, slc_x]
-        rhi = normalize(rhi, 'refl_0_65um_nom', mean_std_dct)
-        # rhi = scale(rhi, 'refl_0_65um_nom', mean_std_dct)
-        refl_rng = rhi - rlo
-        data_norm.append(refl_rng)
 
-        rstd = input_data[:, params.index('refl_substddev_ch01'), :, :]
-        rstd = rstd[:, slc_y, slc_x]
-        rstd = scale2(rstd, 0.0, 20.0)
-        data_norm.append(rstd)
+        # rlo = input_data[:, params.index('refl_submin_ch01'), :, :]
+        # rlo = rlo[:, slc_y, slc_x]
+        # rlo = normalize(rlo, 'refl_0_65um_nom', mean_std_dct)
+        # # rlo = scale(rlo, 'refl_0_65um_nom', mean_std_dct)
+        #
+        # rhi = input_data[:, params.index('refl_submax_ch01'), :, :]
+        # rhi = rhi[:, slc_y, slc_x]
+        # rhi = normalize(rhi, 'refl_0_65um_nom', mean_std_dct)
+        # # rhi = scale(rhi, 'refl_0_65um_nom', mean_std_dct)
+        # refl_rng = rhi - rlo
+        # data_norm.append(refl_rng)
+        #
+        # rstd = input_data[:, params.index('refl_substddev_ch01'), :, :]
+        # rstd = rstd[:, slc_y, slc_x]
+        # rstd = scale2(rstd, 0.0, 20.0)
+        # data_norm.append(rstd)
 
         tmp = input_label[:, label_idx_i, :, :]
         tmp = get_grid_cell_mean(tmp)
@@ -389,6 +392,9 @@ class SRCNN:
         data = np.stack(data_norm, axis=3)
         data = data.astype(np.float32)
 
+        data_hkm = np.stack(data_hkm_norm, axis=3)
+        data_hkm = data_hkm.astype(np.float32)
+
         # -----------------------------------------------------
         # -----------------------------------------------------
         label = input_label[:, label_idx_i, :, :]
@@ -416,7 +422,7 @@ class SRCNN:
         #     data = np.concatenate([data, data_ud, data_lr])
         #     label = np.concatenate([label, label_ud, label_lr])
 
-        return data, label
+        return data, data_hkm, label
 
     def get_in_mem_data_batch_train(self, idxs):
         return self.get_in_mem_data_batch(idxs, True)
@@ -426,12 +432,12 @@ class SRCNN:
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function(self, indexes):
-        out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32])
+        out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.float32])
         return out
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function_test(self, indexes):
-        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
+        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.float32])
         return out
 
     def get_train_dataset(self, num_files):
@@ -505,7 +511,8 @@ class SRCNN:
         num_filters = 64
 
         input_2d = self.inputs[0]
-        print('input: ', input_2d.shape)
+        input_hkm_2d = self.inputs[1]
+        print('input: ', input_2d.shape, input_hkm_2d.shape)
 
         conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=KERNEL_SIZE, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d)
         print(conv.shape)
@@ -590,7 +597,8 @@ class SRCNN:
     def train_step(self, inputs, labels):
         labels = tf.squeeze(labels, axis=[3])
         with tf.GradientTape() as tape:
-            pred = self.model([inputs], training=True)
+            # pred = self.model([inputs], training=True)
+            pred = self.model(inputs, training=True)
             loss = self.loss(labels, pred)
             total_loss = loss
             if len(self.model.losses) > 0:
@@ -609,7 +617,8 @@ class SRCNN:
     @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
     def test_step(self, inputs, labels):
         labels = tf.squeeze(labels, axis=[3])
-        pred = self.model([inputs], training=False)
+        # pred = self.model([inputs], training=False)
+        pred = self.model(inputs, training=False)
         t_loss = self.loss(labels, pred)
 
         self.test_loss(t_loss)
@@ -618,7 +627,8 @@ class SRCNN:
     # @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
     # decorator commented out because pred.numpy(): pred not evaluated yet.
     def predict(self, inputs, labels):
-        pred = self.model([inputs], training=False)
+        # pred = self.model([inputs], training=False)
+        pred = self.model(inputs, training=False)
         # t_loss = self.loss(tf.squeeze(labels, axis=[3]), pred)
         t_loss = self.loss(labels, pred)
 
@@ -678,12 +688,12 @@ class SRCNN:
             proc_batch_cnt = 0
             n_samples = 0
 
-            for data, label in self.train_dataset:
-                trn_ds = tf.data.Dataset.from_tensor_slices((data, label))
+            for data, data_hkm, label in self.train_dataset:
+                trn_ds = tf.data.Dataset.from_tensor_slices((data, data_hkm, label))
                 trn_ds = trn_ds.batch(BATCH_SIZE)
                 for mini_batch in trn_ds:
                     if self.learningRateSchedule is not None:
-                        loss = self.train_step(mini_batch[0], mini_batch[1])
+                        loss = self.train_step([mini_batch[0], mini_batch[1]], mini_batch[2])
 
                     if (step % 100) == 0:
 
@@ -694,11 +704,11 @@ class SRCNN:
                             tf.summary.scalar('num_epochs', epoch, step=step)
 
                         self.reset_test_metrics()
-                        for data_tst, label_tst in self.test_dataset:
-                            tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
+                        for data_tst, data_hkm_tst, label_tst in self.test_dataset:
+                            tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, data_hkm_tst, label_tst))
                             tst_ds = tst_ds.batch(BATCH_SIZE)
                             for mini_batch_test in tst_ds:
-                                self.test_step(mini_batch_test[0], mini_batch_test[1])
+                                self.test_step([mini_batch_test[0], mini_batch_test[1]], mini_batch_test[2])
 
                         with self.writer_valid.as_default():
                             tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
@@ -723,11 +733,11 @@ class SRCNN:
             total_time += (t1-t0)
 
             self.reset_test_metrics()
-            for data, label in self.test_dataset:
-                ds = tf.data.Dataset.from_tensor_slices((data, label))
+            for data, data_hkm, label in self.test_dataset:
+                ds = tf.data.Dataset.from_tensor_slices((data, data_hkm, label))
                 ds = ds.batch(BATCH_SIZE)
                 for mini_batch in ds:
-                    self.test_step(mini_batch[0], mini_batch[1])
+                    self.test_step([mini_batch[0], mini_batch[1]], mini_batch[2])
 
             print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')
-- 
GitLab