Skip to content
Snippets Groups Projects
Commit 42bfd0a5 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent bcdb8f8b
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging ...@@ -14,7 +14,7 @@ logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging
NUM_WU_EPOCHS = 1 NUM_WU_EPOCHS = 1
NUM_EPOCHS = 20 NUM_EPOCHS = 20
num_chans = 2 num_chans = 2
PSNR_MAX = 1.0
class Trainer(object): class Trainer(object):
""" Trainer class for ESRGAN """ """ Trainer class for ESRGAN """
...@@ -93,7 +93,7 @@ class Trainer(object): ...@@ -93,7 +93,7 @@ class Trainer(object):
fake = generator.unsigned_call(image_lr) fake = generator.unsigned_call(image_lr)
loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size) loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size)
# loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size) # loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size)
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0))) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
# gen_vars = list(set(generator.trainable_variables)) # gen_vars = list(set(generator.trainable_variables))
gen_vars = generator.trainable_variables gen_vars = generator.trainable_variables
gradient = tape.gradient(loss, gen_vars) gradient = tape.gradient(loss, gen_vars)
...@@ -255,7 +255,7 @@ class Trainer(object): ...@@ -255,7 +255,7 @@ class Trainer(object):
gen_metric(gen_loss) gen_metric(gen_loss)
gen_loss = gen_loss * (1.0 / self.batch_size) gen_loss = gen_loss * (1.0 / self.batch_size)
disc_loss = disc_loss * (1.0 / self.batch_size) disc_loss = disc_loss * (1.0 / self.batch_size)
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0))) psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables) disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
logging.debug("Calculated gradient for Discriminator") logging.debug("Calculated gradient for Discriminator")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment