diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py index aff1f2a5d2701c51ea11940e0272b165e2699401..703865d8c51e14cc3370a1c909afde339bec453a 100644 --- a/modules/deeplearning/cnn_cld_frac_mod_res.py +++ b/modules/deeplearning/cnn_cld_frac_mod_res.py @@ -402,7 +402,10 @@ class SRCNN: label = input_label label = label.copy() label = label[:, y_128, x_128] - label = get_label_data(label) + if NumClasses == 5: + label = get_label_data_5cat(label) + else: + label = get_label_data(label) if label_param != 'cloud_probability': label = normalize(label, label_param, mean_std_dct)