From 303187b5e117c3eec65f62c94092cd4dd8bf895b Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 12 Oct 2023 10:39:15 -0500
Subject: [PATCH] snapshot...

---
 modules/GSOC/E2_ESRGAN/lib/train.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py
index f8ff4150..25d727cb 100644
--- a/modules/GSOC/E2_ESRGAN/lib/train.py
+++ b/modules/GSOC/E2_ESRGAN/lib/train.py
@@ -93,8 +93,9 @@ class Trainer(object):
       with tf.GradientTape() as tape:
         fake = generator.unsigned_call(image_lr)
         loss_mae = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size)
-        # loss_mse = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size)
-        loss = loss_mae
+        loss_mse = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size)
+        # loss = loss_mae
+        loss = loss_mse
       mean_loss = metric(loss_mae)
       psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
       # gen_vars = list(set(generator.trainable_variables))
@@ -114,6 +115,7 @@ class Trainer(object):
 
     for epoch in range(NUM_WU_EPOCHS):
         print('start epoch #: ', epoch)
+        metric.reset_states()
         for image_lr, image_hr in self.dataset:
           num_steps = train_step(image_lr, image_hr)
 
-- 
GitLab