utils.py 9.56 KiB
import os
from functools import partial
import tensorflow as tf
# from absl import logging
import logging
from GSOC.E2_ESRGAN.lib import settings
""" Utility functions needed for training ESRGAN model. """
# Checkpoint Utilities
def save_checkpoint(checkpoint, training_phase, basepath=""):
""" Saves checkpoint.
Args:
checkpoint: tf.train.Checkpoint object
training_phase: The training phase of the model to load/store the checkpoint for.
can be one of the two "phase_1" or "phase_2"
basepath: Base path to load checkpoints from.
"""
dir_ = settings.Settings()["checkpoint_path"][training_phase]
if basepath:
dir_ = os.path.join(basepath, dir_)
dir_ = os.path.join(dir_, os.path.basename(dir_))
checkpoint.save(file_prefix=dir_)
logging.debug("Prefix: %s. checkpoint saved successfully!" % dir_)
def load_checkpoint(checkpoint, training_phase, basepath=""):
""" Saves checkpoint.
Args:
checkpoint: tf.train.Checkpoint object
training_phase: The training phase of the model to load/store the checkpoint for.
can be one of the two "phase_1" or "phase_2"
basepath: Base Path to load checkpoints from.
"""
logging.info("Loading check point for: %s" % training_phase)
dir_ = settings.Settings()["checkpoint_path"][training_phase]
if basepath:
dir_ = os.path.join(basepath, dir_)
if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")):
logging.info("Found checkpoint at: %s" % dir_)
status = checkpoint.restore(tf.train.latest_checkpoint(dir_))
return status
# Network Interpolation utility
def interpolate_generator(
generator_fn,
discriminator,
alpha,
dimension,
factor=4,
basepath=""):
""" Interpolates between the weights of the PSNR model and GAN model
Refer to Section 3.4 of https://arxiv.org/pdf/1809.00219.pdf (Xintao et. al.)
Args:
generator_fn: function which returns the keras model the generator used.
discriminiator: Keras model of the discriminator.
alpha: interpolation parameter between both the weights of both the models.
dimension: dimension of the high resolution image
factor: scale factor of the model
basepath: Base directory to load checkpoints from.
Returns:
Keras model of a generator with weights interpolated between the PSNR and GAN model.
"""
assert 0 <= alpha <= 1
size = dimension
if not tf.nest.is_nested(dimension):
size = [dimension, dimension]
logging.debug("Interpolating generator. Alpha: %f" % alpha)
optimizer = partial(tf.keras.optimizers.Adam)
gan_generator = generator_fn()
psnr_generator = generator_fn()
# building generators
gan_generator(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
psnr_generator(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
phase_1_ckpt = tf.train.Checkpoint(
G=psnr_generator, G_optimizer=optimizer())
phase_2_ckpt = tf.train.Checkpoint(
G=gan_generator,
G_optimizer=optimizer(),
D=discriminator,
D_optimizer=optimizer())
load_checkpoint(phase_1_ckpt, "phase_1", basepath)
load_checkpoint(phase_2_ckpt, "phase_2", basepath)
# Consuming Checkpoint
logging.debug("Consuming Variables: Adervsarial generator")
gan_generator.unsigned_call(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
input_layer = tf.keras.Input(shape=[None, None, 1], name="input_0")
output = gan_generator(input_layer)
gan_generator = tf.keras.Model(
inputs=[input_layer],
outputs=[output])
logging.debug("Consuming Variables: PSNR generator")
psnr_generator.unsigned_call(tf.random.normal(
[1, size[0] // factor, size[1] // factor, 1]))
for variables_1, variables_2 in zip(
gan_generator.trainable_variables, psnr_generator.trainable_variables):
variables_1.assign((1 - alpha) * variables_2 + alpha * variables_1)
return gan_generator
# Losses
def preprocess_input(image):
image = image[..., ::-1]
# mean = -tf.constant([103.939, 116.779, 123.68])
mean = -tf.constant([0.0])
return tf.nn.bias_add(image, mean)
def PerceptualLoss(weights=None, input_shape=None, loss_type="L1"):
""" Perceptual Loss using VGG19
Args:
weights: Weights to be loaded.
input_shape: Shape of input image.
loss_type: Loss type for features. (L1 / L2)
"""
vgg_model = tf.keras.applications.vgg19.VGG19(
input_shape=input_shape, weights=weights, include_top=False)
for layer in vgg_model.layers:
layer.trainable = False
# Removing Activation Function
vgg_model.get_layer("block5_conv4").activation = lambda x: x
phi = tf.keras.Model(
inputs=[vgg_model.input],
outputs=[
vgg_model.get_layer("block5_conv4").output])
def loss(y_true, y_pred):
if loss_type.lower() == "l1":
return tf.compat.v1.losses.absolute_difference(phi(y_true), phi(y_pred))
if loss_type.lower() == "l2":
return tf.reduce_mean(
tf.reduce_mean(
(phi(y_true) - phi(y_pred))**2,
axis=0))
raise ValueError(
"Loss Function: \"%s\" not defined for Perceptual Loss" %
loss_type)
return loss
def pixel_loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
return tf.reduce_mean(tf.reduce_mean(tf.abs(y_true - y_pred), axis=0))
def RelativisticAverageLoss(non_transformed_disc, type_="G"):
""" Relativistic Average Loss based on RaGAN
Args:
non_transformed_disc: non activated discriminator Model
type_: type of loss to Ra loss to produce.
'G': Relativistic average loss for generator
'D': Relativistic average loss for discriminator
"""
loss = None
def D_Ra(x, y):
return non_transformed_disc(
x) - tf.reduce_mean(non_transformed_disc(y))
def loss_D(y_true, y_pred):
"""
Relativistic Average Loss for Discriminator
Args:
y_true: Real Image
y_pred: Generated Image
"""
real_logits = D_Ra(y_true, y_pred)
fake_logits = D_Ra(y_pred, y_true)
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(real_logits), logits=real_logits))
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(fake_logits), logits=fake_logits))
return real_loss + fake_loss
def loss_G(y_true, y_pred):
"""
Relativistic Average Loss for Generator
Args:
y_true: Real Image
y_pred: Generated Image
"""
real_logits = D_Ra(y_true, y_pred)
fake_logits = D_Ra(y_pred, y_true)
real_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(real_logits), logits=real_logits)
fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(fake_logits), logits=fake_logits)
return real_loss + fake_loss
if type_ == "G":
loss = loss_G
elif type_ == "D":
loss = loss_D
return loss
# Strategy Utils
def assign_to_worker(use_tpu):
if use_tpu:
return "/job:worker"
return ""
class SingleDeviceStrategy(object):
""" Dummy Strategy when Outside TPU """
def __enter__(self, *args, **kwargs):
pass
def __exit__(self, *args, **kwargs):
pass
def experimental_distribute_dataset(self, dataset, *args, **kwargs):
return dataset
def experimental_run_v2(self, fn, args, kwargs):
return fn(*args, **kwargs)
def reduce(reduction_type, distributed_data, axis):
return distributed_data
def scope(self):
return self
# Model Utils
class RDB(tf.keras.layers.Layer):
""" Residual Dense Block Layer """
def __init__(self, out_features=32, bias=True, first_call=True):
super(RDB, self).__init__()
_create_conv2d = partial(
tf.keras.layers.Conv2D,
out_features,
kernel_size=[3, 3],
kernel_initializer="he_normal",
bias_initializer="zeros",
strides=[1, 1], padding="same", use_bias=bias)
self._conv2d_layers = {
"conv_1": _create_conv2d(),
"conv_2": _create_conv2d(),
"conv_3": _create_conv2d(),
"conv_4": _create_conv2d(),
"conv_5": _create_conv2d()}
self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
self._beta = settings.Settings()["RDB"].get("residual_scale_beta", 0.2)
self._first_call = first_call
def call(self, input_):
x1 = self._lrelu(self._conv2d_layers["conv_1"](input_))
x2 = self._lrelu(self._conv2d_layers["conv_2"](
tf.concat([input_, x1], -1)))
x3 = self._lrelu(self._conv2d_layers["conv_3"](
tf.concat([input_, x1, x2], -1)))
x4 = self._lrelu(self._conv2d_layers["conv_4"](
tf.concat([input_, x1, x2, x3], -1)))
x5 = self._conv2d_layers["conv_5"](tf.concat([input_, x1, x2, x3, x4], -1))
if self._first_call:
logging.debug("Initializing with MSRA")
for _, layer in self._conv2d_layers.items():
for variable in layer.trainable_variables:
variable.assign(0.1 * variable)
self._first_call = False
return input_ + self._beta * x5
class RRDB(tf.keras.layers.Layer):
""" Residual in Residual Block Layer """
def __init__(self, out_features=32, first_call=True):
super(RRDB, self).__init__()
self.RDB1 = RDB(out_features, first_call=first_call)
self.RDB2 = RDB(out_features, first_call=first_call)
self.RDB3 = RDB(out_features, first_call=first_call)
self.beta = settings.Settings()["RDB"].get("residual_scale_beta", 0.2)
def call(self, input_):
out = self.RDB1(input_)
out = self.RDB2(out)
out = self.RDB3(out)
return input_ + self.beta * out