diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index f8ff41507c4f273ad0bffe8554ebf4804c835f5f..25d727cb6f954b38260a90e1b9f4d553820bb6bb 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -93,8 +93,9 @@ class Trainer(object): with tf.GradientTape() as tape: fake = generator.unsigned_call(image_lr) loss_mae = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) - # loss_mse = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) - loss = loss_mae + loss_mse = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) + # loss = loss_mae + loss = loss_mse mean_loss = metric(loss_mae) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX))) # gen_vars = list(set(generator.trainable_variables)) @@ -114,6 +115,7 @@ class Trainer(object): for epoch in range(NUM_WU_EPOCHS): print('start epoch #: ', epoch) + metric.reset_states() for image_lr, image_hr in self.dataset: num_steps = train_step(image_lr, image_hr)