diff --git a/modules/deeplearning/srcnn_cld_frac.py b/modules/deeplearning/srcnn_cld_frac.py index 57b9f948db4dc97d78b19385f1bf3bb8ef14c9dc..39d509728732e529eb78e946409e6ff46013bd6f 100644 --- a/modules/deeplearning/srcnn_cld_frac.py +++ b/modules/deeplearning/srcnn_cld_frac.py @@ -179,13 +179,20 @@ def get_grid_cell_mean(grd_k): return s -def get_label_data(grd_k): - grd_k = np.where(np.isnan(grd_k), 0, grd_k) - grd_k = np.where((grd_k >= 0.0) & (grd_k < 0.3), 0, grd_k) - grd_k = np.where((grd_k >= 0.3) & (grd_k < 0.7), 1, grd_k) - grd_k = np.where((grd_k >= 0.7) & (grd_k <= 1.0), 2, grd_k) - - return grd_k +def get_label_data(grd): + blen, ylen, xlen = grd.shape + grd = grd.flatten() + grd = np.where(np.isnan(grd), 0, grd) + cat_0 = np.logical_and(grd >= 0.0, grd < 0.3) + cat_1 = np.logical_and(grd >= 0.3, grd < 0.7) + cat_2 = np.logical_and(grd >= 0.7, grd <= 1.0) + + grd[cat_0] = 0 + grd[cat_1] = 1 + grd[cat_2] = 2 + + grd = grd.reshape((blen, ylen, xlen)) + return grd class SRCNN: