From 570d5571a6e7b36a4901eb8e7f827fa56314748f Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 10 Jul 2023 10:52:45 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_srcnn_abi.py | 44 +++++++++++++++++----
 1 file changed, 36 insertions(+), 8 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 2b1082f1..6d328b52 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -721,28 +721,54 @@ class SRCNN:
         gc.collect()
         t0 = time.time()
 
+        s_x = slice(1812, 3612)
+        s_y = slice(1812, 3612)
+
         h5f = h5py.File(in_file, 'r')
 
         refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
+        print('FD dims: ', refl.shape)
+        refl = refl[s_y, s_x]
         LEN_Y, LEN_X = refl.shape
+        print('sub dims: ', refl.shape)
+
         bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
+        bt = bt[s_y, s_x]
+
         cld_opd = get_grid_values_all(h5f, 'cld_opd_dcomp')
+        cld_opd = cld_opd[s_y, s_x]
+
         refl_sub_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
+        refl_sub_lo = refl_sub_lo[s_y, s_x]
+
         refl_sub_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
+        refl_sub_hi = refl_sub_hi[s_y, s_x]
+
         refl_sub_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
+        refl_sub_std = refl_sub_std[s_y, s_x]
         t1 = time.time()
         print('read data time: ', (t1 - t0))
 
-        cld_opd_sres = self.run_inference_(bt, refl, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, 2*LEN_Y, 2*LEN_X)
+        LEN_Y -= 8
+        LEN_X -= 8
+
+        LEN_Y = 2 * (LEN_Y - 8)
+        LEN_X = 2 * (LEN_X - 8)
+
+        t0 = time.time()
+        cld_opd_sres, LEN_Y_in, LEN_X_in = self.run_inference_(bt, refl, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, LEN_Y, LEN_X)
+        t1 = time.time()
+        print('inference time: ', (t1 - t0))
+        print(cld_opd_sres.shape)
 
-        cld_opd_sres_out = np.zeros((LEN_Y, LEN_X), dtype=np.int8)
+        cld_opd_sres_out = np.zeros((LEN_Y_in, LEN_X_in), dtype=np.int8)
         border = int((KERNEL_SIZE - 1) / 2)
-        cld_opd_sres_out[border:LEN_Y - border, border:LEN_X - border] = cld_opd_sres[0, :, :]
+        cld_opd_sres_out[border:LEN_Y_in - border, border:LEN_X_in - border] = cld_opd_sres[0, :, :, 0]
 
         h5f.close()
 
         if out_file is not None:
-            np.save(out_file, (cld_opd_sres, bt, refl, cld_opd))
+            np.save(out_file, (cld_opd_sres_out, bt, refl, cld_opd))
         else:
             return cld_opd_sres
 
@@ -760,6 +786,8 @@ class SRCNN:
         self.s = np.arange(0, int(LEN_Y / 2) + 3, 0.5)
         self.x_k = slice(1, LEN_X + 3)
         self.y_k = slice(1, LEN_Y + 3)
+        self.LEN_X = LEN_X
+        self.LEN_Y = LEN_Y
 
         t0 = time.time()
         bt = np.where(np.isnan(bt), 0, bt)
@@ -787,17 +815,17 @@ class SRCNN:
 
         refl_sub_lo = np.expand_dims(refl_sub_lo, axis=0)
         refl_sub_lo = upsample_nearest(refl_sub_lo)
-        refl_sub_lo = refl_sub_lo[self.slc_y, self.slc_x]
+        refl_sub_lo = refl_sub_lo[:, self.slc_y, self.slc_x]
         refl_sub_lo = normalize(refl_sub_lo, 'refl_0_65um_nom', mean_std_dct)
 
         refl_sub_hi = np.expand_dims(refl_sub_hi, axis=0)
         refl_sub_hi = upsample_nearest(refl_sub_hi)
-        refl_sub_hi = refl_sub_hi[self.slc_y, self.slc_x]
+        refl_sub_hi = refl_sub_hi[:, self.slc_y, self.slc_x]
         refl_sub_hi = normalize(refl_sub_hi, 'refl_0_65um_nom', mean_std_dct)
 
         refl_sub_std = np.expand_dims(refl_sub_std, axis=0)
         refl_sub_std = upsample_nearest(refl_sub_std)
-        refl_sub_std = refl_sub_std[self.slc_y, self.slc_x]
+        refl_sub_std = refl_sub_std[:, self.slc_y, self.slc_x]
 
         t1 = time.time()
         print('upsample/normalize time: ', (t1 - t0))
@@ -807,7 +835,7 @@ class SRCNN:
         cld_opd_sres = self.do_inference(data)
         cld_opd_sres = denormalize(cld_opd_sres, label_param, mean_std_dct)
 
-        return cld_opd_sres
+        return cld_opd_sres, bt_us.shape[1], bt_us.shape[2]
 
 
 def run_restore_static(directory, ckpt_dir, out_file=None):
-- 
GitLab