diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 1dee62e644f8207962b74818e221955baa3cf906..2920fc3884c1b8b67e25e727841bd7320580bb45 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -217,21 +217,30 @@ class Trainer(object): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: fake = generator.unsigned_call(image_lr) logging.debug("Fetched Generator Fake") - fake = utils.preprocess_input(fake) - image_lr = utils.preprocess_input(image_lr) - image_hr = utils.preprocess_input(image_hr) + + # TDR, not sure about these... + # fake = utils.preprocess_input(fake) + # image_lr = utils.preprocess_input(image_lr) + # image_hr = utils.preprocess_input(image_hr) + # ------------------------------------------------ + # TDR, not using perceptual loss with CLD OPD # percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake)) # logging.debug("Calculated Perceptual Loss") + l1_loss = utils.pixel_loss(image_hr, fake) logging.debug("Calculated Pixel Loss") + loss_RaG = ra_gen(image_hr, fake) logging.debug("Calculated Relativistic" "Average (RA) Loss for Generator") + disc_loss = ra_disc(image_hr, fake) logging.debug("Calculated RA Loss Discriminator") + # TDR, we don't have percep_loss # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss gen_loss = lambda_ * loss_RaG + eta * l1_loss logging.debug("Calculated Generator Loss") + disc_metric(disc_loss) gen_metric(gen_loss) gen_loss = gen_loss * (1.0 / self.batch_size)