Skip to content
Snippets Groups Projects
Select Git revision
  • 94c03213e38a9678f5861fe1d7b7d6ac91d88253
  • master default protected
  • use_flight_altitude
  • distribute
4 results

__init__.py

Blame
  • esrgan_exp.py 2.83 KiB
    import os
    import time
    from PIL import Image
    import numpy as np
    import tensorflow as tf
    #import tensorflow_hub as hub
    import matplotlib.pyplot as plt
    import h5py
    from util.util import descale2, get_grid_values_all
    from util.setup import home_dir
    
    target_param = 'cld_opd_dcomp'
    
    # SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
    # SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1'
    SAVED_MODEL_PATH = '/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/tf_model_esrgan/esrgan'
    
    # model = hub.load(SAVED_MODEL_PATH)
    model = tf.saved_model.load(SAVED_MODEL_PATH)
    
    
    def get_image_from_file(in_file):
        h5f = h5py.File(in_file, 'r')
    
        # s_x = slice(2622, 3134)
        # s_y = slice(2622, 3134)
        s_x = slice(2110, 3646)
        s_y = slice(2110, 3646)
        # s_x = slice(1854, 3902)
        # s_y = slice(1854, 3902)
    
        refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
        refl = refl[s_y, s_x]
    
        cld_opd = get_grid_values_all(h5f, target_param)
        cld_opd = cld_opd[s_y, s_x]
    
        shape = cld_opd.shape
        cld_opd = cld_opd.flatten()
    
        lo = 0.0
        hi = 160.0
        cld_opd -= lo
        cld_opd /= (hi - lo)
    
        not_valid = np.isnan(cld_opd)
        cld_opd[not_valid] = 0
    
        cld_opd = np.reshape(cld_opd, shape)
        # -----------------------------------------------------------
        refl = refl.flatten()
    
        lo = -2.0
        hi = 120.0
        refl -= lo
        refl /= (hi - lo)
    
        not_valid = np.isnan(refl)
        refl[not_valid] = 0
    
        refl = np.reshape(refl, shape)
    
        data = np.stack([refl, cld_opd], axis=2)
    
        return data
    
    
    def preprocess_image(hr_image):
      """ Loads image from path and preprocesses to make it model ready
          Args:
            image_path: Path to the image file
      """
    
      hr_image *= 255.0
    
      hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
      hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
      hr_image = tf.cast(hr_image, tf.float32)
    
      return tf.expand_dims(hr_image, 0)
    
    
    def run(in_file, out_file):
        t0 = time.time()
        image = get_image_from_file(in_file)
        hr_image = preprocess_image(image)
        t1 = time.time()
        print('processing time: ', (t1-t0))
        print('start inference: ', hr_image.shape)
        t0 = time.time()
        sres_image = model(hr_image)
        t1 = time.time()
        print('inference time: ', (t1-t0))
        sres_image = tf.squeeze(sres_image)
    
        sres_image /= 255.0
        hr_image /= 255.0
        hr_image = hr_image.numpy()
        hr_image = descale2(hr_image[:, :, :, 1], 0.0, 160.0)
        image = descale2(image[:, :, 1], 0.0, 160.0)
    
        sres_image = sres_image.numpy()
        sres_image = descale2(sres_image, 0.0, 160.0)
    
        if out_file is not None:
            # np.save(out_file, sres_image)
            with open(out_file, 'wb') as f:
                np.save(f, sres_image)
                np.save(f, hr_image)
                np.save(f, image)
        else:
            return sres_image, hr_image