diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py index c023a17c2fc59a70db7b5c437e16df764f455630..abd9d0bac205ed5a1d7374675c778ae7cbcd1ed1 100644 --- a/modules/deeplearning/cnn_cld_frac.py +++ b/modules/deeplearning/cnn_cld_frac.py @@ -21,7 +21,7 @@ LOG_DEVICE_PLACEMENT = False PROC_BATCH_SIZE = 4 PROC_BATCH_BUFFER_SIZE = 50000 -NumClasses = 3 +NumClasses = 5 if NumClasses == 2: NumLogits = 1 else: @@ -237,6 +237,30 @@ def get_label_data(grd_k): return s +def get_label_data_5cat(grd_k): + grd_k = np.where(np.isnan(grd_k), 0, grd_k) + grd_k = np.where(grd_k < 0.5, 0, 1) + + a = grd_k[:, 0::2, 0::2] + b = grd_k[:, 1::2, 0::2] + c = grd_k[:, 0::2, 1::2] + d = grd_k[:, 1::2, 1::2] + s = a + b + c + d + + cat_0 = (s == 0) + cat_1 = (s == 1) + cat_2 = (s == 2) + cat_3 = (s == 3) + cat_4 = (s == 4) + + s[cat_0] = 0 + s[cat_1] = 1 + s[cat_2] = 2 + s[cat_3] = 3 + s[cat_4] = 4 + + return s + class SRCNN: def __init__(self): @@ -405,7 +429,10 @@ class SRCNN: label = input_data[:, label_idx, :, :] label = label.copy() label = label[:, y_128, x_128] - label = get_label_data(label) + if NumClasses == 5: + label = get_label_data_5cat(label) + else: + label = get_label_data(label) if label_param != 'cloud_probability': label = normalize(label, label_param, mean_std_dct)