""" Evaluates the SavedModel of ESRGAN """
import os
import itertools
import functools
import argparse
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
tf.enable_v2_behavior()

def build_dataset(
    lr_glob,
    hr_glob,
    lr_crop_size=[128, 128],
    scale=4):
  """
    Builds a tf.data.Dataset from directory path.
    Args:
      lr_glob: Pattern to match Low Resolution images.
      hr_glob: Pattern to match High resolution images.
      lr_crop_size: Size of Low Resolution images to work on.
      scale: Scaling factor of the images to work on.
  """

  def _read_images_fn(low_res_images, high_res_images):
    for lr_image_path, hr_image_path in zip(low_res_images, high_res_images):
      lr_image = tf.image.decode_image(tf.io.read_file(lr_image_path))
      hr_image = tf.image.decode_image(tf.io.read_file(hr_image_path))
      for height in range(0, lr_image.shape[0] - lr_crop_size[0] + 1, 40):
        for width in range(0, lr_image.shape[1] - lr_crop_size[1] + 1, 40):
          lr_sub_image = tf.image.crop_to_bounding_box(
              lr_image,
              height, width,
              lr_crop_size[0], lr_crop_size[1])
          hr_sub_image = tf.image.crop_to_bounding_box(
              hr_image,
              height * scale, width * scale,
              lr_crop_size[0] * scale, lr_crop_size[1] * scale)
          yield (tf.cast(lr_sub_image, tf.float32),
                 tf.cast(hr_sub_image, tf.float32))

  hr_images = tf.io.gfile.glob(hr_glob)
  lr_images = tf.io.gfile.glob(lr_glob)
  hr_images = sorted(hr_images)
  lr_images = sorted(lr_images)
  dataset = tf.data.Dataset.from_generator(
      functools.partial(_read_images_fn, lr_images, hr_images),
      (tf.float32, tf.float32),
      (tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3])))
  return dataset


def main(**kwargs):
  total = kwargs["total"]
  dataset = build_dataset(kwargs["lr_files"], kwargs["hr_files"])
  dataset = dataset.batch(kwargs["batch_size"])
  count = itertools.count(start=1, step=kwargs["batch_size"])
  os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
  metrics = tf.keras.metrics.Mean()
  for lr_image, hr_image in dataset:
    # Loading the model multiple time is the only
    # way to preserve the quality, since the quality
    # degrades at every inference for no obvious reaons
    model = hub.load(kwargs["model"])
    super_res_image = model.call(lr_image)
    super_res_image = tf.clip_by_value(super_res_image, 0, 255)
    metrics(
        tf.reduce_mean(
            tf.image.psnr(
                super_res_image,
                hr_image,
                max_val=256)))
    c = next(count)
    if c >= total:
      break
    if not (c // 100) % 10:
      print(
          "%d Images Processed. Mean PSNR yet: %f" %
          (c, metrics.result().numpy()))
  print(
      "%d Images processed. Mean PSNR: %f" %
      (total, metrics.result().numpy()))
  with tf.io.gfile.GFile("PSNR_result.txt", "w") as f:
    f.write("%d Images processed. Mean PSNR: %f" %
            (c, metrics.result().numpy()))


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument(
      "--total",
      default=10000,
      help="Total number of sub images to work on. (default: 1000)")
  parser.add_argument(
      "--batch_size",
      default=16,
      type=int,
      help="Number of images per batch (default: 16)")
  parser.add_argument(
      "--lr_files",
      default=None,
      help="Pattern to match low resolution files")
  parser.add_argument(
    "--hr_files",
    default=None,
    help="Pattern to match High resolution images")
  parser.add_argument(
      "--model",
      default="https://github.com/captain-pool/GSOC/"
              "releases/download/1.0.0/esrgan.tar.gz",
      help="URL or Path to the SavedModel")
  parser.add_argument(
      "--verbose", "-v",
      default=0,
      action="count",
      help="Increase Verbosity of logging")

  flags, unknown = parser.parse_known_args()
  if not (flags.lr_files and flags.hr_files):
    logging.error("Must set flag --lr_files and --hr_files")
    sys.exit(1)
  log_levels = [logging.FATAL, logging.WARN, logging.INFO, logging.DEBUG]
  log_level = log_levels[min(flags.verbose, len(log_levels) - 1)]
  logging.set_verbosity(log_level)
  main(**vars(flags))