main.py 5.65 KiB
import os
from pathlib import Path
from functools import partial
import argparse
# from absl import logging
import logging
from GSOC.E2_ESRGAN.lib import settings, train, model, utils
# from tensorflow.python.eager import profiler
# import tensorflow.compat.v2 as tf
# tf.enable_v2_behavior()
import tensorflow as tf
tf.compat.v1.disable_v2_behavior
logging.basicConfig(filename='/Users/tomrink/esrgan_log.txt', level=logging.DEBUG)
""" Enhanced Super Resolution GAN.
Citation:
@article{DBLP:journals/corr/abs-1809-00219,
author = {Xintao Wang and
Ke Yu and
Shixiang Wu and
Jinjin Gu and
Yihao Liu and
Chao Dong and
Chen Change Loy and
Yu Qiao and
Xiaoou Tang},
title = {{ESRGAN:} Enhanced Super-Resolution Generative Adversarial Networks},
journal = {CoRR},
volume = {abs/1809.00219},
year = {2018},
url = {http://arxiv.org/abs/1809.00219},
archivePrefix = {arXiv},
eprint = {1809.00219},
timestamp = {Fri, 05 Oct 2018 11:34:52 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1809-00219},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
#def main(**kwargs):
def run(kwargs):
""" Main function for training ESRGAN model and exporting it as a SavedModel2.0
Args:
config: path to config yaml file.
log_dir: directory to store summary for tensorboard.
data_dir: directory to store / access the dataset.
manual: boolean to denote if data_dir is a manual directory.
model_dir: directory to store the model into.
"""
for physical_device in tf.config.experimental.list_physical_devices("GPU"):
tf.config.experimental.set_memory_growth(physical_device, True)
strategy = utils.SingleDeviceStrategy()
sett = settings.Settings(kwargs["config"])
Stats = settings.Stats(os.path.join(sett.path, "stats.yaml"))
tf.random.set_seed(10)
with strategy.scope():
summary_writer_1 = tf.summary.create_file_writer(os.path.join(kwargs["log_dir"], "phase1"))
summary_writer_2 = tf.summary.create_file_writer(os.path.join(kwargs["log_dir"], "phase2"))
# profiler.start_profiler_server(6009)
discriminator = model.VGGArch(batch_size=sett["batch_size"], num_features=64)
if not kwargs["export_only"]:
generator = model.RRDBNet(out_channel=1)
logging.debug("Initiating Convolutions")
generator.unsigned_call(tf.random.normal([1, 128, 128, 1]))
training = train.Trainer(
summary_writer=summary_writer_1,
summary_writer_2=summary_writer_2,
settings=sett,
model_dir=kwargs["model_dir"],
data_dir=kwargs["data_dir"],
manual=kwargs["manual"],
strategy=strategy)
phases = list(map(lambda x: x.strip(), kwargs["phases"].lower().split("_")))
if not Stats["train_step_1"] and "phase1" in phases:
logging.info("starting phase 1")
training.warmup_generator(generator)
Stats["train_step_1"] = True
if not Stats["train_step_2"] and "phase2" in phases:
logging.info("starting phase 2")
training.train_gan(generator, discriminator)
Stats["train_step_2"] = True
if Stats["train_step_1"] and Stats["train_step_2"]:
# Attempting to save "Interpolated" Model as SavedModel2.0
interpolated_generator = utils.interpolate_generator(
partial(model.RRDBNet, out_channel=1, first_call=False),
discriminator,
sett["interpolation_parameter"],
[512, 512],
basepath=kwargs["model_dir"])
tf.saved_model.save(
interpolated_generator, os.path.join(
kwargs["model_dir"], "esrgan"))
#if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config/config.yaml",
help="path to configuration file. (default: %(default)s)")
parser.add_argument(
"--data_dir",
default=None,
help="directory to put the data. (default: %(default)s)")
parser.add_argument(
"--manual",
default=False,
help="specify if data_dir is a manual directory. (default: %(default)s)",
action="store_true")
parser.add_argument(
"--model_dir",
default='/Users/tomrink/tf_model_esrgan',
help="directory to put the model in.")
parser.add_argument(
"--log_dir",
default='/Users/tomrink/esrgan_log',
help="directory to story summaries for tensorboard.")
parser.add_argument(
"--phases",
default="phase1_phase2",
help="phases to train for seperated by '_'")
parser.add_argument(
"--export_only",
default=False,
action="store_true",
help="Exports to SavedModel")
parser.add_argument(
"--tpu",
default="",
help="Name of the TPU to be used")
parser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
help="each 'v' increases vebosity of logging.")
# FLAGS, unparsed = parser.parse_known_args()
FLAGS = {'config': str(Path.home()) + '/dev/python/modules/GSOC/E2_ESRGAN' + '/config/config.yaml',
'data_dir': None,
'manual': False,
#'model_dir': str(Path.home()) + '/tf_model_esrgan/',
'model_dir': '/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/tf_model_esrgan/',
'log_dir': str(Path.home()) + '/tf_logs_esrgan/',
'phases': 'phase1_phase2',
'export_only': False,
'tpu': '',
'verbose': 0}
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG]
log_level = log_levels[min(FLAGS['verbose'], len(log_levels) - 1)]
# logging.set_verbosity(log_level)
# FLAGS = vars(FLAGS)
# main(**vars(FLAGS))