diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py index 56b816f24b2c5b3867d0134e2ab409c1c6535f30..5d8887121fd7dfaaa4fad8d342bafe28138c4448 100644 --- a/modules/deeplearning/cloud_opd_srcnn_abi.py +++ b/modules/deeplearning/cloud_opd_srcnn_abi.py @@ -141,10 +141,10 @@ def get_min_max_std(grd_k): return lo, hi, std, avg -# def upsample_static(grd, x_2, y_2, t, s, y_k, x_k): -# grd = resample_2d_linear(x_2, y_2, grd, t, s, y_k, x_k) -# grd = grd[:, y_k, x_k] -# return grd +def upsample_static(grd, x_2, y_2, t, s, y_k, x_k): + grd = resample_2d_linear(x_2, y_2, grd, t, s) + # grd = grd[:, y_k, x_k] + return grd class SRCNN: @@ -705,12 +705,19 @@ def run_evaluate_static(in_file, out_file, ckpt_dir): cld_opd = cld_opd[int(ylen/2):ylen, :] cld_opd_hres = cld_opd.copy() - nn = SRCNN(LEN_Y=2*LEN_Y, LEN_X=2*LEN_X) + nn = SRCNN() + + slc_x = slice(0, (LEN_X - 16) + 4) + slc_y = slice(0, (LEN_Y - 16) + 4) + x_2 = np.arange((LEN_X - 16) + 4) + y_2 = np.arange((LEN_Y - 16) + 4) + t = np.arange(0, (LEN_X - 16) + 4, 0.5) + s = np.arange(0, (LEN_Y - 16) + 4, 0.5) refl = np.where(np.isnan(refl), 0, bt) - refl = refl[nn.slc_y, nn.slc_x] + refl = refl[slc_y, slc_x] refl = np.expand_dims(refl, axis=0) - refl = nn.upsample(refl) + refl = upsample_static(refl, x_2, y_2, t, s) print(refl.shape) refl = normalize(refl, 'refl_0_65um_nom', mean_std_dct) print('REFL done') @@ -718,7 +725,7 @@ def run_evaluate_static(in_file, out_file, ckpt_dir): bt = np.where(np.isnan(bt), 0, bt) bt = bt[nn.slc_y, nn.slc_x] bt = np.expand_dims(bt, axis=0) - bt = nn.upsample(bt) + bt = upsample_static(bt, x_2, y_2, t, s) bt = normalize(bt, 'temp_11_0um_nom', mean_std_dct) print('BT done') @@ -736,7 +743,7 @@ def run_evaluate_static(in_file, out_file, ckpt_dir): cld_opd = np.where(np.isnan(cld_opd), 0, cld_opd) cld_opd = cld_opd[nn.slc_y, nn.slc_x] cld_opd = np.expand_dims(cld_opd, axis=0) - cld_opd = nn.upsample(cld_opd) + cld_opd = upsample_static(cld_opd, x_2, y_2, t, s) cld_opd = normalize(cld_opd, label_param, mean_std_dct) print('OPD done')