diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 9a4681b2b17d86d188abbe3a91b4afec5f1c4c11..905e39589902ebabced5ed18901f54c238b4765b 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -231,18 +231,17 @@ class Trainer(object): l1_loss = utils.pixel_loss(image_hr, fake) logging.debug("Calculated Pixel Loss") - # loss_RaG = ra_gen(image_hr, fake) - # logging.debug("Calculated Relativistic" - # "Average (RA) Loss for Generator") - # - # disc_loss = ra_disc(image_hr, fake) - # logging.debug("Calculated RA Loss Discriminator") - disc_loss = l1_loss + loss_RaG = ra_gen(image_hr, fake) + logging.debug("Calculated Relativistic" + "Average (RA) Loss for Generator") + + disc_loss = ra_disc(image_hr, fake) + logging.debug("Calculated RA Loss Discriminator") + # TDR, we don't have percep_loss # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss - # gen_loss = lambda_ * loss_RaG + eta * l1_loss - gen_loss = l1_loss + gen_loss = lambda_ * loss_RaG + eta * l1_loss logging.debug("Calculated Generator Loss") disc_metric(disc_loss)