Skip to content
Snippets Groups Projects
esrgan_exp.py 2.83 KiB
import os
import time
from PIL import Image
import numpy as np
import tensorflow as tf
#import tensorflow_hub as hub
import matplotlib.pyplot as plt
import h5py
from util.util import descale2, get_grid_values_all
from util.setup import home_dir

target_param = 'cld_opd_dcomp'

# SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
# SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1'
SAVED_MODEL_PATH = '/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/tf_model_esrgan/esrgan'

# model = hub.load(SAVED_MODEL_PATH)
model = tf.saved_model.load(SAVED_MODEL_PATH)


def get_image_from_file(in_file):
    h5f = h5py.File(in_file, 'r')

    # s_x = slice(2622, 3134)
    # s_y = slice(2622, 3134)
    s_x = slice(2110, 3646)
    s_y = slice(2110, 3646)
    # s_x = slice(1854, 3902)
    # s_y = slice(1854, 3902)

    refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
    refl = refl[s_y, s_x]

    cld_opd = get_grid_values_all(h5f, target_param)
    cld_opd = cld_opd[s_y, s_x]

    shape = cld_opd.shape
    cld_opd = cld_opd.flatten()

    lo = 0.0
    hi = 160.0
    cld_opd -= lo
    cld_opd /= (hi - lo)

    not_valid = np.isnan(cld_opd)
    cld_opd[not_valid] = 0

    cld_opd = np.reshape(cld_opd, shape)
    # -----------------------------------------------------------
    refl = refl.flatten()

    lo = -2.0
    hi = 120.0
    refl -= lo
    refl /= (hi - lo)

    not_valid = np.isnan(refl)
    refl[not_valid] = 0

    refl = np.reshape(refl, shape)

    data = np.stack([refl, cld_opd], axis=2)

    return data


def preprocess_image(hr_image):
  """ Loads image from path and preprocesses to make it model ready
      Args:
        image_path: Path to the image file
  """

  hr_image *= 255.0

  hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
  hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
  hr_image = tf.cast(hr_image, tf.float32)

  return tf.expand_dims(hr_image, 0)


def run(in_file, out_file):
    t0 = time.time()
    image = get_image_from_file(in_file)
    hr_image = preprocess_image(image)
    t1 = time.time()
    print('processing time: ', (t1-t0))
    print('start inference: ', hr_image.shape)
    t0 = time.time()
    sres_image = model(hr_image)
    t1 = time.time()
    print('inference time: ', (t1-t0))
    sres_image = tf.squeeze(sres_image)

    sres_image /= 255.0
    hr_image /= 255.0
    hr_image = hr_image.numpy()
    hr_image = descale2(hr_image[:, :, :, 1], 0.0, 160.0)
    image = descale2(image[:, :, 1], 0.0, 160.0)

    sres_image = sres_image.numpy()
    sres_image = descale2(sres_image, 0.0, 160.0)

    if out_file is not None:
        # np.save(out_file, sres_image)
        with open(out_file, 'wb') as f:
            np.save(f, sres_image)
            np.save(f, hr_image)
            np.save(f, image)
    else:
        return sres_image, hr_image