From 42bfd0a50ea85f625b236154849f49719726558f Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Sat, 7 Oct 2023 11:02:12 -0500
Subject: [PATCH] snapshot...

---
 modules/GSOC/E2_ESRGAN/lib/train.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py
index 7467b257..e164a74d 100644
--- a/modules/GSOC/E2_ESRGAN/lib/train.py
+++ b/modules/GSOC/E2_ESRGAN/lib/train.py
@@ -14,7 +14,7 @@ logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging
 NUM_WU_EPOCHS = 1
 NUM_EPOCHS = 20
 num_chans = 2
-
+PSNR_MAX = 1.0
 
 class Trainer(object):
   """ Trainer class for ESRGAN """
@@ -93,7 +93,7 @@ class Trainer(object):
         fake = generator.unsigned_call(image_lr)
         loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size)
         # loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size)
-      psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0)))
+      psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
       # gen_vars = list(set(generator.trainable_variables))
       gen_vars = generator.trainable_variables
       gradient = tape.gradient(loss, gen_vars)
@@ -255,7 +255,7 @@ class Trainer(object):
         gen_metric(gen_loss)
         gen_loss = gen_loss * (1.0 / self.batch_size)
         disc_loss = disc_loss * (1.0 / self.batch_size)
-        psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0)))
+        psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
 
       disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
       logging.debug("Calculated gradient for Discriminator")
-- 
GitLab