diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index d1d1fbcf75ec4b0c30b86c6987c3a69de46e640f..43394b5b3acd84cfa17abe21941baa1b3c274e22 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -99,6 +99,7 @@ class Trainer(object): # loss = loss_mae loss = loss_mse mean_loss = metric(loss_mae) + mse_metric(loss) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) # gen_vars = list(set(generator.trainable_variables)) gen_vars = generator.trainable_variables @@ -119,6 +120,7 @@ class Trainer(object): print('start epoch #: ', epoch) metric.reset_states() psnr_metric.reset_states() + mse_metric.reset_states() for image_lr, image_hr in self.dataset: num_steps = train_step(image_lr, image_hr) @@ -149,6 +151,7 @@ class Trainer(object): "\tPSNR: {}\tTime Taken: {} sec".format( num_steps, metric.result(), + mse_metric.result(), psnr_metric.result(), time.time() - start_time))