From 3c652a43dd102f0c5af56c60f74a94b1a49313a6 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 6 Feb 2023 16:02:46 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/srcnn_cld_frac.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/modules/deeplearning/srcnn_cld_frac.py b/modules/deeplearning/srcnn_cld_frac.py
index 57b9f948..39d50972 100644
--- a/modules/deeplearning/srcnn_cld_frac.py
+++ b/modules/deeplearning/srcnn_cld_frac.py
@@ -179,13 +179,20 @@ def get_grid_cell_mean(grd_k):
     return s
 
 
-def get_label_data(grd_k):
-    grd_k = np.where(np.isnan(grd_k), 0, grd_k)
-    grd_k = np.where((grd_k >= 0.0) & (grd_k < 0.3), 0, grd_k)
-    grd_k = np.where((grd_k >= 0.3) & (grd_k < 0.7), 1, grd_k)
-    grd_k = np.where((grd_k >= 0.7) & (grd_k <= 1.0), 2, grd_k)
-
-    return grd_k
+def get_label_data(grd):
+    blen, ylen, xlen = grd.shape
+    grd = grd.flatten()
+    grd = np.where(np.isnan(grd), 0, grd)
+    cat_0 = np.logical_and(grd >= 0.0, grd < 0.3)
+    cat_1 = np.logical_and(grd >= 0.3, grd < 0.7)
+    cat_2 = np.logical_and(grd >= 0.7, grd <= 1.0)
+
+    grd[cat_0] = 0
+    grd[cat_1] = 1
+    grd[cat_2] = 2
+
+    grd = grd.reshape((blen, ylen, xlen))
+    return grd
 
 
 class SRCNN:
-- 
GitLab