esrgan.py 4.80 KiB
from collections import OrderedDict
from functools import partial
import tensorflow as tf
from util import esrgan_utils as utils
""" Keras Models for ESRGAN
Classes:
RRDBNet: Generator of ESRGAN. (Residual in Residual Network)
VGGArch: VGG28 Architecture making the Discriminator ESRGAN
"""
class RRDBNet(tf.keras.Model):
""" Residual in Residual Network consisting of:
- Convolution Layers
- Residual in Residual Block as the trunk of the model
- Pixel Shuffler layers (tf.nn.depth_to_space)
- Upscaling Convolutional Layers
Args:
out_channel: number of channels of the fake output image.
num_features (default: 32): number of filters to use in the convolutional layers.
trunk_size (default: 3): number of Residual in Residual Blocks to form the trunk.
growth_channel (default: 32): number of filters to use in the internal convolutional layers.
use_bias (default: True): boolean to indicate if bias is to be used in the conv layers.
"""
def __init__(
self,
out_channel,
num_features=32,
trunk_size=11,
growth_channel=32,
use_bias=True,
first_call=True):
super(RRDBNet, self).__init__()
self.rrdb_block = partial(utils.RRDB, growth_channel, first_call=first_call)
conv = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3],
strides=[1, 1],
padding="same",
use_bias=use_bias)
conv_transpose = partial(
tf.keras.layers.Conv2DTranspose,
kernel_size=[3, 3],
strides=[2, 2],
padding="same",
use_bias=use_bias)
self.conv_first = conv(filters=num_features)
self.rdb_trunk = tf.keras.Sequential(
[self.rrdb_block() for _ in range(trunk_size)])
self.conv_trunk = conv(filters=num_features)
# Upsample
self.upsample1 = conv_transpose(num_features)
self.upsample2 = conv_transpose(num_features)
self.conv_last_1 = conv(num_features)
self.conv_last_2 = conv(out_channel)
self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
# @tf.function(
# input_signature=[
# tf.TensorSpec(shape=[None, None, None, 3],
# dtype=tf.float32)])
def call(self, inputs):
return self.unsigned_call(inputs)
def unsigned_call(self, input_):
feature = self.lrelu(self.conv_first(input_))
trunk = self.conv_trunk(self.rdb_trunk(feature))
feature = trunk + feature
feature = self.lrelu(
self.upsample1(feature))
feature = self.lrelu(
self.upsample2(feature))
feature = self.lrelu(self.conv_last_1(feature))
out = self.conv_last_2(feature)
return out
class VGGArch(tf.keras.Model):
""" Keras Model for VGG28 Architecture needed to form
the discriminator of the architecture.
Args:
output_shape (default: 1): output_shape of the generator
num_features (default: 64): number of features to be used in the convolutional layers
a factor of 2**i will be multiplied as per the need
use_bias (default: True): Boolean to indicate whether to use biases for convolution layers
"""
def __init__(
self,
batch_size=8,
output_shape=1,
num_features=64,
use_bias=False):
super(VGGArch, self).__init__()
conv = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3], use_bias=use_bias, padding="same")
batch_norm = partial(tf.keras.layers.BatchNormalization)
def no_batch_norm(x): return x
self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
self._dense_1 = tf.keras.layers.Dense(100)
self._dense_2 = tf.keras.layers.Dense(output_shape)
self._conv_layers = OrderedDict()
self._batch_norm = OrderedDict()
self._conv_layers["conv_0_0"] = conv(filters=num_features, strides=1)
self._conv_layers["conv_0_1"] = conv(filters=num_features, strides=2)
self._batch_norm["bn_0_1"] = batch_norm()
for i in range(1, 4):
for j in range(1, 3):
self._conv_layers["conv_%d_%d" % (i, j)] = conv(
filters=num_features * (2**i), strides=j)
self._batch_norm["bn_%d_%d" % (i, j)] = batch_norm()
def call(self, inputs):
return self.unsigned_call(inputs)
def unsigned_call(self, input_):
features = self._lrelu(self._conv_layers["conv_0_0"](input_))
features = self._lrelu(
self._batch_norm["bn_0_1"](
self._conv_layers["conv_0_1"](features)))
# VGG Trunk
for i in range(1, 4):
for j in range(1, 3):
features = self._lrelu(
self._batch_norm["bn_%d_%d" % (i, j)](
self._conv_layers["conv_%d_%d" % (i, j)](features)))
flattened = tf.keras.layers.Flatten()(features)
dense = self._lrelu(self._dense_1(flattened))
out = self._dense_2(dense)
return out