Skip to content
Snippets Groups Projects
Commit 48211ef0 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 49b452c6
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,8 @@ from GSOC.E2_ESRGAN.lib import settings ...@@ -9,6 +9,8 @@ from GSOC.E2_ESRGAN.lib import settings
# Checkpoint Utilities # Checkpoint Utilities
num_chans = 2
def save_checkpoint(checkpoint, training_phase, basepath=""): def save_checkpoint(checkpoint, training_phase, basepath=""):
""" Saves checkpoint. """ Saves checkpoint.
...@@ -77,9 +79,9 @@ def interpolate_generator( ...@@ -77,9 +79,9 @@ def interpolate_generator(
psnr_generator = generator_fn() psnr_generator = generator_fn()
# building generators # building generators
gan_generator(tf.random.normal( gan_generator(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1])) [1, size[0] // factor, size[1] // factor, num_chans]))
psnr_generator(tf.random.normal( psnr_generator(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1])) [1, size[0] // factor, size[1] // factor, num_chans]))
phase_1_ckpt = tf.train.Checkpoint( phase_1_ckpt = tf.train.Checkpoint(
G=psnr_generator, G_optimizer=optimizer()) G=psnr_generator, G_optimizer=optimizer())
...@@ -96,8 +98,8 @@ def interpolate_generator( ...@@ -96,8 +98,8 @@ def interpolate_generator(
# Consuming Checkpoint # Consuming Checkpoint
logging.debug("Consuming Variables: Adervsarial generator") logging.debug("Consuming Variables: Adervsarial generator")
gan_generator.unsigned_call(tf.random.normal( gan_generator.unsigned_call(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1])) [1, size[0] // factor, size[1] // factor, num_chans]))
input_layer = tf.keras.Input(shape=[None, None, 1], name="input_0") input_layer = tf.keras.Input(shape=[None, None, num_chans], name="input_0")
output = gan_generator(input_layer) output = gan_generator(input_layer)
gan_generator = tf.keras.Model( gan_generator = tf.keras.Model(
inputs=[input_layer], inputs=[input_layer],
...@@ -105,7 +107,7 @@ def interpolate_generator( ...@@ -105,7 +107,7 @@ def interpolate_generator(
logging.debug("Consuming Variables: PSNR generator") logging.debug("Consuming Variables: PSNR generator")
psnr_generator.unsigned_call(tf.random.normal( psnr_generator.unsigned_call(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1])) [1, size[0] // factor, size[1] // factor, num_chans]))
for variables_1, variables_2 in zip( for variables_1, variables_2 in zip(
gan_generator.trainable_variables, psnr_generator.trainable_variables): gan_generator.trainable_variables, psnr_generator.trainable_variables):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment