Skip to content
Snippets Groups Projects
Commit e93081c7 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 385635fe
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,9 @@ from pathlib import Path
logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG)
NUM_WU_EPOCHS = 10
NUM_EPOCHS = 20
class Trainer(object):
""" Trainer class for ESRGAN """
......@@ -105,46 +108,45 @@ class Trainer(object):
# return mean_metric
return distributed_metric
for image_lr, image_hr in self.dataset:
num_steps = train_step(image_lr, image_hr)
if num_steps >= total_steps:
return
if status:
status.assert_consumed()
logging.info(
"consumed checkpoint for phase_1 successfully")
status = None
if not num_steps % decay_step: # Decay Learning Rate
logging.debug(
"Learning Rate: %s" %
G_optimizer.learning_rate.numpy)
G_optimizer.learning_rate.assign(
G_optimizer.learning_rate * decay_factor)
logging.debug(
"Decayed Learning Rate by %f."
"Current Learning Rate %s" % (
decay_factor, G_optimizer.learning_rate))
with self.summary_writer.as_default():
tf.summary.scalar(
"warmup_loss", metric.result(), step=G_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
# if not num_steps % self.settings["print_step"]: # test
if True:
logging.info(
"[WARMUP] Step: {}\tGenerator Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format(
num_steps,
metric.result(),
psnr_metric.result(),
time.time() -
start_time))
if psnr_metric.result() > previous_loss:
utils.save_checkpoint(checkpoint, "phase_1", self.model_dir)
previous_loss = psnr_metric.result()
start_time = time.time()
for epoch in range(NUM_WU_EPOCHS):
for image_lr, image_hr in self.dataset:
num_steps = train_step(image_lr, image_hr)
if status:
status.assert_consumed()
logging.info(
"consumed checkpoint for phase_1 successfully")
status = None
if not num_steps % decay_step: # Decay Learning Rate
logging.debug(
"Learning Rate: %s" %
G_optimizer.learning_rate.numpy)
G_optimizer.learning_rate.assign(
G_optimizer.learning_rate * decay_factor)
logging.debug(
"Decayed Learning Rate by %f."
"Current Learning Rate %s" % (
decay_factor, G_optimizer.learning_rate))
with self.summary_writer.as_default():
tf.summary.scalar(
"warmup_loss", metric.result(), step=G_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
# if not num_steps % self.settings["print_step"]: # test
if True:
logging.info(
"[WARMUP] Step: {}\tGenerator Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format(
num_steps,
metric.result(),
psnr_metric.result(),
time.time() -
start_time))
if psnr_metric.result() > previous_loss:
utils.save_checkpoint(checkpoint, "phase_1", self.model_dir)
previous_loss = psnr_metric.result()
start_time = time.time()
def train_gan(self, generator, discriminator):
""" Implements Training routine for ESRGAN
......@@ -260,55 +262,55 @@ class Trainer(object):
start = time.time()
last_psnr = 0
for image_lr, image_hr in self.dataset:
num_step = train_step(image_lr, image_hr)
if num_step >= total_steps:
return
if status:
status.assert_consumed()
logging.info("consumed checkpoint successfully!")
status = None
# Decaying Learning Rate
for _step in decay_steps.copy():
if num_step >= _step:
decay_steps.pop(0)
g_current_lr = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN,
G_optimizer.learning_rate, axis=None)
d_current_lr = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN,
D_optimizer.learning_rate, axis=None)
logging.debug(
"Current LR: G = %s, D = %s" %
(g_current_lr, d_current_lr))
logging.debug(
"[Phase 2] Decayed Learing Rate by %f." % decay_factor)
G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor)
D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor)
# Writing Summary
with self.summary_writer_2.as_default():
tf.summary.scalar(
"gen_loss", gen_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar(
"disc_loss", disc_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations)
# Logging and Checkpointing
# if not num_step % self.settings["print_step"]: # testing
if True:
logging.info(
"Step: {}\tGen Loss: {}\tDisc Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format(
num_step,
gen_metric.result(),
disc_metric.result(),
psnr_metric.result(),
time.time() - start))
# if psnr_metric.result() > last_psnr:
last_psnr = psnr_metric.result()
utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
start = time.time()
for epoch in range(NUM_EPOCHS):
for image_lr, image_hr in self.dataset:
num_step = train_step(image_lr, image_hr)
if status:
status.assert_consumed()
logging.info("consumed checkpoint successfully!")
status = None
# Decaying Learning Rate
for _step in decay_steps.copy():
if num_step >= _step:
decay_steps.pop(0)
g_current_lr = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN,
G_optimizer.learning_rate, axis=None)
d_current_lr = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN,
D_optimizer.learning_rate, axis=None)
logging.debug(
"Current LR: G = %s, D = %s" %
(g_current_lr, d_current_lr))
logging.debug(
"[Phase 2] Decayed Learing Rate by %f." % decay_factor)
G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor)
D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor)
# Writing Summary
with self.summary_writer_2.as_default():
tf.summary.scalar(
"gen_loss", gen_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar(
"disc_loss", disc_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations)
# Logging and Checkpointing
# if not num_step % self.settings["print_step"]: # testing
if True:
logging.info(
"Step: {}\tGen Loss: {}\tDisc Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format(
num_step,
gen_metric.result(),
disc_metric.result(),
psnr_metric.result(),
time.time() - start))
# if psnr_metric.result() > last_psnr:
last_psnr = psnr_metric.result()
utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
start = time.time()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment