diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 1372e3802073bedf074299b56e756707640c2dfc..9a4681b2b17d86d188abbe3a91b4afec5f1c4c11 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -231,15 +231,18 @@ class Trainer(object): l1_loss = utils.pixel_loss(image_hr, fake) logging.debug("Calculated Pixel Loss") - loss_RaG = ra_gen(image_hr, fake) - logging.debug("Calculated Relativistic" - "Average (RA) Loss for Generator") + # loss_RaG = ra_gen(image_hr, fake) + # logging.debug("Calculated Relativistic" + # "Average (RA) Loss for Generator") + # + # disc_loss = ra_disc(image_hr, fake) + # logging.debug("Calculated RA Loss Discriminator") + disc_loss = l1_loss - disc_loss = ra_disc(image_hr, fake) - logging.debug("Calculated RA Loss Discriminator") # TDR, we don't have percep_loss # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss - gen_loss = lambda_ * loss_RaG + eta * l1_loss + # gen_loss = lambda_ * loss_RaG + eta * l1_loss + gen_loss = l1_loss logging.debug("Calculated Generator Loss") disc_metric(disc_loss)