From b98b04a39e3da82412cf0c87fda908f5bb0d16a3 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 12 Jul 2023 11:03:07 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_srcnn_abi.py | 42 +++++++++++++++------
 1 file changed, 30 insertions(+), 12 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 167580c9..8f277a61 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -29,7 +29,7 @@ EARLY_STOP = True
 
 NOISE_TRAINING = False
 NOISE_STDDEV = 0.01
-DO_AUGMENT = False
+DO_AUGMENT = True
 
 DO_SMOOTH = False
 SIGMA = 1.0
@@ -59,7 +59,8 @@ params_i = ['temp_11_0um_nom', 'refl_0_65um_nom', 'temp_stddev3x3_ch31', 'refl_s
 # data_params_half = ['temp_11_0um_nom', 'refl_0_65um_nom']
 data_params_half = ['temp_11_0um_nom']
 data_params_full = ['refl_0_65um_nom']
-sub_fields = ['refl_submin_ch01', 'refl_submax_ch01', 'refl_substddev_ch01']
+# sub_fields = ['refl_submin_ch01', 'refl_submax_ch01', 'refl_substddev_ch01']
+sub_fields = ['refl_substddev_ch01']
 # sub_fields = ['refl_stddev3x3_ch01']
 
 label_idx_i = params_i.index(label_param)
@@ -210,7 +211,7 @@ class SRCNN:
         self.test_label_files = None
 
         # self.n_chans = len(data_params_half) + len(data_params_full) + 1
-        self.n_chans = 5
+        self.n_chans = 4
 
         self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
 
@@ -271,11 +272,22 @@ class SRCNN:
             data_norm.append(tmp)
 
         # High res refectance ----------
-        # tmp = input_label[:, label_idx_i, :, :]
+        # idx = params_i.index('refl_0_65um_nom')
+        # tmp = input_label[:, idx, :, :]
         # tmp = np.where(np.isnan(tmp), 0, tmp)
         # tmp = normalize(tmp, 'refl_0_65um_nom', mean_std_dct)
         # data_norm.append(tmp[:, self.slc_y, self.slc_x])
 
+        idx = params_i.index('refl_0_65um_nom')
+        tmp = input_label[:, idx, :, :]
+        tmp = tmp.copy()
+        tmp = np.where(np.isnan(tmp), 0.0, tmp)
+        tmp = tmp[:, self.slc_y_2, self.slc_x_2]
+        tmp = self.upsample(tmp)
+        tmp = smooth_2d(tmp)
+        tmp = normalize(tmp, label_param, mean_std_dct)
+        data_norm.append(tmp)
+
         tmp = input_label[:, label_idx_i, :, :]
         tmp = tmp.copy()
         tmp = np.where(np.isnan(tmp), 0.0, tmp)
@@ -328,14 +340,20 @@ class SRCNN:
         label = label.astype(np.float32)
 
         if is_training and DO_AUGMENT:
-            data_ud = np.flip(data, axis=1)
-            label_ud = np.flip(label, axis=1)
-
-            data_lr = np.flip(data, axis=2)
-            label_lr = np.flip(label, axis=2)
-
-            data = np.concatenate([data, data_ud, data_lr])
-            label = np.concatenate([label, label_ud, label_lr])
+            # data_ud = np.flip(data, axis=1)
+            # label_ud = np.flip(label, axis=1)
+            #
+            # data_lr = np.flip(data, axis=2)
+            # label_lr = np.flip(label, axis=2)
+            #
+            # data = np.concatenate([data, data_ud, data_lr])
+            # label = np.concatenate([label, label_ud, label_lr])
+
+            data_rot = np.rot90(data, axes=(1, 2))
+            label_rot = np.rot90(label, axes=(1, 2))
+
+            data = np.concatenate([data, data_rot])
+            label = np.concatenate([label, label_rot])
 
         return data, label
 
-- 
GitLab