diff --git a/modules/deeplearning/esrgan_exp.py b/modules/deeplearning/esrgan_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3d3ac2146511f1bad4ce0000c0e6f9af077151 --- /dev/null +++ b/modules/deeplearning/esrgan_exp.py @@ -0,0 +1,85 @@ +import os +import time +from PIL import Image +import numpy as np +import tensorflow as tf +import tensorflow_hub as hub +import matplotlib.pyplot as plt +import h5py +from util.util import scale, descale, get_grid_values_all +from util.setup import home_dir + +# SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1" +# SAVED_MODEL_PATH = '/Users/tomrink/Downloads/esrgan-tf2_1' +SAVED_MODEL_PATH = home_dir+'esrgan-tf2_1' + +model = hub.load(SAVED_MODEL_PATH) + + +def get_image(in_file): + h5f = h5py.File(in_file, 'r') + + # s_x = slice(2622, 3134) + # s_y = slice(2622, 3134) + s_x = slice(1854, 3902) + s_y = slice(1854, 3902) + + # bt = get_grid_values_all(h5f, 'temp_11_0um_nom') + # bt = bt[s_y, s_x] + + cld_opd = get_grid_values_all(h5f, 'cld_opd_dcomp') + cld_opd = cld_opd[s_y, s_x] + + shape = cld_opd.shape + data = cld_opd.flatten() + + lo = 0.0 + hi = 160.0 + + data -= lo + data /= (hi - lo) + + not_valid = np.isnan(data) + data[not_valid] = 0 + + data = np.reshape(data, shape) + + return data + + +def preprocess_image(hr_image): + """ Loads image from path and preprocesses to make it model ready + Args: + image_path: Path to the image file + """ + # hr_image = tf.image.decode_image(tf.io.read_file(image_path)) + # If PNG, remove the alpha channel. The model only supports + # images with 3 color channels. + if hr_image.shape[-1] == 4: + hr_image = hr_image[..., :-1] + elif len(hr_image.shape) == 2: + tmp_image = np.zeros([hr_image.shape[0], hr_image.shape[1], 3]) + tmp_image[:, :, 0] = hr_image + tmp_image[:, :, 1] = hr_image + tmp_image[:, :, 2] = hr_image + hr_image = tmp_image + hr_image *= 255.0 + + hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4 + hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1]) + hr_image = tf.cast(hr_image, tf.float32) + + return tf.expand_dims(hr_image, 0) + + +def run(in_file): + # model = hub.load(SAVED_MODEL_PATH) + hr_image = get_image(in_file) + hr_image = preprocess_image(hr_image) + t0 = time.time() + fake_image = model(hr_image) + t1 = time.time() + print('inference time: ', (t1-t0)) + fake_image = tf.squeeze(fake_image) + + return fake_image \ No newline at end of file