import os
from pathlib import Path
from functools import partial
import argparse
# from absl import logging
import logging
from GSOC.E2_ESRGAN.lib import settings, train, model, utils
# from tensorflow.python.eager import profiler
# import tensorflow.compat.v2 as tf
# tf.enable_v2_behavior()
import tensorflow as tf
tf.compat.v1.disable_v2_behavior

logging.basicConfig(filename='/Users/tomrink/esrgan_log.txt', level=logging.DEBUG)

num_chans = 2


""" Enhanced Super Resolution GAN.
    Citation:
      @article{DBLP:journals/corr/abs-1809-00219,
        author    = {Xintao Wang and
                     Ke Yu and
                     Shixiang Wu and
                     Jinjin Gu and
                     Yihao Liu and
                     Chao Dong and
                     Chen Change Loy and
                     Yu Qiao and
                     Xiaoou Tang},
        title     = {{ESRGAN:} Enhanced Super-Resolution Generative Adversarial Networks},
        journal   = {CoRR},
        volume    = {abs/1809.00219},
        year      = {2018},
        url       = {http://arxiv.org/abs/1809.00219},
        archivePrefix = {arXiv},
        eprint    = {1809.00219},
        timestamp = {Fri, 05 Oct 2018 11:34:52 +0200},
        biburl    = {https://dblp.org/rec/bib/journals/corr/abs-1809-00219},
        bibsource = {dblp computer science bibliography, https://dblp.org}
      }
"""


#def main(**kwargs):
def run(kwargs):
  """ Main function for training ESRGAN model and exporting it as a SavedModel2.0
      Args:
        config: path to config yaml file.
        log_dir: directory to store summary for tensorboard.
        data_dir: directory to store / access the dataset.
        manual: boolean to denote if data_dir is a manual directory.
        model_dir: directory to store the model into.
  """

  for physical_device in tf.config.experimental.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(physical_device, True)

  strategy = utils.SingleDeviceStrategy()
  sett = settings.Settings(kwargs["config"])
  Stats = settings.Stats(os.path.join(sett.path, "stats.yaml"))
  tf.random.set_seed(10)

  with strategy.scope():
    summary_writer_1 = tf.summary.create_file_writer(os.path.join(kwargs["log_dir"], "phase1"))
    summary_writer_2 = tf.summary.create_file_writer(os.path.join(kwargs["log_dir"], "phase2"))

    # profiler.start_profiler_server(6009)
    discriminator = model.VGGArch(batch_size=sett["batch_size"], num_features=64)

    if not kwargs["export_only"]:
      generator = model.RRDBNet(out_channel=1)
      logging.debug("Initiating Convolutions")
      generator.unsigned_call(tf.random.normal([1, 128, 128, num_chans]))
      training = train.Trainer(
          summary_writer=summary_writer_1,
          summary_writer_2=summary_writer_2,
          settings=sett,
          model_dir=kwargs["model_dir"],
          data_dir=kwargs["data_dir"],
          manual=kwargs["manual"],
          strategy=strategy)
      phases = list(map(lambda x: x.strip(), kwargs["phases"].lower().split("_")))
      if not Stats["train_step_1"] and "phase1" in phases:
        logging.info("starting phase 1")
        training.warmup_generator(generator)
        Stats["train_step_1"] = True
      if not Stats["train_step_2"] and "phase2" in phases:
        logging.info("starting phase 2")
        training.train_gan(generator, discriminator)
        Stats["train_step_2"] = True

  if Stats["train_step_1"] and Stats["train_step_2"]:
    # Attempting to save "Interpolated" Model as SavedModel2.0
    interpolated_generator = utils.interpolate_generator(
        partial(model.RRDBNet, out_channel=1, first_call=False),
        discriminator,
        sett["interpolation_parameter"],
        [512, 512],
        basepath=kwargs["model_dir"])

    tf.saved_model.save(
        interpolated_generator, os.path.join(
            kwargs["model_dir"], "esrgan"))


#if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
  "--config",
  default="config/config.yaml",
  help="path to configuration file. (default: %(default)s)")
parser.add_argument(
  "--data_dir",
  default=None,
  help="directory to put the data. (default: %(default)s)")
parser.add_argument(
  "--manual",
  default=False,
  help="specify if data_dir is a manual directory. (default: %(default)s)",
  action="store_true")
parser.add_argument(
  "--model_dir",
  default='/Users/tomrink/tf_model_esrgan',
  help="directory to put the model in.")
parser.add_argument(
  "--log_dir",
  default='/Users/tomrink/esrgan_log',
  help="directory to story summaries for tensorboard.")
parser.add_argument(
  "--phases",
  default="phase1_phase2",
  help="phases to train for seperated by '_'")
parser.add_argument(
  "--export_only",
  default=False,
  action="store_true",
  help="Exports to SavedModel")
parser.add_argument(
  "--tpu",
  default="",
  help="Name of the TPU to be used")
parser.add_argument(
  "-v",
  "--verbose",
  action="count",
  default=0,
  help="each 'v' increases vebosity of logging.")
# FLAGS, unparsed = parser.parse_known_args()

FLAGS = {'config': str(Path.home()) + '/dev/python/modules/GSOC/E2_ESRGAN' + '/config/config.yaml',
         'data_dir': None,
         'manual': False,
         #'model_dir': str(Path.home()) + '/tf_model_esrgan/',
         'model_dir': '/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/tf_model_esrgan/',
         'log_dir': str(Path.home()) + '/tf_logs_esrgan/',
         'phases': 'phase1_phase2',
         'export_only': False,
         'tpu': '',
         'verbose': 0}

log_levels = [logging.WARNING, logging.INFO, logging.DEBUG]
log_level = log_levels[min(FLAGS['verbose'], len(log_levels) - 1)]
# logging.set_verbosity(log_level)
# FLAGS = vars(FLAGS)
# main(**vars(FLAGS))