diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 7467b2577236d5ac5cae6ab140b339be09e8677b..e164a74d21034fc723752d59d4d56e004e9bee26 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -14,7 +14,7 @@ logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging NUM_WU_EPOCHS = 1 NUM_EPOCHS = 20 num_chans = 2 - +PSNR_MAX = 1.0 class Trainer(object): """ Trainer class for ESRGAN """ @@ -93,7 +93,7 @@ class Trainer(object): fake = generator.unsigned_call(image_lr) loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) # loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) - psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0))) + psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) # gen_vars = list(set(generator.trainable_variables)) gen_vars = generator.trainable_variables gradient = tape.gradient(loss, gen_vars) @@ -255,7 +255,7 @@ class Trainer(object): gen_metric(gen_loss) gen_loss = gen_loss * (1.0 / self.batch_size) disc_loss = disc_loss * (1.0 / self.batch_size) - psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0))) + psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables) logging.debug("Calculated gradient for Discriminator")