From 070b553222d12882e02e64ddf1f36ed344171238 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 20 Mar 2023 09:50:27 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cnn_cld_frac.py | 31 ++++++++++++++++++++++++++--
 1 file changed, 29 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py
index c023a17c..abd9d0ba 100644
--- a/modules/deeplearning/cnn_cld_frac.py
+++ b/modules/deeplearning/cnn_cld_frac.py
@@ -21,7 +21,7 @@ LOG_DEVICE_PLACEMENT = False
 PROC_BATCH_SIZE = 4
 PROC_BATCH_BUFFER_SIZE = 50000
 
-NumClasses = 3
+NumClasses = 5
 if NumClasses == 2:
     NumLogits = 1
 else:
@@ -237,6 +237,30 @@ def get_label_data(grd_k):
     return s
 
 
+def get_label_data_5cat(grd_k):
+    grd_k = np.where(np.isnan(grd_k), 0, grd_k)
+    grd_k = np.where(grd_k < 0.5, 0, 1)
+
+    a = grd_k[:, 0::2, 0::2]
+    b = grd_k[:, 1::2, 0::2]
+    c = grd_k[:, 0::2, 1::2]
+    d = grd_k[:, 1::2, 1::2]
+    s = a + b + c + d
+
+    cat_0 = (s == 0)
+    cat_1 = (s == 1)
+    cat_2 = (s == 2)
+    cat_3 = (s == 3)
+    cat_4 = (s == 4)
+
+    s[cat_0] = 0
+    s[cat_1] = 1
+    s[cat_2] = 2
+    s[cat_3] = 3
+    s[cat_4] = 4
+
+    return s
+
 class SRCNN:
     
     def __init__(self):
@@ -405,7 +429,10 @@ class SRCNN:
         label = input_data[:, label_idx, :, :]
         label = label.copy()
         label = label[:, y_128, x_128]
-        label = get_label_data(label)
+        if NumClasses == 5:
+            label = get_label_data_5cat(label)
+        else:
+            label = get_label_data(label)
 
         if label_param != 'cloud_probability':
             label = normalize(label, label_param, mean_std_dct)
-- 
GitLab