diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 173841ab5afcba67af067e43ca9d0d7d7c9b200d..7467b2577236d5ac5cae6ab140b339be09e8677b 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -92,6 +92,7 @@ class Trainer(object): with tf.GradientTape() as tape: fake = generator.unsigned_call(image_lr) loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) + # loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0))) # gen_vars = list(set(generator.trainable_variables)) gen_vars = generator.trainable_variables