diff --git a/modules/GSOC/E2_ESRGAN/lib/train.py b/modules/GSOC/E2_ESRGAN/lib/train.py index 25d727cb6f954b38260a90e1b9f4d553820bb6bb..d4b827d6fc8544ee2d744bc9af9cb369726a19fd 100644 --- a/modules/GSOC/E2_ESRGAN/lib/train.py +++ b/modules/GSOC/E2_ESRGAN/lib/train.py @@ -15,6 +15,7 @@ NUM_WU_EPOCHS = 14 NUM_EPOCHS = 20 num_chans = 2 PSNR_MAX = 256.0 +GENERATOR_ONLY = True class Trainer(object): @@ -153,6 +154,8 @@ class Trainer(object): utils.save_checkpoint(checkpoint, "phase_1", self.model_dir) previous_loss = psnr_metric.result() start_time = time.time() + if GENERATOR_ONLY: + return def train_gan(self, generator, discriminator): """ Implements Training routine for ESRGAN