Skip to content
Snippets Groups Projects
Commit b1191322 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 95048aa8
No related branches found
No related tags found
No related merge requests found
import os
import time
from PIL import Image
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import h5py
from util.util import scale, descale, get_grid_values_all
from util.setup import home_dir
# SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
# SAVED_MODEL_PATH = '/Users/tomrink/Downloads/esrgan-tf2_1'
SAVED_MODEL_PATH = home_dir+'esrgan-tf2_1'
model = hub.load(SAVED_MODEL_PATH)
def get_image(in_file):
h5f = h5py.File(in_file, 'r')
# s_x = slice(2622, 3134)
# s_y = slice(2622, 3134)
s_x = slice(1854, 3902)
s_y = slice(1854, 3902)
# bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
# bt = bt[s_y, s_x]
cld_opd = get_grid_values_all(h5f, 'cld_opd_dcomp')
cld_opd = cld_opd[s_y, s_x]
shape = cld_opd.shape
data = cld_opd.flatten()
lo = 0.0
hi = 160.0
data -= lo
data /= (hi - lo)
not_valid = np.isnan(data)
data[not_valid] = 0
data = np.reshape(data, shape)
return data
def preprocess_image(hr_image):
""" Loads image from path and preprocesses to make it model ready
Args:
image_path: Path to the image file
"""
# hr_image = tf.image.decode_image(tf.io.read_file(image_path))
# If PNG, remove the alpha channel. The model only supports
# images with 3 color channels.
if hr_image.shape[-1] == 4:
hr_image = hr_image[..., :-1]
elif len(hr_image.shape) == 2:
tmp_image = np.zeros([hr_image.shape[0], hr_image.shape[1], 3])
tmp_image[:, :, 0] = hr_image
tmp_image[:, :, 1] = hr_image
tmp_image[:, :, 2] = hr_image
hr_image = tmp_image
hr_image *= 255.0
hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
hr_image = tf.cast(hr_image, tf.float32)
return tf.expand_dims(hr_image, 0)
def run(in_file):
# model = hub.load(SAVED_MODEL_PATH)
hr_image = get_image(in_file)
hr_image = preprocess_image(hr_image)
t0 = time.time()
fake_image = model(hr_image)
t1 = time.time()
print('inference time: ', (t1-t0))
fake_image = tf.squeeze(fake_image)
return fake_image
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment