diff --git a/modules/deeplearning/esrgan_exp.py b/modules/deeplearning/esrgan_exp.py index 4d0380bf3cc3525756b11b3c8430e06481ffb0d5..7d87331ccb2a2af7afd716b644c6942318a18bfa 100644 --- a/modules/deeplearning/esrgan_exp.py +++ b/modules/deeplearning/esrgan_exp.py @@ -12,7 +12,8 @@ from util.setup import home_dir target_param = 'cld_opd_dcomp' # SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1" -SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1' +# SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1' +SAVED_MODEL_PATH = '/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/tf_model_esrgan' model = hub.load(SAVED_MODEL_PATH) @@ -27,25 +28,38 @@ def get_image_from_file(in_file): # s_x = slice(1854, 3902) # s_y = slice(1854, 3902) - # bt = get_grid_values_all(h5f, 'temp_11_0um_nom') - # bt = bt[s_y, s_x] + refl = get_grid_values_all(h5f, 'refl_0_65um_nom') + refl = refl[s_y, s_x] cld_opd = get_grid_values_all(h5f, target_param) cld_opd = cld_opd[s_y, s_x] shape = cld_opd.shape - data = cld_opd.flatten() + cld_opd = cld_opd.flatten() lo = 0.0 hi = 160.0 + cld_opd -= lo + cld_opd /= (hi - lo) - data -= lo - data /= (hi - lo) + not_valid = np.isnan(cld_opd) + cld_opd[not_valid] = 0 - not_valid = np.isnan(data) - data[not_valid] = 0 + cld_opd = np.reshape(cld_opd, shape) + # ----------------------------------------------------------- + refl = refl.flatten() - data = np.reshape(data, shape) + lo = -2.0 + hi = 120.0 + refl -= lo + refl /= (hi - lo) + + not_valid = np.isnan(refl) + refl[not_valid] = 0 + + refl = np.reshape(refl, shape) + + data = np.stack([refl, cld_opd], axis=2) return data @@ -86,6 +100,8 @@ def preprocess_image(hr_image): hr_image = tmp_image hr_image *= 255.0 + hr_image *= 255.0 + hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4 hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1]) hr_image = tf.cast(hr_image, tf.float32) @@ -94,10 +110,8 @@ def preprocess_image(hr_image): def run(in_file, out_file): - # model = hub.load(SAVED_MODEL_PATH) t0 = time.time() hr_image = get_image_from_file(in_file) - # hr_image = get_image(in_file) hr_image = preprocess_image(hr_image) t1 = time.time() print('processing time: ', (t1-t0))