diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py index d51c45ae55cb6998fb2d2893e81ba3ee4f892b72..12f7147e6e01c73fabd98f2e32017f379e612bfe 100644 --- a/modules/deeplearning/cnn_cld_frac.py +++ b/modules/deeplearning/cnn_cld_frac.py @@ -201,29 +201,54 @@ def get_grid_cell_mean(grd_k): return s +# def get_label_data(grd_k): +# grd_k = np.where(np.isnan(grd_k), 0, grd_k) +# +# a = grd_k[:, 0::2, 0::2] +# b = grd_k[:, 1::2, 0::2] +# c = grd_k[:, 0::2, 1::2] +# d = grd_k[:, 1::2, 1::2] +# s_t = a + b + c + d +# s_t /= 4.0 +# +# blen, ylen, xlen = s_t.shape +# s_t = s_t.flatten() +# cat_0 = np.logical_and(s_t >= 0.0, s_t < 0.2) +# cat_1 = np.logical_and(s_t >= 0.2, s_t < 0.7) +# cat_2 = np.logical_and(s_t >= 0.7, s_t <= 1.0) +# +# s_t[cat_0] = 0 +# s_t[cat_1] = 1 +# s_t[cat_2] = 2 +# +# s_t = s_t.reshape((blen, ylen, xlen)) +# +# return s_t + + def get_label_data(grd_k): grd_k = np.where(np.isnan(grd_k), 0, grd_k) + cat_0 = np.logical_and(grd_k >= 0.0, grd_k < 0.2) + cat_1 = np.logical_and(grd_k >= 0.2, grd_k < 0.7) + cat_2 = np.logical_and(grd_k >= 0.7, grd_k <= 1.0) + grd_k[cat_0] = -1 + grd_k[cat_1] = 0 + grd_k[cat_2] = 1 a = grd_k[:, 0::2, 0::2] b = grd_k[:, 1::2, 0::2] c = grd_k[:, 0::2, 1::2] d = grd_k[:, 1::2, 1::2] - s_t = a + b + c + d - s_t /= 4.0 - - blen, ylen, xlen = s_t.shape - s_t = s_t.flatten() - cat_0 = np.logical_and(s_t >= 0.0, s_t < 0.2) - cat_1 = np.logical_and(s_t >= 0.2, s_t < 0.7) - cat_2 = np.logical_and(s_t >= 0.7, s_t <= 1.0) - - s_t[cat_0] = 0 - s_t[cat_1] = 1 - s_t[cat_2] = 2 + s = a + b + c + d - s_t = s_t.reshape((blen, ylen, xlen)) + cat_0 = s <= -3 + cat_1 = np.logical_and(s > -3, s < 3) + cat_2 = s >= 3 + s[cat_0] = 0 + s[cat_1] = 1 + s[cat_2] = 2 - return s_t + return s class SRCNN: @@ -384,9 +409,9 @@ class SRCNN: if DO_ESPCN: tmp = tmp[:, slc_y_2, slc_x_2] else: # Half res upsampled to full res: - tmp = upsample(tmp) - # tmp = upsample_mean(tmp) - # tmp = tmp[:, slc_y, slc_x] + # tmp = upsample(tmp) + tmp = upsample_mean(tmp) + tmp = tmp[:, slc_y, slc_x] if label_param != 'cloud_probability': tmp = normalize(tmp, label_param, mean_std_dct) if DO_ADD_NOISE: