evaluate_psnr.py 4.24 KiB
""" Evaluates the SavedModel of ESRGAN """
import os
import itertools
import functools
import argparse
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
tf.enable_v2_behavior()
def build_dataset(
lr_glob,
hr_glob,
lr_crop_size=[128, 128],
scale=4):
"""
Builds a tf.data.Dataset from directory path.
Args:
lr_glob: Pattern to match Low Resolution images.
hr_glob: Pattern to match High resolution images.
lr_crop_size: Size of Low Resolution images to work on.
scale: Scaling factor of the images to work on.
"""
def _read_images_fn(low_res_images, high_res_images):
for lr_image_path, hr_image_path in zip(low_res_images, high_res_images):
lr_image = tf.image.decode_image(tf.io.read_file(lr_image_path))
hr_image = tf.image.decode_image(tf.io.read_file(hr_image_path))
for height in range(0, lr_image.shape[0] - lr_crop_size[0] + 1, 40):
for width in range(0, lr_image.shape[1] - lr_crop_size[1] + 1, 40):
lr_sub_image = tf.image.crop_to_bounding_box(
lr_image,
height, width,
lr_crop_size[0], lr_crop_size[1])
hr_sub_image = tf.image.crop_to_bounding_box(
hr_image,
height * scale, width * scale,
lr_crop_size[0] * scale, lr_crop_size[1] * scale)
yield (tf.cast(lr_sub_image, tf.float32),
tf.cast(hr_sub_image, tf.float32))
hr_images = tf.io.gfile.glob(hr_glob)
lr_images = tf.io.gfile.glob(lr_glob)
hr_images = sorted(hr_images)
lr_images = sorted(lr_images)
dataset = tf.data.Dataset.from_generator(
functools.partial(_read_images_fn, lr_images, hr_images),
(tf.float32, tf.float32),
(tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3])))
return dataset
def main(**kwargs):
total = kwargs["total"]
dataset = build_dataset(kwargs["lr_files"], kwargs["hr_files"])
dataset = dataset.batch(kwargs["batch_size"])
count = itertools.count(start=1, step=kwargs["batch_size"])
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
metrics = tf.keras.metrics.Mean()
for lr_image, hr_image in dataset:
# Loading the model multiple time is the only
# way to preserve the quality, since the quality
# degrades at every inference for no obvious reaons
model = hub.load(kwargs["model"])
super_res_image = model.call(lr_image)
super_res_image = tf.clip_by_value(super_res_image, 0, 255)
metrics(
tf.reduce_mean(
tf.image.psnr(
super_res_image,
hr_image,
max_val=256)))
c = next(count)
if c >= total:
break
if not (c // 100) % 10:
print(
"%d Images Processed. Mean PSNR yet: %f" %
(c, metrics.result().numpy()))
print(
"%d Images processed. Mean PSNR: %f" %
(total, metrics.result().numpy()))
with tf.io.gfile.GFile("PSNR_result.txt", "w") as f:
f.write("%d Images processed. Mean PSNR: %f" %
(c, metrics.result().numpy()))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--total",
default=10000,
help="Total number of sub images to work on. (default: 1000)")
parser.add_argument(
"--batch_size",
default=16,
type=int,
help="Number of images per batch (default: 16)")
parser.add_argument(
"--lr_files",
default=None,
help="Pattern to match low resolution files")
parser.add_argument(
"--hr_files",
default=None,
help="Pattern to match High resolution images")
parser.add_argument(
"--model",
default="https://github.com/captain-pool/GSOC/"
"releases/download/1.0.0/esrgan.tar.gz",
help="URL or Path to the SavedModel")
parser.add_argument(
"--verbose", "-v",
default=0,
action="count",
help="Increase Verbosity of logging")
flags, unknown = parser.parse_known_args()
if not (flags.lr_files and flags.hr_files):
logging.error("Must set flag --lr_files and --hr_files")
sys.exit(1)
log_levels = [logging.FATAL, logging.WARN, logging.INFO, logging.DEBUG]
log_level = log_levels[min(flags.verbose, len(log_levels) - 1)]
logging.set_verbosity(log_level)
main(**vars(flags))