Skip to content
Snippets Groups Projects
Select Git revision
  • efa258f4b174629d00d494f0b998fc14fa4d622b
  • master default protected
  • use_flight_altitude
  • distribute
4 results

cnn_l1b_l2_16.py

Blame
  • evaluate_psnr.py 4.24 KiB
    """ Evaluates the SavedModel of ESRGAN """
    import os
    import itertools
    import functools
    import argparse
    from absl import logging
    import tensorflow.compat.v2 as tf
    import tensorflow_hub as hub
    tf.enable_v2_behavior()
    
    def build_dataset(
        lr_glob,
        hr_glob,
        lr_crop_size=[128, 128],
        scale=4):
      """
        Builds a tf.data.Dataset from directory path.
        Args:
          lr_glob: Pattern to match Low Resolution images.
          hr_glob: Pattern to match High resolution images.
          lr_crop_size: Size of Low Resolution images to work on.
          scale: Scaling factor of the images to work on.
      """
    
      def _read_images_fn(low_res_images, high_res_images):
        for lr_image_path, hr_image_path in zip(low_res_images, high_res_images):
          lr_image = tf.image.decode_image(tf.io.read_file(lr_image_path))
          hr_image = tf.image.decode_image(tf.io.read_file(hr_image_path))
          for height in range(0, lr_image.shape[0] - lr_crop_size[0] + 1, 40):
            for width in range(0, lr_image.shape[1] - lr_crop_size[1] + 1, 40):
              lr_sub_image = tf.image.crop_to_bounding_box(
                  lr_image,
                  height, width,
                  lr_crop_size[0], lr_crop_size[1])
              hr_sub_image = tf.image.crop_to_bounding_box(
                  hr_image,
                  height * scale, width * scale,
                  lr_crop_size[0] * scale, lr_crop_size[1] * scale)
              yield (tf.cast(lr_sub_image, tf.float32),
                     tf.cast(hr_sub_image, tf.float32))
    
      hr_images = tf.io.gfile.glob(hr_glob)
      lr_images = tf.io.gfile.glob(lr_glob)
      hr_images = sorted(hr_images)
      lr_images = sorted(lr_images)
      dataset = tf.data.Dataset.from_generator(
          functools.partial(_read_images_fn, lr_images, hr_images),
          (tf.float32, tf.float32),
          (tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3])))
      return dataset
    
    
    def main(**kwargs):
      total = kwargs["total"]
      dataset = build_dataset(kwargs["lr_files"], kwargs["hr_files"])
      dataset = dataset.batch(kwargs["batch_size"])
      count = itertools.count(start=1, step=kwargs["batch_size"])
      os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
      metrics = tf.keras.metrics.Mean()
      for lr_image, hr_image in dataset:
        # Loading the model multiple time is the only
        # way to preserve the quality, since the quality
        # degrades at every inference for no obvious reaons
        model = hub.load(kwargs["model"])
        super_res_image = model.call(lr_image)
        super_res_image = tf.clip_by_value(super_res_image, 0, 255)
        metrics(
            tf.reduce_mean(
                tf.image.psnr(
                    super_res_image,
                    hr_image,
                    max_val=256)))
        c = next(count)
        if c >= total:
          break
        if not (c // 100) % 10:
          print(
              "%d Images Processed. Mean PSNR yet: %f" %
              (c, metrics.result().numpy()))
      print(
          "%d Images processed. Mean PSNR: %f" %
          (total, metrics.result().numpy()))
      with tf.io.gfile.GFile("PSNR_result.txt", "w") as f:
        f.write("%d Images processed. Mean PSNR: %f" %
                (c, metrics.result().numpy()))
    
    
    if __name__ == "__main__":
      parser = argparse.ArgumentParser()
      parser.add_argument(
          "--total",
          default=10000,
          help="Total number of sub images to work on. (default: 1000)")
      parser.add_argument(
          "--batch_size",
          default=16,
          type=int,
          help="Number of images per batch (default: 16)")
      parser.add_argument(
          "--lr_files",
          default=None,
          help="Pattern to match low resolution files")
      parser.add_argument(
        "--hr_files",
        default=None,
        help="Pattern to match High resolution images")
      parser.add_argument(
          "--model",
          default="https://github.com/captain-pool/GSOC/"
                  "releases/download/1.0.0/esrgan.tar.gz",
          help="URL or Path to the SavedModel")
      parser.add_argument(
          "--verbose", "-v",
          default=0,
          action="count",
          help="Increase Verbosity of logging")
    
      flags, unknown = parser.parse_known_args()
      if not (flags.lr_files and flags.hr_files):
        logging.error("Must set flag --lr_files and --hr_files")
        sys.exit(1)
      log_levels = [logging.FATAL, logging.WARN, logging.INFO, logging.DEBUG]
      log_level = log_levels[min(flags.verbose, len(log_levels) - 1)]
      logging.set_verbosity(log_level)
      main(**vars(flags))