Skip to content
Snippets Groups Projects
Commit e93081c7 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 385635fe
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,9 @@ from pathlib import Path ...@@ -11,6 +11,9 @@ from pathlib import Path
logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG) logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG)
NUM_WU_EPOCHS = 10
NUM_EPOCHS = 20
class Trainer(object): class Trainer(object):
""" Trainer class for ESRGAN """ """ Trainer class for ESRGAN """
...@@ -105,46 +108,45 @@ class Trainer(object): ...@@ -105,46 +108,45 @@ class Trainer(object):
# return mean_metric # return mean_metric
return distributed_metric return distributed_metric
for image_lr, image_hr in self.dataset: for epoch in range(NUM_WU_EPOCHS):
num_steps = train_step(image_lr, image_hr) for image_lr, image_hr in self.dataset:
num_steps = train_step(image_lr, image_hr)
if num_steps >= total_steps:
return if status:
if status: status.assert_consumed()
status.assert_consumed() logging.info(
logging.info( "consumed checkpoint for phase_1 successfully")
"consumed checkpoint for phase_1 successfully") status = None
status = None
if not num_steps % decay_step: # Decay Learning Rate
if not num_steps % decay_step: # Decay Learning Rate logging.debug(
logging.debug( "Learning Rate: %s" %
"Learning Rate: %s" % G_optimizer.learning_rate.numpy)
G_optimizer.learning_rate.numpy) G_optimizer.learning_rate.assign(
G_optimizer.learning_rate.assign( G_optimizer.learning_rate * decay_factor)
G_optimizer.learning_rate * decay_factor) logging.debug(
logging.debug( "Decayed Learning Rate by %f."
"Decayed Learning Rate by %f." "Current Learning Rate %s" % (
"Current Learning Rate %s" % ( decay_factor, G_optimizer.learning_rate))
decay_factor, G_optimizer.learning_rate)) with self.summary_writer.as_default():
with self.summary_writer.as_default(): tf.summary.scalar(
tf.summary.scalar( "warmup_loss", metric.result(), step=G_optimizer.iterations)
"warmup_loss", metric.result(), step=G_optimizer.iterations) tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
# if not num_steps % self.settings["print_step"]: # test
# if not num_steps % self.settings["print_step"]: # test if True:
if True: logging.info(
logging.info( "[WARMUP] Step: {}\tGenerator Loss: {}"
"[WARMUP] Step: {}\tGenerator Loss: {}" "\tPSNR: {}\tTime Taken: {} sec".format(
"\tPSNR: {}\tTime Taken: {} sec".format( num_steps,
num_steps, metric.result(),
metric.result(), psnr_metric.result(),
psnr_metric.result(), time.time() -
time.time() - start_time))
start_time)) if psnr_metric.result() > previous_loss:
if psnr_metric.result() > previous_loss: utils.save_checkpoint(checkpoint, "phase_1", self.model_dir)
utils.save_checkpoint(checkpoint, "phase_1", self.model_dir) previous_loss = psnr_metric.result()
previous_loss = psnr_metric.result() start_time = time.time()
start_time = time.time()
def train_gan(self, generator, discriminator): def train_gan(self, generator, discriminator):
""" Implements Training routine for ESRGAN """ Implements Training routine for ESRGAN
...@@ -260,55 +262,55 @@ class Trainer(object): ...@@ -260,55 +262,55 @@ class Trainer(object):
start = time.time() start = time.time()
last_psnr = 0 last_psnr = 0
for image_lr, image_hr in self.dataset: for epoch in range(NUM_EPOCHS):
num_step = train_step(image_lr, image_hr) for image_lr, image_hr in self.dataset:
if num_step >= total_steps: num_step = train_step(image_lr, image_hr)
return
if status: if status:
status.assert_consumed() status.assert_consumed()
logging.info("consumed checkpoint successfully!") logging.info("consumed checkpoint successfully!")
status = None status = None
# Decaying Learning Rate # Decaying Learning Rate
for _step in decay_steps.copy(): for _step in decay_steps.copy():
if num_step >= _step: if num_step >= _step:
decay_steps.pop(0) decay_steps.pop(0)
g_current_lr = self.strategy.reduce( g_current_lr = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, tf.distribute.ReduceOp.MEAN,
G_optimizer.learning_rate, axis=None) G_optimizer.learning_rate, axis=None)
d_current_lr = self.strategy.reduce( d_current_lr = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, tf.distribute.ReduceOp.MEAN,
D_optimizer.learning_rate, axis=None) D_optimizer.learning_rate, axis=None)
logging.debug( logging.debug(
"Current LR: G = %s, D = %s" % "Current LR: G = %s, D = %s" %
(g_current_lr, d_current_lr)) (g_current_lr, d_current_lr))
logging.debug( logging.debug(
"[Phase 2] Decayed Learing Rate by %f." % decay_factor) "[Phase 2] Decayed Learing Rate by %f." % decay_factor)
G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor) G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor)
D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor) D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor)
# Writing Summary # Writing Summary
with self.summary_writer_2.as_default(): with self.summary_writer_2.as_default():
tf.summary.scalar( tf.summary.scalar(
"gen_loss", gen_metric.result(), step=D_optimizer.iterations) "gen_loss", gen_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar( tf.summary.scalar(
"disc_loss", disc_metric.result(), step=D_optimizer.iterations) "disc_loss", disc_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations) tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations)
# Logging and Checkpointing # Logging and Checkpointing
# if not num_step % self.settings["print_step"]: # testing # if not num_step % self.settings["print_step"]: # testing
if True: if True:
logging.info( logging.info(
"Step: {}\tGen Loss: {}\tDisc Loss: {}" "Step: {}\tGen Loss: {}\tDisc Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format( "\tPSNR: {}\tTime Taken: {} sec".format(
num_step, num_step,
gen_metric.result(), gen_metric.result(),
disc_metric.result(), disc_metric.result(),
psnr_metric.result(), psnr_metric.result(),
time.time() - start)) time.time() - start))
# if psnr_metric.result() > last_psnr: # if psnr_metric.result() > last_psnr:
last_psnr = psnr_metric.result() last_psnr = psnr_metric.result()
utils.save_checkpoint(checkpoint, "phase_2", self.model_dir) utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
start = time.time() start = time.time()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment