Skip to content
Snippets Groups Projects
Commit 48211ef0 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 49b452c6
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,8 @@ from GSOC.E2_ESRGAN.lib import settings
# Checkpoint Utilities
num_chans = 2
def save_checkpoint(checkpoint, training_phase, basepath=""):
""" Saves checkpoint.
......@@ -77,9 +79,9 @@ def interpolate_generator(
psnr_generator = generator_fn()
# building generators
gan_generator(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
[1, size[0] // factor, size[1] // factor, num_chans]))
psnr_generator(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
[1, size[0] // factor, size[1] // factor, num_chans]))
phase_1_ckpt = tf.train.Checkpoint(
G=psnr_generator, G_optimizer=optimizer())
......@@ -96,8 +98,8 @@ def interpolate_generator(
# Consuming Checkpoint
logging.debug("Consuming Variables: Adervsarial generator")
gan_generator.unsigned_call(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
input_layer = tf.keras.Input(shape=[None, None, 1], name="input_0")
[1, size[0] // factor, size[1] // factor, num_chans]))
input_layer = tf.keras.Input(shape=[None, None, num_chans], name="input_0")
output = gan_generator(input_layer)
gan_generator = tf.keras.Model(
inputs=[input_layer],
......@@ -105,7 +107,7 @@ def interpolate_generator(
logging.debug("Consuming Variables: PSNR generator")
psnr_generator.unsigned_call(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
[1, size[0] // factor, size[1] // factor, num_chans]))
for variables_1, variables_2 in zip(
gan_generator.trainable_variables, psnr_generator.trainable_variables):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment