Skip to content
Snippets Groups Projects
Select Git revision
  • e7ebb3bfadae5c3c27d4436fd822b2f8fba548e9
  • master default protected
  • use_flight_altitude
  • distribute
4 results

__init__.py

Blame
  • esrgan_exp.py 2.73 KiB
    import os
    import time
    from PIL import Image
    import numpy as np
    import tensorflow as tf
    import tensorflow_hub as hub
    import matplotlib.pyplot as plt
    import h5py
    from util.util import scale, descale, get_grid_values_all
    from util.setup import home_dir
    
    target_param = 'cld_opd_dcomp'
    
    # SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
    SAVED_MODEL_PATH = home_dir + '/esrgan-tf2_1'
    
    model = hub.load(SAVED_MODEL_PATH)
    
    
    def get_image_from_file(in_file):
        h5f = h5py.File(in_file, 'r')
    
        # s_x = slice(2622, 3134)
        # s_y = slice(2622, 3134)
        s_x = slice(2110, 3646)
        s_y = slice(2110, 3646)
        # s_x = slice(1854, 3902)
        # s_y = slice(1854, 3902)
    
        # bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
        # bt = bt[s_y, s_x]
    
        cld_opd = get_grid_values_all(h5f, target_param)
        cld_opd = cld_opd[s_y, s_x]
    
        shape = cld_opd.shape
        data = cld_opd.flatten()
    
        lo = 0.0
        hi = 160.0
    
        data -= lo
        data /= (hi - lo)
    
        not_valid = np.isnan(data)
        data[not_valid] = 0
    
        data = np.reshape(data, shape)
    
        return data
    
    
    def get_image(data):
        shape = data.shape
        data = data.flatten()
    
        lo = 0.0
        hi = 160.0
    
        data -= lo
        data /= (hi - lo)
    
        not_valid = np.isnan(data)
        data[not_valid] = 0
    
        data = np.reshape(data, shape)
    
        return data
    
    
    def preprocess_image(hr_image):
      """ Loads image from path and preprocesses to make it model ready
          Args:
            image_path: Path to the image file
      """
      # hr_image = tf.image.decode_image(tf.io.read_file(image_path))
      # If PNG, remove the alpha channel. The model only supports
      # images with 3 color channels.
      if hr_image.shape[-1] == 4:
          hr_image = hr_image[..., :-1]
      elif len(hr_image.shape) == 2:
          tmp_image = np.zeros([hr_image.shape[0], hr_image.shape[1], 3])
          tmp_image[:, :, 0] = hr_image
          tmp_image[:, :, 1] = hr_image
          tmp_image[:, :, 2] = hr_image
          hr_image = tmp_image
          hr_image *= 255.0
    
      hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
      hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
      hr_image = tf.cast(hr_image, tf.float32)
    
      return tf.expand_dims(hr_image, 0)
    
    
    def run(in_file, out_file):
        # model = hub.load(SAVED_MODEL_PATH)
        t0 = time.time()
        hr_image = get_image_from_file(in_file)
        # hr_image = get_image(in_file)
        hr_image = preprocess_image(hr_image)
        t1 = time.time()
        print('processing time: ', (t1-t0))
        print('start inference: ', hr_image.shape)
        t0 = time.time()
        fake_image = model(hr_image)
        t1 = time.time()
        print('inference time: ', (t1-t0))
        fake_image = tf.squeeze(fake_image)
    
        if out_file is not None:
            np.save(out_file, fake_image)
        else:
            return fake_image, hr_image