From 48211ef022a298e93f9a5e8783744c5688275ad8 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Wed, 6 Sep 2023 11:58:13 -0500 Subject: [PATCH] snapshot... --- modules/GSOC/E2_ESRGAN/lib/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/modules/GSOC/E2_ESRGAN/lib/utils.py b/modules/GSOC/E2_ESRGAN/lib/utils.py index 954910ab..db1d4b2b 100644 --- a/modules/GSOC/E2_ESRGAN/lib/utils.py +++ b/modules/GSOC/E2_ESRGAN/lib/utils.py @@ -9,6 +9,8 @@ from GSOC.E2_ESRGAN.lib import settings # Checkpoint Utilities +num_chans = 2 + def save_checkpoint(checkpoint, training_phase, basepath=""): """ Saves checkpoint. @@ -77,9 +79,9 @@ def interpolate_generator( psnr_generator = generator_fn() # building generators gan_generator(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) + [1, size[0] // factor, size[1] // factor, num_chans])) psnr_generator(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) + [1, size[0] // factor, size[1] // factor, num_chans])) phase_1_ckpt = tf.train.Checkpoint( G=psnr_generator, G_optimizer=optimizer()) @@ -96,8 +98,8 @@ def interpolate_generator( # Consuming Checkpoint logging.debug("Consuming Variables: Adervsarial generator") gan_generator.unsigned_call(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) - input_layer = tf.keras.Input(shape=[None, None, 1], name="input_0") + [1, size[0] // factor, size[1] // factor, num_chans])) + input_layer = tf.keras.Input(shape=[None, None, num_chans], name="input_0") output = gan_generator(input_layer) gan_generator = tf.keras.Model( inputs=[input_layer], @@ -105,7 +107,7 @@ def interpolate_generator( logging.debug("Consuming Variables: PSNR generator") psnr_generator.unsigned_call(tf.random.normal( - [1, size[0] // factor, size[1] // factor, 1])) + [1, size[0] // factor, size[1] // factor, num_chans])) for variables_1, variables_2 in zip( gan_generator.trainable_variables, psnr_generator.trainable_variables): -- GitLab