import os import time from functools import partial # from absl import logging import logging import tensorflow as tf # from lib import utils, dataset from GSOC.E2_ESRGAN.lib import utils, dataset import glob from pathlib import Path logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG) NUM_WU_EPOCHS = 14 NUM_EPOCHS = 20 num_chans = 2 PSNR_MAX = 256.0 class Trainer(object): """ Trainer class for ESRGAN """ def __init__( self, summary_writer, summary_writer_2, settings, model_dir='/Users/tomrink/tf_model_esgan/', data_dir=None, manual=False, strategy=None): """ Setup the values and variables for Training. Args: summary_writer: tf.summary.SummaryWriter object to write summaries for Tensorboard. settings: settings object for fetching data from config files. data_dir (default: None): path where the data downloaded should be stored / accessed. manual (default: False): boolean to represent if data_dir is a manual dir. """ self.settings = settings self.model_dir = model_dir self.summary_writer = summary_writer self.summary_writer_2 = summary_writer_2 self.strategy = strategy dataset_args = self.settings["dataset"] self.batch_size = self.settings["batch_size"] hr_size = dataset_args['hr_dimension'] lr_size = int(hr_size // 4) filenames = glob.glob('/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/clavrx_surfrad_abi_cld_opd_1_512x512/train_ires_????.npy') self.dataset = dataset.OpdNpyDataset(filenames, hr_size, lr_size, batch_size=self.batch_size).dataset print('OPD dataset initialized...') def warmup_generator(self, generator): """ Training on L1 Loss to warmup the Generator. Minimizing the L1 Loss will reduce the Peak Signal to Noise Ratio (PSNR) of the generated image from the generator. This trained generator is then used to bootstrap the training of the GAN, creating better image inputs instead of random noises. Args: generator: Model Object for the Generator """ # Loading up phase parameters print('begin warmup_generator...') warmup_num_iter = self.settings.get("warmup_num_iter", None) phase_args = self.settings["train_psnr"] decay_params = phase_args["adam"]["decay"] decay_step = decay_params["step"] decay_factor = decay_params["factor"] total_steps = phase_args["num_steps"] metric = tf.keras.metrics.Mean() psnr_metric = tf.keras.metrics.Mean() # Generator Optimizer G_optimizer = tf.optimizers.Adam( learning_rate=phase_args["adam"]["initial_lr"], beta_1=phase_args["adam"]["beta_1"], beta_2=phase_args["adam"]["beta_2"]) checkpoint = tf.train.Checkpoint( G=generator, G_optimizer=G_optimizer) ckpt_manager = tf.train.CheckpointManager(checkpoint, self.model_dir, max_to_keep=7) status = utils.load_checkpoint(checkpoint, "phase_1", self.model_dir) logging.debug("phase_1 status object: {}".format(status)) previous_loss = 0 start_time = time.time() # Training starts def _step_fn(image_lr, image_hr): logging.debug("Starting Distributed Step") with tf.GradientTape() as tape: fake = generator.unsigned_call(image_lr) loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) # loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) # gen_vars = list(set(generator.trainable_variables)) gen_vars = generator.trainable_variables gradient = tape.gradient(loss, gen_vars) G_optimizer.apply_gradients(zip(gradient, gen_vars)) mean_loss = metric(loss) logging.debug("Ending Distributed Step") return tf.cast(G_optimizer.iterations, tf.float32) @tf.function def train_step(image_lr, image_hr): # distributed_metric = self.strategy.experimental_run_v2(_step_fn, args=[image_lr, image_hr]) distributed_metric = _step_fn(image_lr, image_hr) # mean_metric = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, distributed_metric, axis=None) # return mean_metric return distributed_metric for epoch in range(NUM_WU_EPOCHS): print('start epoch #: ', epoch) for image_lr, image_hr in self.dataset: num_steps = train_step(image_lr, image_hr) if status: status.assert_consumed() logging.info( "consumed checkpoint for phase_1 successfully") status = None if not num_steps % decay_step: # Decay Learning Rate logging.debug( "Learning Rate: %s" % G_optimizer.learning_rate.numpy) G_optimizer.learning_rate.assign( G_optimizer.learning_rate * decay_factor) logging.debug( "Decayed Learning Rate by %f." "Current Learning Rate %s" % ( decay_factor, G_optimizer.learning_rate)) with self.summary_writer.as_default(): tf.summary.scalar( "warmup_loss", metric.result(), step=G_optimizer.iterations) tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations) if not num_steps % self.settings["print_step"]: logging.info( "[WARMUP] Step: {}\tGenerator Loss: {}" "\tPSNR: {}\tTime Taken: {} sec".format( num_steps, metric.result(), psnr_metric.result(), time.time() - start_time)) if psnr_metric.result() > previous_loss: utils.save_checkpoint(checkpoint, "phase_1", self.model_dir) previous_loss = psnr_metric.result() start_time = time.time() def train_gan(self, generator, discriminator): """ Implements Training routine for ESRGAN Args: generator: Model object for the Generator discriminator: Model object for the Discriminator """ print('begin train_gan...') phase_args = self.settings["train_combined"] decay_args = phase_args["adam"]["decay"] decay_factor = decay_args["factor"] decay_steps = decay_args["step"] lambda_ = phase_args["lambda"] hr_dimension = self.settings["dataset"]["hr_dimension"] eta = phase_args["eta"] total_steps = phase_args["num_steps"] optimizer = partial( tf.optimizers.Adam, learning_rate=phase_args["adam"]["initial_lr"], beta_1=phase_args["adam"]["beta_1"], beta_2=phase_args["adam"]["beta_2"]) G_optimizer = optimizer() D_optimizer = optimizer() ra_gen = utils.RelativisticAverageLoss(discriminator, type_="G") ra_disc = utils.RelativisticAverageLoss(discriminator, type_="D") # The weights of generator trained during Phase #1 # is used to initialize or "hot start" the generator # for phase #2 of training status = None checkpoint = tf.train.Checkpoint( G=generator, G_optimizer=G_optimizer, D=discriminator, D_optimizer=D_optimizer) ckpt_manager = tf.train.CheckpointManager(checkpoint, self.model_dir, max_to_keep=7) if not tf.io.gfile.exists( os.path.join( self.model_dir, self.settings["checkpoint_path"]["phase_2"], "checkpoint")): hot_start = tf.train.Checkpoint( G=generator, G_optimizer=G_optimizer) status = utils.load_checkpoint(hot_start, "phase_1", self.model_dir) # consuming variable from checkpoint G_optimizer.learning_rate.assign(phase_args["adam"]["initial_lr"]) else: status = utils.load_checkpoint(checkpoint, "phase_2", self.model_dir) logging.debug("phase status object: {}".format(status)) gen_metric = tf.keras.metrics.Mean() disc_metric = tf.keras.metrics.Mean() psnr_metric = tf.keras.metrics.Mean() # TDR: May need to create a Perceptual Model for Optical Depth? # logging.debug("Loading Perceptual Model") # perceptual_loss = utils.PerceptualLoss( # weights="imagenet", # input_shape=[hr_dimension, hr_dimension, 1], # loss_type=phase_args["perceptual_loss_type"]) # logging.debug("Loaded Model") def _step_fn(image_lr, image_hr): logging.debug("Starting Distributed Step") with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: fake = generator.unsigned_call(image_lr) logging.debug("Fetched Generator Fake") # TDR, not sure about these... # fake = utils.preprocess_input(fake) # image_lr = utils.preprocess_input(image_lr) # image_hr = utils.preprocess_input(image_hr) # ------------------------------------------------ # TDR, not using perceptual loss with CLD OPD # percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake)) # logging.debug("Calculated Perceptual Loss") # -------------------------------------------------------------- # # substitute standard losses below... # l1_loss = utils.pixel_loss(image_hr, fake) # logging.debug("Calculated Pixel Loss") # # loss_RaG = ra_gen(image_hr, fake) # logging.debug("Calculated Relativistic" # "Average (RA) Loss for Generator") # # disc_loss = ra_disc(image_hr, fake) # logging.debug("Calculated RA Loss Discriminator") # # # TDR, we don't have percep_loss # # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss # gen_loss = lambda_ * loss_RaG + eta * l1_loss # # ----------------------------------------------------------- gen_loss = utils.std_loss_G(discriminator, fake) disc_loss = utils.std_loss_D(discriminator, image_hr, fake) logging.debug("Calculated Generator Loss") disc_metric(disc_loss) gen_metric(gen_loss) gen_loss = gen_loss * (1.0 / self.batch_size) disc_loss = disc_loss * (1.0 / self.batch_size) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables) logging.debug("Calculated gradient for Discriminator") D_optimizer.apply_gradients(zip(disc_grad, discriminator.trainable_variables)) logging.debug("Applied gradients to Discriminator") gen_grad = gen_tape.gradient(gen_loss, generator.trainable_variables) logging.debug("Calculated gradient for Generator") G_optimizer.apply_gradients(zip(gen_grad, generator.trainable_variables)) logging.debug("Applied gradients to Generator") return tf.cast(D_optimizer.iterations, tf.float32) @tf.function def train_step(image_lr, image_hr): # distributed_iterations = self.strategy.experimental_run_v2(step_fn, args=(image_lr, image_hr)) distributed_iterations = _step_fn(image_lr, image_hr) # num_steps = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, distributed_iterations, axis=None) # return num_steps return distributed_iterations start = time.time() last_psnr = 0 for epoch in range(NUM_EPOCHS): print('start epoch #: ', epoch) for image_lr, image_hr in self.dataset: num_step = train_step(image_lr, image_hr) if status: status.assert_consumed() logging.info("consumed checkpoint successfully!") status = None # Decaying Learning Rate for _step in decay_steps.copy(): if num_step >= _step: decay_steps.pop(0) # TDR, Let's don't print this out, causes an error in the next line # g_current_lr = self.strategy.reduce( # tf.distribute.ReduceOp.MEAN, # G_optimizer.learning_rate, axis=None) # # d_current_lr = self.strategy.reduce( # tf.distribute.ReduceOp.MEAN, # D_optimizer.learning_rate, axis=None) # # logging.debug( # "Current LR: G = %s, D = %s" % # (g_current_lr, d_current_lr)) # logging.debug( # "[Phase 2] Decayed Learing Rate by %f." % decay_factor) G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor) D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor) # Writing Summary with self.summary_writer_2.as_default(): tf.summary.scalar( "gen_loss", gen_metric.result(), step=D_optimizer.iterations) tf.summary.scalar( "disc_loss", disc_metric.result(), step=D_optimizer.iterations) tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations) # Logging and Checkpointing if not num_step % self.settings["print_step"]: logging.info( "Step: {}\tGen Loss: {}\tDisc Loss: {}" "\tPSNR: {}\tTime Taken: {} sec".format( num_step, gen_metric.result(), disc_metric.result(), psnr_metric.result(), time.time() - start)) if psnr_metric.result() > last_psnr: last_psnr = psnr_metric.result() utils.save_checkpoint(checkpoint, "phase_2", self.model_dir) start = time.time()