From 57daf95ab5a085f66074cc3c623fb0809cdcf6fc Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 14 Feb 2023 10:56:13 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cnn_cld_frac.py | 59 ++++++++++++++++++++--------
 1 file changed, 42 insertions(+), 17 deletions(-)

diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py
index d51c45ae..12f7147e 100644
--- a/modules/deeplearning/cnn_cld_frac.py
+++ b/modules/deeplearning/cnn_cld_frac.py
@@ -201,29 +201,54 @@ def get_grid_cell_mean(grd_k):
     return s
 
 
+# def get_label_data(grd_k):
+#     grd_k = np.where(np.isnan(grd_k), 0, grd_k)
+#
+#     a = grd_k[:, 0::2, 0::2]
+#     b = grd_k[:, 1::2, 0::2]
+#     c = grd_k[:, 0::2, 1::2]
+#     d = grd_k[:, 1::2, 1::2]
+#     s_t = a + b + c + d
+#     s_t /= 4.0
+#
+#     blen, ylen, xlen = s_t.shape
+#     s_t = s_t.flatten()
+#     cat_0 = np.logical_and(s_t >= 0.0, s_t < 0.2)
+#     cat_1 = np.logical_and(s_t >= 0.2, s_t < 0.7)
+#     cat_2 = np.logical_and(s_t >= 0.7, s_t <= 1.0)
+#
+#     s_t[cat_0] = 0
+#     s_t[cat_1] = 1
+#     s_t[cat_2] = 2
+#
+#     s_t = s_t.reshape((blen, ylen, xlen))
+#
+#     return s_t
+
+
 def get_label_data(grd_k):
     grd_k = np.where(np.isnan(grd_k), 0, grd_k)
+    cat_0 = np.logical_and(grd_k >= 0.0, grd_k < 0.2)
+    cat_1 = np.logical_and(grd_k >= 0.2, grd_k < 0.7)
+    cat_2 = np.logical_and(grd_k >= 0.7, grd_k <= 1.0)
+    grd_k[cat_0] = -1
+    grd_k[cat_1] = 0
+    grd_k[cat_2] = 1
 
     a = grd_k[:, 0::2, 0::2]
     b = grd_k[:, 1::2, 0::2]
     c = grd_k[:, 0::2, 1::2]
     d = grd_k[:, 1::2, 1::2]
-    s_t = a + b + c + d
-    s_t /= 4.0
-
-    blen, ylen, xlen = s_t.shape
-    s_t = s_t.flatten()
-    cat_0 = np.logical_and(s_t >= 0.0, s_t < 0.2)
-    cat_1 = np.logical_and(s_t >= 0.2, s_t < 0.7)
-    cat_2 = np.logical_and(s_t >= 0.7, s_t <= 1.0)
-
-    s_t[cat_0] = 0
-    s_t[cat_1] = 1
-    s_t[cat_2] = 2
+    s = a + b + c + d
 
-    s_t = s_t.reshape((blen, ylen, xlen))
+    cat_0 = s <= -3
+    cat_1 = np.logical_and(s > -3, s < 3)
+    cat_2 = s >= 3
+    s[cat_0] = 0
+    s[cat_1] = 1
+    s[cat_2] = 2
 
-    return s_t
+    return s
 
 
 class SRCNN:
@@ -384,9 +409,9 @@ class SRCNN:
         if DO_ESPCN:
             tmp = tmp[:, slc_y_2, slc_x_2]
         else:  # Half res upsampled to full res:
-            tmp = upsample(tmp)
-            # tmp = upsample_mean(tmp)
-            # tmp = tmp[:, slc_y, slc_x]
+            # tmp = upsample(tmp)
+            tmp = upsample_mean(tmp)
+            tmp = tmp[:, slc_y, slc_x]
         if label_param != 'cloud_probability':
             tmp = normalize(tmp, label_param, mean_std_dct)
             if DO_ADD_NOISE:
-- 
GitLab