From e673694be5899f86184f759c2d9d0bdb1890e1e6 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 6 Jul 2023 12:39:45 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_srcnn_abi.py | 52 ++++++++++++---------
 1 file changed, 30 insertions(+), 22 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 0fb761bd..40d2020a 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -729,16 +729,22 @@ class SRCNN:
         t1 = time.time()
         print('read data time: ', (t1 - t0))
 
-        self.run_inference_(bt, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, LEN_Y, LEN_X)
+        self.run_inference_(bt, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, 2*LEN_Y, 2*LEN_X)
 
     def run_inference_(self, bt, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, LEN_Y, LEN_X):
 
-        slc_x = slice(0, (LEN_X - 16) + 4)
-        slc_y = slice(0, (LEN_Y - 16) + 4)
-        x_2 = np.arange((LEN_X - 16) + 4)
-        y_2 = np.arange((LEN_Y - 16) + 4)
-        t = np.arange(0, (LEN_X - 16) + 4, 0.5)
-        s = np.arange(0, (LEN_Y - 16) + 4, 0.5)
+        self.slc_x_m = slice(1, int(LEN_X / 2) + 4)
+        self.slc_y_m = slice(1, int(LEN_Y / 2) + 4)
+        self.slc_x = slice(3, LEN_X + 5)
+        self.slc_y = slice(3, LEN_Y + 5)
+        self.slc_x_2 = slice(2, LEN_X + 7, 2)
+        self.slc_y_2 = slice(2, LEN_Y + 7, 2)
+        self.x_2 = np.arange(int(LEN_X / 2) + 3)
+        self.y_2 = np.arange(int(LEN_Y / 2) + 3)
+        self.t = np.arange(0, int(LEN_X / 2) + 3, 0.5)
+        self.s = np.arange(0, int(LEN_Y / 2) + 3, 0.5)
+        self.x_k = slice(1, LEN_X + 3)
+        self.y_k = slice(1, LEN_Y + 3)
 
         # refl = np.where(np.isnan(refl), 0, refl)
         # refl = refl[slc_y, slc_x]
@@ -750,42 +756,44 @@ class SRCNN:
 
         t0 = time.time()
         bt = np.where(np.isnan(bt), 0, bt)
-        bt = bt[slc_y, slc_x]
+        bt = bt[self.slc_y_m, self.slc_x_m]
         bt = np.expand_dims(bt, axis=0)
-        bt_us = upsample_static(bt, x_2, y_2, t, s, None, None)
+        # bt_us = upsample_static(bt, x_2, y_2, t, s, None, None)
+        bt_us = self.upsample(bt)
         bt_us = smooth_2d(bt_us)
         bt_us = normalize(bt_us, 'temp_11_0um_nom', mean_std_dct)
-        print('BT done')
 
-        refl_sub_lo = refl_sub_lo[slc_y, slc_x]
+        cld_opd = np.where(np.isnan(cld_opd), 0, cld_opd)
+        cld_opd = cld_opd[self.slc_y_m, self.slc_x_m]
+        cld_opd = np.expand_dims(cld_opd, axis=0)
+        # cld_opd_us = upsample_static(cld_opd, x_2, y_2, t, s, None, None)
+        cld_opd_us = self.upsample(cld_opd)
+        cld_opd_us = smooth_2d(cld_opd_us)
+        cld_opd_us = normalize(cld_opd_us, label_param, mean_std_dct)
+
         refl_sub_lo = np.expand_dims(refl_sub_lo, axis=0)
         refl_sub_lo = upsample_nearest(refl_sub_lo)
+        refl_sub_lo = refl_sub_lo[self.slc_y, self.slc_x]
         refl_sub_lo = normalize(refl_sub_lo, 'refl_0_65um_nom', mean_std_dct)
 
-        refl_sub_hi = refl_sub_hi[slc_y, slc_x]
         refl_sub_hi = np.expand_dims(refl_sub_hi, axis=0)
         refl_sub_hi = upsample_nearest(refl_sub_hi)
+        refl_sub_hi = refl_sub_hi[self.slc_y, self.slc_x]
         refl_sub_hi = normalize(refl_sub_hi, 'refl_0_65um_nom', mean_std_dct)
 
-        refl_sub_std = refl_sub_std[slc_y, slc_x]
         refl_sub_std = np.expand_dims(refl_sub_std, axis=0)
         refl_sub_std = upsample_nearest(refl_sub_std)
+        refl_sub_std = refl_sub_std[self.slc_y, self.slc_x]
 
-        cld_opd = np.where(np.isnan(cld_opd), 0, cld_opd)
-        cld_opd = cld_opd[slc_y, slc_x]
-        cld_opd = np.expand_dims(cld_opd, axis=0)
-        cld_opd_us = upsample_static(cld_opd, x_2, y_2, t, s, None, None)
-        cld_opd_us = smooth_2d(cld_opd_us)
-        cld_opd_us = normalize(cld_opd_us, label_param, mean_std_dct)
-        print('OPD done')
         t1 = time.time()
         print('upsample/normalize time: ', (t1 - t0))
 
         data = np.stack([bt_us, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd_us], axis=3)
 
-        # data = self.do_inference(data)
+        cld_opd_sres = self.do_inference(data)
+        cld_opd_sres = denormalize(cld_opd_sres, label_param, mean_std_dct)
 
-        return None
+        return cld_opd_sres
 
 
 def run_restore_static(directory, ckpt_dir, out_file=None):
-- 
GitLab