From e93081c7be608364025fabd872dcbaf21794b71b Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 22 Aug 2023 10:39:40 -0500
Subject: [PATCH] snapshot...

---
 modules/GSOC/E2_ESRGAN/lib/train.py | 186 ++++++++++++++--------------
 1 file changed, 94 insertions(+), 92 deletions(-)

diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py
index fc992c7c..a0498295 100644
--- a/modules/GSOC/E2_ESRGAN/lib/train.py
+++ b/modules/GSOC/E2_ESRGAN/lib/train.py
@@ -11,6 +11,9 @@ from pathlib import Path
 
 logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG)
 
+NUM_WU_EPOCHS = 10
+NUM_EPOCHS = 20
+
 
 class Trainer(object):
   """ Trainer class for ESRGAN """
@@ -105,46 +108,45 @@ class Trainer(object):
       # return mean_metric
       return distributed_metric
 
-    for image_lr, image_hr in self.dataset:
-      num_steps = train_step(image_lr, image_hr)
-
-      if num_steps >= total_steps:
-        return
-      if status:
-        status.assert_consumed()
-        logging.info(
-            "consumed checkpoint for phase_1 successfully")
-        status = None
-
-      if not num_steps % decay_step:  # Decay Learning Rate
-        logging.debug(
-            "Learning Rate: %s" %
-            G_optimizer.learning_rate.numpy)
-        G_optimizer.learning_rate.assign(
-            G_optimizer.learning_rate * decay_factor)
-        logging.debug(
-            "Decayed Learning Rate by %f."
-            "Current Learning Rate %s" % (
-                decay_factor, G_optimizer.learning_rate))
-      with self.summary_writer.as_default():
-        tf.summary.scalar(
-            "warmup_loss", metric.result(), step=G_optimizer.iterations)
-        tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
-
-      # if not num_steps % self.settings["print_step"]: # test
-      if True:
-        logging.info(
-            "[WARMUP] Step: {}\tGenerator Loss: {}"
-            "\tPSNR: {}\tTime Taken: {} sec".format(
-                num_steps,
-                metric.result(),
-                psnr_metric.result(),
-                time.time() -
-                start_time))
-        if psnr_metric.result() > previous_loss:
-          utils.save_checkpoint(checkpoint, "phase_1", self.model_dir)
-        previous_loss = psnr_metric.result()
-        start_time = time.time()
+    for epoch in range(NUM_WU_EPOCHS):
+        for image_lr, image_hr in self.dataset:
+          num_steps = train_step(image_lr, image_hr)
+
+          if status:
+            status.assert_consumed()
+            logging.info(
+                "consumed checkpoint for phase_1 successfully")
+            status = None
+
+          if not num_steps % decay_step:  # Decay Learning Rate
+            logging.debug(
+                "Learning Rate: %s" %
+                G_optimizer.learning_rate.numpy)
+            G_optimizer.learning_rate.assign(
+                G_optimizer.learning_rate * decay_factor)
+            logging.debug(
+                "Decayed Learning Rate by %f."
+                "Current Learning Rate %s" % (
+                    decay_factor, G_optimizer.learning_rate))
+          with self.summary_writer.as_default():
+            tf.summary.scalar(
+                "warmup_loss", metric.result(), step=G_optimizer.iterations)
+            tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
+
+          # if not num_steps % self.settings["print_step"]: # test
+          if True:
+            logging.info(
+                "[WARMUP] Step: {}\tGenerator Loss: {}"
+                "\tPSNR: {}\tTime Taken: {} sec".format(
+                    num_steps,
+                    metric.result(),
+                    psnr_metric.result(),
+                    time.time() -
+                    start_time))
+            if psnr_metric.result() > previous_loss:
+              utils.save_checkpoint(checkpoint, "phase_1", self.model_dir)
+            previous_loss = psnr_metric.result()
+            start_time = time.time()
 
   def train_gan(self, generator, discriminator):
     """ Implements Training routine for ESRGAN
@@ -260,55 +262,55 @@ class Trainer(object):
     start = time.time()
     last_psnr = 0
 
-    for image_lr, image_hr in self.dataset:
-      num_step = train_step(image_lr, image_hr)
-      if num_step >= total_steps:
-        return
-      if status:
-        status.assert_consumed()
-        logging.info("consumed checkpoint successfully!")
-        status = None
-
-      # Decaying Learning Rate
-      for _step in decay_steps.copy():
-        if num_step >= _step:
-          decay_steps.pop(0)
-          g_current_lr = self.strategy.reduce(
-              tf.distribute.ReduceOp.MEAN,
-              G_optimizer.learning_rate, axis=None)
-
-          d_current_lr = self.strategy.reduce(
-              tf.distribute.ReduceOp.MEAN,
-              D_optimizer.learning_rate, axis=None)
-
-          logging.debug(
-              "Current LR: G = %s, D = %s" %
-              (g_current_lr, d_current_lr))
-          logging.debug(
-              "[Phase 2] Decayed Learing Rate by %f." % decay_factor)
-          G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor)
-          D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor)
-
-      # Writing Summary
-      with self.summary_writer_2.as_default():
-        tf.summary.scalar(
-            "gen_loss", gen_metric.result(), step=D_optimizer.iterations)
-        tf.summary.scalar(
-            "disc_loss", disc_metric.result(), step=D_optimizer.iterations)
-        tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations)
-
-      # Logging and Checkpointing
-      # if not num_step % self.settings["print_step"]: # testing
-      if True:
-        logging.info(
-            "Step: {}\tGen Loss: {}\tDisc Loss: {}"
-            "\tPSNR: {}\tTime Taken: {} sec".format(
-                num_step,
-                gen_metric.result(),
-                disc_metric.result(),
-                psnr_metric.result(),
-                time.time() - start))
-        # if psnr_metric.result() > last_psnr:
-        last_psnr = psnr_metric.result()
-        utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
-        start = time.time()
+    for epoch in range(NUM_EPOCHS):
+        for image_lr, image_hr in self.dataset:
+          num_step = train_step(image_lr, image_hr)
+
+          if status:
+            status.assert_consumed()
+            logging.info("consumed checkpoint successfully!")
+            status = None
+
+          # Decaying Learning Rate
+          for _step in decay_steps.copy():
+            if num_step >= _step:
+              decay_steps.pop(0)
+              g_current_lr = self.strategy.reduce(
+                  tf.distribute.ReduceOp.MEAN,
+                  G_optimizer.learning_rate, axis=None)
+
+              d_current_lr = self.strategy.reduce(
+                  tf.distribute.ReduceOp.MEAN,
+                  D_optimizer.learning_rate, axis=None)
+
+              logging.debug(
+                  "Current LR: G = %s, D = %s" %
+                  (g_current_lr, d_current_lr))
+              logging.debug(
+                  "[Phase 2] Decayed Learing Rate by %f." % decay_factor)
+              G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor)
+              D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor)
+
+          # Writing Summary
+          with self.summary_writer_2.as_default():
+            tf.summary.scalar(
+                "gen_loss", gen_metric.result(), step=D_optimizer.iterations)
+            tf.summary.scalar(
+                "disc_loss", disc_metric.result(), step=D_optimizer.iterations)
+            tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations)
+
+          # Logging and Checkpointing
+          # if not num_step % self.settings["print_step"]: # testing
+          if True:
+            logging.info(
+                "Step: {}\tGen Loss: {}\tDisc Loss: {}"
+                "\tPSNR: {}\tTime Taken: {} sec".format(
+                    num_step,
+                    gen_metric.result(),
+                    disc_metric.result(),
+                    psnr_metric.result(),
+                    time.time() - start))
+            # if psnr_metric.result() > last_psnr:
+            last_psnr = psnr_metric.result()
+            utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
+            start = time.time()
-- 
GitLab