From 861a6b4952240fc867f172f5a24a8a212ccbd2b9 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 17 Jul 2023 10:31:10 -0500
Subject: [PATCH] `snapshot...`

---
 modules/deeplearning/cloud_opd_srcnn_abi.py | 75 +++++++++++++++++----
 1 file changed, 62 insertions(+), 13 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 773f3f4d..b832b715 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -796,6 +796,55 @@ class SRCNN:
         else:
             return cld_opd_sres
 
+    def run_inference_test(self, in_file, out_file):
+        gc.collect()
+        t0 = time.time()
+
+        group_name_i = 'super/'
+        group_name_m = 'orig/'
+        target_param = 'cld_opd_dcomp_1'
+
+        h5f = h5py.File(in_file, 'r')
+
+        refl = get_grid_values_all(h5f, group_name_i+'refl_ch01')
+        print('FD dims: ', refl.shape)
+        LEN_Y, LEN_X = refl.shape
+
+        bt = get_grid_values_all(h5f, group_name_m+'temp_ch38')
+
+        cld_opd = get_grid_values_all(h5f, group_name_m+target_param)
+
+        # refl_sub_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
+        # refl_sub_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
+        # refl_sub_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
+
+        t1 = time.time()
+        print('read data time: ', (t1 - t0))
+
+        LEN_Y -= 8
+        LEN_X -= 8
+
+        LEN_Y = 2 * (LEN_Y - 8)
+        LEN_X = 2 * (LEN_X - 8)
+
+        t0 = time.time()
+        # cld_opd_sres, LEN_Y_in, LEN_X_in = self.run_inference_(bt, refl, cld_opd, refl_sub_lo, refl_sub_hi, refl_sub_std, LEN_Y, LEN_X)
+        cld_opd_sres, LEN_Y_in, LEN_X_in = self.run_inference_(bt, refl, cld_opd, None, None, None, LEN_Y, LEN_X)
+        t1 = time.time()
+        print('inference time: ', (t1 - t0))
+        print(cld_opd_sres.shape)
+
+        cld_opd_sres_out = np.zeros((LEN_Y_in, LEN_X_in), dtype=np.int8)
+        border = int((KERNEL_SIZE - 1) / 2)
+        cld_opd_sres_out[border:LEN_Y_in - border, border:LEN_X_in - border] = cld_opd_sres[0, :, :, 0]
+
+        h5f.close()
+
+        if out_file is not None:
+            np.save(out_file, (cld_opd_sres_out, bt, refl, cld_opd))
+        else:
+            return cld_opd_sres
+
     def run_inference_(self, bt, refl, cld_opd, refl_sub_lo, refl_sub_hi, refl_sub_std, LEN_Y, LEN_X):
 
         self.slc_x_m = slice(1, int(LEN_X / 2) + 4)
@@ -841,19 +890,19 @@ class SRCNN:
             cld_opd_us = smooth_2d(cld_opd_us)
         cld_opd_us = normalize(cld_opd_us, label_param, mean_std_dct)
 
-        refl_sub_lo = np.expand_dims(refl_sub_lo, axis=0)
-        refl_sub_lo = upsample_nearest(refl_sub_lo)
-        refl_sub_lo = refl_sub_lo[:, self.slc_y, self.slc_x]
-        refl_sub_lo = normalize(refl_sub_lo, 'refl_0_65um_nom', mean_std_dct)
-
-        refl_sub_hi = np.expand_dims(refl_sub_hi, axis=0)
-        refl_sub_hi = upsample_nearest(refl_sub_hi)
-        refl_sub_hi = refl_sub_hi[:, self.slc_y, self.slc_x]
-        refl_sub_hi = normalize(refl_sub_hi, 'refl_0_65um_nom', mean_std_dct)
-
-        refl_sub_std = np.expand_dims(refl_sub_std, axis=0)
-        refl_sub_std = upsample_nearest(refl_sub_std)
-        refl_sub_std = refl_sub_std[:, self.slc_y, self.slc_x]
+        # refl_sub_lo = np.expand_dims(refl_sub_lo, axis=0)
+        # refl_sub_lo = upsample_nearest(refl_sub_lo)
+        # refl_sub_lo = refl_sub_lo[:, self.slc_y, self.slc_x]
+        # refl_sub_lo = normalize(refl_sub_lo, 'refl_0_65um_nom', mean_std_dct)
+        #
+        # refl_sub_hi = np.expand_dims(refl_sub_hi, axis=0)
+        # refl_sub_hi = upsample_nearest(refl_sub_hi)
+        # refl_sub_hi = refl_sub_hi[:, self.slc_y, self.slc_x]
+        # refl_sub_hi = normalize(refl_sub_hi, 'refl_0_65um_nom', mean_std_dct)
+        #
+        # refl_sub_std = np.expand_dims(refl_sub_std, axis=0)
+        # refl_sub_std = upsample_nearest(refl_sub_std)
+        # refl_sub_std = refl_sub_std[:, self.slc_y, self.slc_x]
 
         t1 = time.time()
         print('upsample/normalize time: ', (t1 - t0))
-- 
GitLab