From 070fd9fa2b7dd83485b9bafe5dbde6ad8c56ac3b Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Thu, 12 Oct 2023 10:31:24 -0500 Subject: [PATCH] snapshot... --- modules/GSOC/E2_ESRGAN/lib/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 7d234fcd..f8ff4150 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -92,14 +92,15 @@ class Trainer(object): logging.debug("Starting Distributed Step") with tf.GradientTape() as tape: fake = generator.unsigned_call(image_lr) - loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) - # loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) + loss_mae = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) + # loss_mse = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) + loss = loss_mae + mean_loss = metric(loss_mae) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) # gen_vars = list(set(generator.trainable_variables)) gen_vars = generator.trainable_variables gradient = tape.gradient(loss, gen_vars) G_optimizer.apply_gradients(zip(gradient, gen_vars)) - mean_loss = metric(loss) logging.debug("Ending Distributed Step") return tf.cast(G_optimizer.iterations, tf.float32) -- GitLab