esrgan_exp.py 2.73 KiB
import os
import time
from PIL import Image
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import h5py
from util.util import scale, descale, get_grid_values_all
from util.setup import home_dir
target_param = 'cld_opd_dcomp'
# SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1'
model = hub.load(SAVED_MODEL_PATH)
def get_image_from_file(in_file):
h5f = h5py.File(in_file, 'r')
# s_x = slice(2622, 3134)
# s_y = slice(2622, 3134)
s_x = slice(2110, 3646)
s_y = slice(2110, 3646)
# s_x = slice(1854, 3902)
# s_y = slice(1854, 3902)
# bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
# bt = bt[s_y, s_x]
cld_opd = get_grid_values_all(h5f, target_param)
cld_opd = cld_opd[s_y, s_x]
shape = cld_opd.shape
data = cld_opd.flatten()
lo = 0.0
hi = 160.0
data -= lo
data /= (hi - lo)
not_valid = np.isnan(data)
data[not_valid] = 0
data = np.reshape(data, shape)
return data
def get_image(data):
shape = data.shape
data = data.flatten()
lo = 0.0
hi = 160.0
data -= lo
data /= (hi - lo)
not_valid = np.isnan(data)
data[not_valid] = 0
data = np.reshape(data, shape)
return data
def preprocess_image(hr_image):
""" Loads image from path and preprocesses to make it model ready
Args:
image_path: Path to the image file
"""
# hr_image = tf.image.decode_image(tf.io.read_file(image_path))
# If PNG, remove the alpha channel. The model only supports
# images with 3 color channels.
if hr_image.shape[-1] == 4:
hr_image = hr_image[..., :-1]
elif len(hr_image.shape) == 2:
tmp_image = np.zeros([hr_image.shape[0], hr_image.shape[1], 3])
tmp_image[:, :, 0] = hr_image
tmp_image[:, :, 1] = hr_image
tmp_image[:, :, 2] = hr_image
hr_image = tmp_image
hr_image *= 255.0
hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
hr_image = tf.cast(hr_image, tf.float32)
return tf.expand_dims(hr_image, 0)
def run(in_file, out_file):
# model = hub.load(SAVED_MODEL_PATH)
t0 = time.time()
hr_image = get_image_from_file(in_file)
# hr_image = get_image(in_file)
hr_image = preprocess_image(hr_image)
t1 = time.time()
print('processing time: ', (t1-t0))
print('start inference: ', hr_image.shape)
t0 = time.time()
fake_image = model(hr_image)
t1 = time.time()
print('inference time: ', (t1-t0))
fake_image = tf.squeeze(fake_image)
if out_file is not None:
np.save(out_file, fake_image)
else:
return fake_image, hr_image