diff --git a/modules/GSOC/E2_ESRGAN/lib/utils.py b/modules/GSOC/E2_ESRGAN/lib/utils.py index 954910abae10c3bc831ca5bf274f776573570dea..db1d4b2b4d674233bd1477784d620789b5a293b9 100644 --- a/modules/GSOC/E2_ESRGAN/lib/utils.py +++ b/modules/GSOC/E2_ESRGAN/lib/utils.py @@ -9,6 +9,8 @@ from GSOC.E2_ESRGAN.lib import settings # Checkpoint Utilities +num_chans = 2 + def save_checkpoint(checkpoint, training_phase, basepath=""): """ Saves checkpoint. @@ -77,9 +79,9 @@ def interpolate_generator( psnr_generator = generator_fn() # building generators gan_generator(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) + [1, size[0] // factor, size[1] // factor, num_chans])) psnr_generator(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) + [1, size[0] // factor, size[1] // factor, num_chans])) phase_1_ckpt = tf.train.Checkpoint( G=psnr_generator, G_optimizer=optimizer()) @@ -96,8 +98,8 @@ def interpolate_generator( # Consuming Checkpoint logging.debug("Consuming Variables: Adervsarial generator") gan_generator.unsigned_call(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) - input_layer = tf.keras.Input(shape=[None, None, 1], name="input_0") + [1, size[0] // factor, size[1] // factor, num_chans])) + input_layer = tf.keras.Input(shape=[None, None, num_chans], name="input_0") output = gan_generator(input_layer) gan_generator = tf.keras.Model( inputs=[input_layer], @@ -105,7 +107,7 @@ def interpolate_generator( logging.debug("Consuming Variables: PSNR generator") psnr_generator.unsigned_call(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) + [1, size[0] // factor, size[1] // factor, num_chans])) for variables_1, variables_2 in zip( gan_generator.trainable_variables, psnr_generator.trainable_variables):