diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index d4b827d6fc8544ee2d744bc9af9cb369726a19fd..8a625e9253b61d1fce036ec2017ac2c54a3adc1f 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -154,8 +154,6 @@ class Trainer(object): utils.save_checkpoint(checkpoint, "phase_1", self.model_dir) previous_loss = psnr_metric.result() start_time = time.time() - if GENERATOR_ONLY: - return def train_gan(self, generator, discriminator): """ Implements Training routine for ESRGAN @@ -342,3 +340,5 @@ class Trainer(object): last_psnr = psnr_metric.result() utils.save_checkpoint(checkpoint, "phase_2", self.model_dir) start = time.time() + if GENERATOR_ONLY: + return