Skip to content
Snippets Groups Projects
Commit 42bfd0a5 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent bcdb8f8b
Branches
No related tags found
No related merge requests found
......@@ -14,7 +14,7 @@ logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging
NUM_WU_EPOCHS = 1
NUM_EPOCHS = 20
num_chans = 2
PSNR_MAX = 1.0
class Trainer(object):
""" Trainer class for ESRGAN """
......@@ -93,7 +93,7 @@ class Trainer(object):
fake = generator.unsigned_call(image_lr)
loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size)
# loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size)
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0)))
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
# gen_vars = list(set(generator.trainable_variables))
gen_vars = generator.trainable_variables
gradient = tape.gradient(loss, gen_vars)
......@@ -255,7 +255,7 @@ class Trainer(object):
gen_metric(gen_loss)
gen_loss = gen_loss * (1.0 / self.batch_size)
disc_loss = disc_loss * (1.0 / self.batch_size)
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0)))
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
logging.debug("Calculated gradient for Discriminator")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment