diff --git a/modules/GSOC/E2_ESRGAN/lib/utils.py b/modules/GSOC/E2_ESRGAN/lib/utils.py
index 954910abae10c3bc831ca5bf274f776573570dea..db1d4b2b4d674233bd1477784d620789b5a293b9 100644
--- a/modules/GSOC/E2_ESRGAN/lib/utils.py
+++ b/modules/GSOC/E2_ESRGAN/lib/utils.py
@@ -9,6 +9,8 @@ from GSOC.E2_ESRGAN.lib import settings
 
 # Checkpoint Utilities
 
+num_chans = 2
+
 
 def save_checkpoint(checkpoint, training_phase, basepath=""):
   """ Saves checkpoint.
@@ -77,9 +79,9 @@ def interpolate_generator(
   psnr_generator = generator_fn()
   # building generators
   gan_generator(tf.random.normal(
-      [1, size[0] // factor, size[1] // factor, 1]))
+      [1, size[0] // factor, size[1] // factor, num_chans]))
   psnr_generator(tf.random.normal(
-      [1, size[0] // factor, size[1] // factor, 1]))
+      [1, size[0] // factor, size[1] // factor, num_chans]))
 
   phase_1_ckpt = tf.train.Checkpoint(
       G=psnr_generator, G_optimizer=optimizer())
@@ -96,8 +98,8 @@ def interpolate_generator(
   # Consuming Checkpoint
   logging.debug("Consuming Variables: Adervsarial generator") 
   gan_generator.unsigned_call(tf.random.normal(
-      [1, size[0] // factor, size[1] // factor, 1]))
-  input_layer = tf.keras.Input(shape=[None, None, 1], name="input_0")
+      [1, size[0] // factor, size[1] // factor, num_chans]))
+  input_layer = tf.keras.Input(shape=[None, None, num_chans], name="input_0")
   output = gan_generator(input_layer)
   gan_generator = tf.keras.Model(
       inputs=[input_layer],
@@ -105,7 +107,7 @@ def interpolate_generator(
 
   logging.debug("Consuming Variables: PSNR generator")
   psnr_generator.unsigned_call(tf.random.normal(
-      [1, size[0] // factor, size[1] // factor, 1]))
+      [1, size[0] // factor, size[1] // factor, num_chans]))
 
   for variables_1, variables_2 in zip(
           gan_generator.trainable_variables, psnr_generator.trainable_variables):