From b9ca1991d52fff5a3d3fcb4951b3f2be29b19cb0 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 1 Aug 2023 11:06:01 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_srcnn_abi.py | 83 +++++++++++----------
 1 file changed, 44 insertions(+), 39 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 146cd75d..67bbdbc0 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -212,7 +212,7 @@ class SRCNN:
         self.test_label_files = None
 
         # self.n_chans = len(data_params_half) + len(data_params_full) + 1
-        self.n_chans = 3
+        self.n_chans = 1
 
         self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
 
@@ -264,23 +264,25 @@ class SRCNN:
         input_label = np.concatenate(label_s)
 
         data_norm = []
-        for param in data_params_half:
-            idx = params.index(param)
-            tmp = input_data[:, idx, :, :]
-            tmp = np.where(np.isnan(tmp), 0.0, tmp)
-            tmp = tmp[:, self.slc_y_m, self.slc_x_m]
-            tmp = self.upsample(tmp)
-            if DO_SMOOTH:
-                tmp = smooth_2d(tmp)
-            tmp = normalize(tmp, param, mean_std_dct)
-            data_norm.append(tmp)
+        # for param in data_params_half:
+        #     idx = params.index(param)
+        #     tmp = input_data[:, idx, :, :]
+        #     tmp = np.where(np.isnan(tmp), 0.0, tmp)
+        #     tmp = tmp[:, self.slc_y_m, self.slc_x_m]
+        #     tmp = self.upsample(tmp)
+        #     if DO_SMOOTH:
+        #         tmp = smooth_2d(tmp)
+        #     tmp = normalize(tmp, param, mean_std_dct)
+        #     # tmp = scale(tmp, param, mean_std_dct)
+        #     data_norm.append(tmp)
 
         # High res refectance ----------
-        idx = params_i.index('refl_0_65um_nom')
-        tmp = input_label[:, idx, ::2, ::2]
-        tmp = np.where(np.isnan(tmp), 0, tmp)
-        tmp = normalize(tmp, 'refl_0_65um_nom', mean_std_dct)
-        data_norm.append(tmp[:, self.slc_y, self.slc_x])
+        # idx = params_i.index('refl_0_65um_nom')
+        # tmp = input_label[:, idx, ::2, ::2]
+        # tmp = np.where(np.isnan(tmp), 0, tmp)
+        # tmp = normalize(tmp, 'refl_0_65um_nom', mean_std_dct)
+        # # tmp = scale(tmp, 'refl_0_65um_nom', mean_std_dct)
+        # data_norm.append(tmp[:, self.slc_y, self.slc_x])
 
         # High res reflectance down 2 ---------
         # idx = params_i.index('refl_0_65um_nom')
@@ -301,7 +303,8 @@ class SRCNN:
         tmp = self.upsample(tmp)
         if DO_SMOOTH:
             tmp = smooth_2d(tmp)
-        tmp = normalize(tmp, label_param, mean_std_dct)
+        # tmp = normalize(tmp, label_param, mean_std_dct)
+        tmp = scale(tmp, label_param, mean_std_dct)
         data_norm.append(tmp)
 
         # for param in sub_fields:
@@ -336,8 +339,8 @@ class SRCNN:
         # -----------------------------------------------------
         label = input_label[:, label_idx_i, ::2, ::2]
         label = label.copy()
-        label = normalize(label, label_param, mean_std_dct)
-        # label = scale(label, label_param, mean_std_dct)
+        # label = normalize(label, label_param, mean_std_dct)
+        label = scale(label, label_param, mean_std_dct)
         label = label[:, self.y_128, self.x_128]
 
         label = np.where(np.isnan(label), 0.0, label)
@@ -870,24 +873,24 @@ class SRCNN:
         self.LEN_Y = LEN_Y
 
         t0 = time.time()
-        bt = np.where(np.isnan(bt), 0, bt)
-        bt = bt[self.slc_y_m, self.slc_x_m]
-        bt = np.expand_dims(bt, axis=0)
-        # bt_us = upsample_static(bt, x_2, y_2, t, s, None, None)
-        bt_us = self.upsample(bt)
-        if DO_SMOOTH:
-            bt_us = smooth_2d(bt_us)
-        bt_us = normalize(bt_us, 'temp_11_0um_nom', mean_std_dct)
-
-        refl = np.where(np.isnan(refl), 0, refl)
-        # refl = refl[self.slc_y_m, self.slc_x_m]
-        refl = refl[self.slc_y, self.slc_x]
-        refl = np.expand_dims(refl, axis=0)
-        # refl_us = self.upsample(refl)
-        refl_us = refl
-        if DO_SMOOTH:
-            refl_us = smooth_2d(refl)
-        refl_us = normalize(refl_us, 'refl_0_65um_nom', mean_std_dct)
+        # bt = np.where(np.isnan(bt), 0, bt)
+        # bt = bt[self.slc_y_m, self.slc_x_m]
+        # bt = np.expand_dims(bt, axis=0)
+        # # bt_us = upsample_static(bt, x_2, y_2, t, s, None, None)
+        # bt_us = self.upsample(bt)
+        # if DO_SMOOTH:
+        #     bt_us = smooth_2d(bt_us)
+        # bt_us = normalize(bt_us, 'temp_11_0um_nom', mean_std_dct)
+
+        # refl = np.where(np.isnan(refl), 0, refl)
+        # # refl = refl[self.slc_y_m, self.slc_x_m]
+        # refl = refl[self.slc_y, self.slc_x]
+        # refl = np.expand_dims(refl, axis=0)
+        # # refl_us = self.upsample(refl)
+        # refl_us = refl
+        # if DO_SMOOTH:
+        #     refl_us = smooth_2d(refl)
+        # refl_us = normalize(refl_us, 'refl_0_65um_nom', mean_std_dct)
 
         cld_opd = np.where(np.isnan(cld_opd), 0, cld_opd)
         cld_opd = cld_opd[self.slc_y_m, self.slc_x_m]
@@ -896,7 +899,8 @@ class SRCNN:
         cld_opd_us = self.upsample(cld_opd)
         if DO_SMOOTH:
             cld_opd_us = smooth_2d(cld_opd_us)
-        cld_opd_us = normalize(cld_opd_us, label_param, mean_std_dct)
+        # cld_opd_us = normalize(cld_opd_us, label_param, mean_std_dct)
+        cld_opd_us = scale(cld_opd_us, label_param, mean_std_dct)
 
         # refl_sub_lo = np.expand_dims(refl_sub_lo, axis=0)
         # refl_sub_lo = upsample_nearest(refl_sub_lo)
@@ -917,7 +921,8 @@ class SRCNN:
 
         # data = np.stack([bt_us, refl_us, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd_us], axis=3)
         # data = np.stack([bt_us, refl_us, cld_opd_us, refl_sub_std], axis=3)
-        data = np.stack([bt_us, refl_us, cld_opd_us], axis=3)
+        # data = np.stack([bt_us, refl_us, cld_opd_us], axis=3)
+        data = np.stack([cld_opd_us], axis=3)
         print('data in: ', data.shape)
 
         cld_opd_sres = self.do_inference(data)
-- 
GitLab