import os import time from PIL import Image import numpy as np import tensorflow as tf #import tensorflow_hub as hub import matplotlib.pyplot as plt import h5py from util.util import scale, descale, get_grid_values_all from util.setup import home_dir target_param = 'cld_opd_dcomp' # SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1" # SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1' SAVED_MODEL_PATH = '/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/tf_model_esrgan/esrgan' # model = hub.load(SAVED_MODEL_PATH) model = tf.saved_model.load(SAVED_MODEL_PATH) def get_image_from_file(in_file): h5f = h5py.File(in_file, 'r') # s_x = slice(2622, 3134) # s_y = slice(2622, 3134) s_x = slice(2110, 3646) s_y = slice(2110, 3646) # s_x = slice(1854, 3902) # s_y = slice(1854, 3902) refl = get_grid_values_all(h5f, 'refl_0_65um_nom') refl = refl[s_y, s_x] cld_opd = get_grid_values_all(h5f, target_param) cld_opd = cld_opd[s_y, s_x] shape = cld_opd.shape cld_opd = cld_opd.flatten() lo = 0.0 hi = 160.0 cld_opd -= lo cld_opd /= (hi - lo) not_valid = np.isnan(cld_opd) cld_opd[not_valid] = 0 cld_opd = np.reshape(cld_opd, shape) # ----------------------------------------------------------- refl = refl.flatten() lo = -2.0 hi = 120.0 refl -= lo refl /= (hi - lo) not_valid = np.isnan(refl) refl[not_valid] = 0 refl = np.reshape(refl, shape) data = np.stack([refl, cld_opd], axis=2) return data def get_image(data): shape = data.shape data = data.flatten() lo = 0.0 hi = 160.0 data -= lo data /= (hi - lo) not_valid = np.isnan(data) data[not_valid] = 0 data = np.reshape(data, shape) return data def preprocess_image(hr_image): """ Loads image from path and preprocesses to make it model ready Args: image_path: Path to the image file """ # hr_image = tf.image.decode_image(tf.io.read_file(image_path)) # If PNG, remove the alpha channel. The model only supports # images with 3 color channels. if hr_image.shape[-1] == 4: hr_image = hr_image[..., :-1] elif len(hr_image.shape) == 2: tmp_image = np.zeros([hr_image.shape[0], hr_image.shape[1], 3]) tmp_image[:, :, 0] = hr_image tmp_image[:, :, 1] = hr_image tmp_image[:, :, 2] = hr_image hr_image = tmp_image hr_image *= 255.0 hr_image *= 255.0 hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4 hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1]) hr_image = tf.cast(hr_image, tf.float32) return tf.expand_dims(hr_image, 0) def run(in_file, out_file): t0 = time.time() image = get_image_from_file(in_file) hr_image = preprocess_image(image) t1 = time.time() print('processing time: ', (t1-t0)) print('start inference: ', hr_image.shape) t0 = time.time() sres_image = model(hr_image) t1 = time.time() print('inference time: ', (t1-t0)) sres_image = tf.squeeze(sres_image) if out_file is not None: # np.save(out_file, sres_image) with open(out_file, 'wb') as f: np.save(f, sres_image) np.save(f, hr_image) np.save(f, image) else: return sres_image, hr_image