From 070b553222d12882e02e64ddf1f36ed344171238 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 20 Mar 2023 09:50:27 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/cnn_cld_frac.py | 31 ++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py index c023a17c..abd9d0ba 100644 --- a/modules/deeplearning/cnn_cld_frac.py +++ b/modules/deeplearning/cnn_cld_frac.py @@ -21,7 +21,7 @@ LOG_DEVICE_PLACEMENT = False PROC_BATCH_SIZE = 4 PROC_BATCH_BUFFER_SIZE = 50000 -NumClasses = 3 +NumClasses = 5 if NumClasses == 2: NumLogits = 1 else: @@ -237,6 +237,30 @@ def get_label_data(grd_k): return s +def get_label_data_5cat(grd_k): + grd_k = np.where(np.isnan(grd_k), 0, grd_k) + grd_k = np.where(grd_k < 0.5, 0, 1) + + a = grd_k[:, 0::2, 0::2] + b = grd_k[:, 1::2, 0::2] + c = grd_k[:, 0::2, 1::2] + d = grd_k[:, 1::2, 1::2] + s = a + b + c + d + + cat_0 = (s == 0) + cat_1 = (s == 1) + cat_2 = (s == 2) + cat_3 = (s == 3) + cat_4 = (s == 4) + + s[cat_0] = 0 + s[cat_1] = 1 + s[cat_2] = 2 + s[cat_3] = 3 + s[cat_4] = 4 + + return s + class SRCNN: def __init__(self): @@ -405,7 +429,10 @@ class SRCNN: label = input_data[:, label_idx, :, :] label = label.copy() label = label[:, y_128, x_128] - label = get_label_data(label) + if NumClasses == 5: + label = get_label_data_5cat(label) + else: + label = get_label_data(label) if label_param != 'cloud_probability': label = normalize(label, label_param, mean_std_dct) -- GitLab