From 5e60c876227ef3b090dd7b5a96e9de585aa392a4 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 28 Dec 2023 16:19:27 -0600
Subject: [PATCH] snapshot...

---
 .../deeplearning/cloud_fraction_fcn_abi.py    | 21 ++++++++++++-------
 1 file changed, 13 insertions(+), 8 deletions(-)

diff --git a/modules/deeplearning/cloud_fraction_fcn_abi.py b/modules/deeplearning/cloud_fraction_fcn_abi.py
index 3e7008d7..631f6e84 100644
--- a/modules/deeplearning/cloud_fraction_fcn_abi.py
+++ b/modules/deeplearning/cloud_fraction_fcn_abi.py
@@ -20,7 +20,8 @@ LOG_DEVICE_PLACEMENT = False
 PROC_BATCH_SIZE = 4
 PROC_BATCH_BUFFER_SIZE = 5000
 
-NumClasses = 5
+NumClasses = 7
+# NumClasses = 5
 if NumClasses == 2:
     NumLogits = 1
 else:
@@ -186,7 +187,7 @@ def get_label_data_5cat(grd_k):
     return s
 
 
-def get_label_data_8cat(grd_k):
+def get_label_data_7cat(grd_k):
     grd_k = np.where(np.isnan(grd_k), 0, grd_k)
     grd_k = np.where(grd_k < 0.5, 0, 1)
 
@@ -197,21 +198,23 @@ def get_label_data_8cat(grd_k):
 
     cat_0 = np.logical_and(s >= 0, s < 1)  # CLDY
     cat_1 = np.logical_and(s >= 1, s < 4)
-    cat_2 = np.logical_and(s >= 4, s < 6)
-    cat_3 = np.logical_and(s >= 6, s < 9)
-    cat_4 = np.logical_and(s >= 9, s < 11)
-    cat_5 = np.logical_and(s >= 11, s < 14)
-    cat_6 = np.logical_and(s >= 14, s <= 15)
-    cat_7 = np.logical_and(s > 15, s <= 16)  # CLR
+    cat_2 = np.logical_and(s >= 4, s < 7)
+    cat_3 = np.logical_and(s >= 7, s < 10)
+    cat_4 = np.logical_and(s >= 10, s < 13)
+    cat_5 = np.logical_and(s >= 13, s <= 15)
+    cat_6 = np.logical_and(s > 15, s <= 16)  # CLR
 
     s[cat_0] = 0
     s[cat_1] = 1
     s[cat_2] = 2
     s[cat_3] = 3
     s[cat_4] = 4
+    s[cat_5] = 5
+    s[cat_6] = 6
 
     return s
 
+
 def get_label_data(grd_k):
     grd_k = np.where(np.isnan(grd_k), 0, grd_k)
     grd_k = np.where(grd_k < 0.5, 0, 1)
@@ -391,6 +394,8 @@ class SRCNN:
         label = label[:, y_64, x_64]
         if NumClasses == 5:
             label = get_label_data_5cat(label)
+        elif NumClasses == 7:
+            label = get_label_data_7cat(label)
         else:
             label = get_label_data(label)
 
-- 
GitLab