train.py 13.44 KiB
import os
import time
from functools import partial
# from absl import logging
import logging
import tensorflow as tf
# from lib import utils, dataset
from GSOC.E2_ESRGAN.lib import utils, dataset
import glob
from pathlib import Path
logging.basicConfig(filename=str(Path.home()) + '/esrgan_log.txt', level=logging.DEBUG)
NUM_WU_EPOCHS = 14
NUM_EPOCHS = 20
num_chans = 2
PSNR_MAX = 256.0
class Trainer(object):
""" Trainer class for ESRGAN """
def __init__(
self,
summary_writer,
summary_writer_2,
settings,
model_dir='/Users/tomrink/tf_model_esgan/',
data_dir=None,
manual=False,
strategy=None):
""" Setup the values and variables for Training.
Args:
summary_writer: tf.summary.SummaryWriter object to write summaries for Tensorboard.
settings: settings object for fetching data from config files.
data_dir (default: None): path where the data downloaded should be stored / accessed.
manual (default: False): boolean to represent if data_dir is a manual dir.
"""
self.settings = settings
self.model_dir = model_dir
self.summary_writer = summary_writer
self.summary_writer_2 = summary_writer_2
self.strategy = strategy
dataset_args = self.settings["dataset"]
self.batch_size = self.settings["batch_size"]
hr_size = dataset_args['hr_dimension']
lr_size = int(hr_size // 4)
filenames = glob.glob('/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/clavrx_surfrad_abi_cld_opd_1_512x512/train_ires_????.npy')
self.dataset = dataset.OpdNpyDataset(filenames, hr_size, lr_size, batch_size=self.batch_size).dataset
print('OPD dataset initialized...')
def warmup_generator(self, generator):
""" Training on L1 Loss to warmup the Generator.
Minimizing the L1 Loss will reduce the Peak Signal to Noise Ratio (PSNR)
of the generated image from the generator.
This trained generator is then used to bootstrap the training of the
GAN, creating better image inputs instead of random noises.
Args:
generator: Model Object for the Generator
"""
# Loading up phase parameters
print('begin warmup_generator...')
warmup_num_iter = self.settings.get("warmup_num_iter", None)
phase_args = self.settings["train_psnr"]
decay_params = phase_args["adam"]["decay"]
decay_step = decay_params["step"]
decay_factor = decay_params["factor"]
total_steps = phase_args["num_steps"]
metric = tf.keras.metrics.Mean()
psnr_metric = tf.keras.metrics.Mean()
# Generator Optimizer
G_optimizer = tf.optimizers.Adam(
learning_rate=phase_args["adam"]["initial_lr"],
beta_1=phase_args["adam"]["beta_1"],
beta_2=phase_args["adam"]["beta_2"])
checkpoint = tf.train.Checkpoint(
G=generator,
G_optimizer=G_optimizer)
ckpt_manager = tf.train.CheckpointManager(checkpoint, self.model_dir, max_to_keep=7)
status = utils.load_checkpoint(checkpoint, "phase_1", self.model_dir)
logging.debug("phase_1 status object: {}".format(status))
previous_loss = 0
start_time = time.time()
# Training starts
def _step_fn(image_lr, image_hr):
logging.debug("Starting Distributed Step")
with tf.GradientTape() as tape:
fake = generator.unsigned_call(image_lr)
loss = utils.pixel_loss(image_hr, fake) * (1.0 / self.batch_size)
# loss = utils.pixel_loss_mse(image_hr, fake) * (1.0 / self.batch_size)
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
# gen_vars = list(set(generator.trainable_variables))
gen_vars = generator.trainable_variables
gradient = tape.gradient(loss, gen_vars)
G_optimizer.apply_gradients(zip(gradient, gen_vars))
mean_loss = metric(loss)
logging.debug("Ending Distributed Step")
return tf.cast(G_optimizer.iterations, tf.float32)
@tf.function
def train_step(image_lr, image_hr):
# distributed_metric = self.strategy.experimental_run_v2(_step_fn, args=[image_lr, image_hr])
distributed_metric = _step_fn(image_lr, image_hr)
# mean_metric = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, distributed_metric, axis=None)
# return mean_metric
return distributed_metric
for epoch in range(NUM_WU_EPOCHS):
print('start epoch #: ', epoch)
for image_lr, image_hr in self.dataset:
num_steps = train_step(image_lr, image_hr)
if status:
status.assert_consumed()
logging.info(
"consumed checkpoint for phase_1 successfully")
status = None
if not num_steps % decay_step: # Decay Learning Rate
logging.debug(
"Learning Rate: %s" %
G_optimizer.learning_rate.numpy)
G_optimizer.learning_rate.assign(
G_optimizer.learning_rate * decay_factor)
logging.debug(
"Decayed Learning Rate by %f."
"Current Learning Rate %s" % (
decay_factor, G_optimizer.learning_rate))
with self.summary_writer.as_default():
tf.summary.scalar(
"warmup_loss", metric.result(), step=G_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), G_optimizer.iterations)
if not num_steps % self.settings["print_step"]:
logging.info(
"[WARMUP] Step: {}\tGenerator Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format(
num_steps,
metric.result(),
psnr_metric.result(),
time.time() -
start_time))
if psnr_metric.result() > previous_loss:
utils.save_checkpoint(checkpoint, "phase_1", self.model_dir)
previous_loss = psnr_metric.result()
start_time = time.time()
def train_gan(self, generator, discriminator):
""" Implements Training routine for ESRGAN
Args:
generator: Model object for the Generator
discriminator: Model object for the Discriminator
"""
print('begin train_gan...')
phase_args = self.settings["train_combined"]
decay_args = phase_args["adam"]["decay"]
decay_factor = decay_args["factor"]
decay_steps = decay_args["step"]
lambda_ = phase_args["lambda"]
hr_dimension = self.settings["dataset"]["hr_dimension"]
eta = phase_args["eta"]
total_steps = phase_args["num_steps"]
optimizer = partial(
tf.optimizers.Adam,
learning_rate=phase_args["adam"]["initial_lr"],
beta_1=phase_args["adam"]["beta_1"],
beta_2=phase_args["adam"]["beta_2"])
G_optimizer = optimizer()
D_optimizer = optimizer()
ra_gen = utils.RelativisticAverageLoss(discriminator, type_="G")
ra_disc = utils.RelativisticAverageLoss(discriminator, type_="D")
# The weights of generator trained during Phase #1
# is used to initialize or "hot start" the generator
# for phase #2 of training
status = None
checkpoint = tf.train.Checkpoint(
G=generator,
G_optimizer=G_optimizer,
D=discriminator,
D_optimizer=D_optimizer)
ckpt_manager = tf.train.CheckpointManager(checkpoint, self.model_dir, max_to_keep=7)
if not tf.io.gfile.exists(
os.path.join(
self.model_dir,
self.settings["checkpoint_path"]["phase_2"],
"checkpoint")):
hot_start = tf.train.Checkpoint(
G=generator,
G_optimizer=G_optimizer)
status = utils.load_checkpoint(hot_start, "phase_1", self.model_dir)
# consuming variable from checkpoint
G_optimizer.learning_rate.assign(phase_args["adam"]["initial_lr"])
else:
status = utils.load_checkpoint(checkpoint, "phase_2", self.model_dir)
logging.debug("phase status object: {}".format(status))
gen_metric = tf.keras.metrics.Mean()
disc_metric = tf.keras.metrics.Mean()
psnr_metric = tf.keras.metrics.Mean()
# TDR: May need to create a Perceptual Model for Optical Depth?
# logging.debug("Loading Perceptual Model")
# perceptual_loss = utils.PerceptualLoss(
# weights="imagenet",
# input_shape=[hr_dimension, hr_dimension, 1],
# loss_type=phase_args["perceptual_loss_type"])
# logging.debug("Loaded Model")
def _step_fn(image_lr, image_hr):
logging.debug("Starting Distributed Step")
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
fake = generator.unsigned_call(image_lr)
logging.debug("Fetched Generator Fake")
# TDR, not sure about these...
# fake = utils.preprocess_input(fake)
# image_lr = utils.preprocess_input(image_lr)
# image_hr = utils.preprocess_input(image_hr)
# ------------------------------------------------
# TDR, not using perceptual loss with CLD OPD
# percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake))
# logging.debug("Calculated Perceptual Loss")
# --------------------------------------------------------------
# # substitute standard losses below...
# l1_loss = utils.pixel_loss(image_hr, fake)
# logging.debug("Calculated Pixel Loss")
#
# loss_RaG = ra_gen(image_hr, fake)
# logging.debug("Calculated Relativistic"
# "Average (RA) Loss for Generator")
#
# disc_loss = ra_disc(image_hr, fake)
# logging.debug("Calculated RA Loss Discriminator")
#
# # TDR, we don't have percep_loss
# # gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
# gen_loss = lambda_ * loss_RaG + eta * l1_loss
# # -----------------------------------------------------------
gen_loss = utils.std_loss_G(discriminator, fake)
disc_loss = utils.std_loss_D(discriminator, image_hr, fake)
logging.debug("Calculated Generator Loss")
disc_metric(disc_loss)
gen_metric(gen_loss)
gen_loss = gen_loss * (1.0 / self.batch_size)
disc_loss = disc_loss * (1.0 / self.batch_size)
psnr_metric(tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=PSNR_MAX)))
disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
logging.debug("Calculated gradient for Discriminator")
D_optimizer.apply_gradients(zip(disc_grad, discriminator.trainable_variables))
logging.debug("Applied gradients to Discriminator")
gen_grad = gen_tape.gradient(gen_loss, generator.trainable_variables)
logging.debug("Calculated gradient for Generator")
G_optimizer.apply_gradients(zip(gen_grad, generator.trainable_variables))
logging.debug("Applied gradients to Generator")
return tf.cast(D_optimizer.iterations, tf.float32)
@tf.function
def train_step(image_lr, image_hr):
# distributed_iterations = self.strategy.experimental_run_v2(step_fn, args=(image_lr, image_hr))
distributed_iterations = _step_fn(image_lr, image_hr)
# num_steps = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, distributed_iterations, axis=None)
# return num_steps
return distributed_iterations
start = time.time()
last_psnr = 0
for epoch in range(NUM_EPOCHS):
print('start epoch #: ', epoch)
for image_lr, image_hr in self.dataset:
num_step = train_step(image_lr, image_hr)
if status:
status.assert_consumed()
logging.info("consumed checkpoint successfully!")
status = None
# Decaying Learning Rate
for _step in decay_steps.copy():
if num_step >= _step:
decay_steps.pop(0)
# TDR, Let's don't print this out, causes an error in the next line
# g_current_lr = self.strategy.reduce(
# tf.distribute.ReduceOp.MEAN,
# G_optimizer.learning_rate, axis=None)
#
# d_current_lr = self.strategy.reduce(
# tf.distribute.ReduceOp.MEAN,
# D_optimizer.learning_rate, axis=None)
#
# logging.debug(
# "Current LR: G = %s, D = %s" %
# (g_current_lr, d_current_lr))
# logging.debug(
# "[Phase 2] Decayed Learing Rate by %f." % decay_factor)
G_optimizer.learning_rate.assign(G_optimizer.learning_rate * decay_factor)
D_optimizer.learning_rate.assign(D_optimizer.learning_rate * decay_factor)
# Writing Summary
with self.summary_writer_2.as_default():
tf.summary.scalar(
"gen_loss", gen_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar(
"disc_loss", disc_metric.result(), step=D_optimizer.iterations)
tf.summary.scalar("mean_psnr", psnr_metric.result(), step=D_optimizer.iterations)
# Logging and Checkpointing
if not num_step % self.settings["print_step"]:
logging.info(
"Step: {}\tGen Loss: {}\tDisc Loss: {}"
"\tPSNR: {}\tTime Taken: {} sec".format(
num_step,
gen_metric.result(),
disc_metric.result(),
psnr_metric.result(),
time.time() - start))
if psnr_metric.result() > last_psnr:
last_psnr = psnr_metric.result()
utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
start = time.time()