diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index fc992c7c463248266309aa984683479af7b93d72..a0498295fb229bb66a5288e4bc8522e1b5c49a05 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -11,6 +11,9 @@ from pathlib import Path logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG) +NUM_WU_EPOCHS = 10 +NUM_EPOCHS = 20 + class Trainer(object): """ Trainer class for ESRGAN """ @@ -105,46 +108,45 @@ class Trainer(object): # return mean_metric return distributed_metric - for image_lr, image_hr in self.dataset: - num_steps = train_step(image_lr, image_hr) - - if num_steps >= total_steps: - return - if status: - status.assert_consumed() - logging.info( - "consumed checkpoint for phase_1 successfully") - status = None - - if not num_steps % decay_step: # Decay Learning Rate - logging.debug( - "Learning Rate: %s" % - G_optimizer.learning_rate.numpy) - G_optimizer.learning_rate.assign( - G_optimizer.learning_rate * decay_factor) - logging.debug( - "Decayed Learning Rate by %f." - "Current Learning Rate %s" % ( - decay_factor, G_optimizer.learning_rate)) - with self.summary_writer.as_default(): - tf.summary.scalar( - "warmup_loss", metric.result(), step=G_optimizer.iterations) - tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations) - - # if not num_steps % self.settings["print_step"]: # test - if True: - logging.info( - "[WARMUP] Step: {}\tGenerator Loss: {}" - "\tPSNR: {}\tTime Taken: {} sec".format( - num_steps, - metric.result(), - psnr_metric.result(), - time.time() - - start_time)) - if psnr_metric.result() > previous_loss: - utils.save_checkpoint(checkpoint, "phase_1", self.model_dir) - previous_loss = psnr_metric.result() - start_time = time.time() + for epoch in range(NUM_WU_EPOCHS): + for image_lr, image_hr in self.dataset: + num_steps = train_step(image_lr, image_hr) + + if status: + status.assert_consumed() + logging.info( + "consumed checkpoint for phase_1 successfully") + status = None + + if not num_steps % decay_step: # Decay Learning Rate + logging.debug( + "Learning Rate: %s" % + G_optimizer.learning_rate.numpy) + G_optimizer.learning_rate.assign( + G_optimizer.learning_rate * decay_factor) + logging.debug( + "Decayed Learning Rate by %f." + "Current Learning Rate %s" % ( + decay_factor, G_optimizer.learning_rate)) + with self.summary_writer.as_default(): + tf.summary.scalar( + "warmup_loss", metric.result(), step=G_optimizer.iterations) + tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations) + + # if not num_steps % self.settings["print_step"]: # test + if True: + logging.info( + "[WARMUP] Step: {}\tGenerator Loss: {}" + "\tPSNR: {}\tTime Taken: {} sec".format( + num_steps, + metric.result(), + psnr_metric.result(), + time.time() - + start_time)) + if psnr_metric.result() > previous_loss: + utils.save_checkpoint(checkpoint, "phase_1", self.model_dir) + previous_loss = psnr_metric.result() + start_time = time.time() def train_gan(self, generator, discriminator): """ Implements Training routine for ESRGAN @@ -260,55 +262,55 @@ class Trainer(object): start = time.time() last_psnr = 0 - for image_lr, image_hr in self.dataset: - num_step = train_step(image_lr, image_hr) - if num_step >= total_steps: - return - if status: - status.assert_consumed() - logging.info("consumed checkpoint successfully!") - status = None - - # Decaying Learning Rate - for _step in decay_steps.copy(): - if num_step >= _step: - decay_steps.pop(0) - g_current_lr = self.strategy.reduce( - tf.distribute.ReduceOp.MEAN, - G_optimizer.learning_rate, axis=None) - - d_current_lr = self.strategy.reduce( - tf.distribute.ReduceOp.MEAN, - D_optimizer.learning_rate, axis=None) - - logging.debug( - "Current LR: G = %s, D = %s" % - (g_current_lr, d_current_lr)) - logging.debug( - "[Phase 2] Decayed Learing Rate by %f." % decay_factor) - G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor) - D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor) - - # Writing Summary - with self.summary_writer_2.as_default(): - tf.summary.scalar( - "gen_loss", gen_metric.result(), step=D_optimizer.iterations) - tf.summary.scalar( - "disc_loss", disc_metric.result(), step=D_optimizer.iterations) - tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations) - - # Logging and Checkpointing - # if not num_step % self.settings["print_step"]: # testing - if True: - logging.info( - "Step: {}\tGen Loss: {}\tDisc Loss: {}" - "\tPSNR: {}\tTime Taken: {} sec".format( - num_step, - gen_metric.result(), - disc_metric.result(), - psnr_metric.result(), - time.time() - start)) - # if psnr_metric.result() > last_psnr: - last_psnr = psnr_metric.result() - utils.save_checkpoint(checkpoint, "phase_2", self.model_dir) - start = time.time() + for epoch in range(NUM_EPOCHS): + for image_lr, image_hr in self.dataset: + num_step = train_step(image_lr, image_hr) + + if status: + status.assert_consumed() + logging.info("consumed checkpoint successfully!") + status = None + + # Decaying Learning Rate + for _step in decay_steps.copy(): + if num_step >= _step: + decay_steps.pop(0) + g_current_lr = self.strategy.reduce( + tf.distribute.ReduceOp.MEAN, + G_optimizer.learning_rate, axis=None) + + d_current_lr = self.strategy.reduce( + tf.distribute.ReduceOp.MEAN, + D_optimizer.learning_rate, axis=None) + + logging.debug( + "Current LR: G = %s, D = %s" % + (g_current_lr, d_current_lr)) + logging.debug( + "[Phase 2] Decayed Learing Rate by %f." % decay_factor) + G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor) + D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor) + + # Writing Summary + with self.summary_writer_2.as_default(): + tf.summary.scalar( + "gen_loss", gen_metric.result(), step=D_optimizer.iterations) + tf.summary.scalar( + "disc_loss", disc_metric.result(), step=D_optimizer.iterations) + tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations) + + # Logging and Checkpointing + # if not num_step % self.settings["print_step"]: # testing + if True: + logging.info( + "Step: {}\tGen Loss: {}\tDisc Loss: {}" + "\tPSNR: {}\tTime Taken: {} sec".format( + num_step, + gen_metric.result(), + disc_metric.result(), + psnr_metric.result(), + time.time() - start)) + # if psnr_metric.result() > last_psnr: + last_psnr = psnr_metric.result() + utils.save_checkpoint(checkpoint, "phase_2", self.model_dir) + start = time.time()