From 0b81822768c8e920aab6a043f0d2506f1cdc1a99 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 5 Oct 2023 11:14:09 -0500
Subject: [PATCH] snapshot...

---
 modules/GSOC/E2_ESRGAN/lib/train.py | 35 ++++++++++++++++-------------
 1 file changed, 20 insertions(+), 15 deletions(-)

diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py
index d3706f0f..173841ab 100644
--- a/modules/GSOC/E2_ESRGAN/lib/train.py
+++ b/modules/GSOC/E2_ESRGAN/lib/train.py
@@ -228,21 +228,26 @@ class Trainer(object):
         # TDR, not using perceptual loss with CLD OPD
         # percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake))
         # logging.debug("Calculated Perceptual Loss")
-
-        l1_loss = utils.pixel_loss(image_hr, fake)
-        logging.debug("Calculated Pixel Loss")
-
-        loss_RaG = ra_gen(image_hr, fake)
-        logging.debug("Calculated Relativistic"
-                      "Average (RA) Loss for Generator")
-
-        disc_loss = ra_disc(image_hr, fake)
-        logging.debug("Calculated RA Loss Discriminator")
-
-
-        # TDR, we don't have percep_loss
-        # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
-        gen_loss = lambda_ * loss_RaG + eta * l1_loss
+        # --------------------------------------------------------------
+
+        # # substitute standard losses below...
+        # l1_loss = utils.pixel_loss(image_hr, fake)
+        # logging.debug("Calculated Pixel Loss")
+        #
+        # loss_RaG = ra_gen(image_hr, fake)
+        # logging.debug("Calculated Relativistic"
+        #               "Average (RA) Loss for Generator")
+        #
+        # disc_loss = ra_disc(image_hr, fake)
+        # logging.debug("Calculated RA Loss Discriminator")
+        #
+        # # TDR, we don't have percep_loss
+        # # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
+        # gen_loss = lambda_ * loss_RaG + eta * l1_loss
+        # # -----------------------------------------------------------
+
+        gen_loss = utils.std_loss_G(discriminator, fake)
+        disc_loss = utils.std_loss_D(discriminator, image_hr, fake)
         logging.debug("Calculated Generator Loss")
 
         disc_metric(disc_loss)
-- 
GitLab