Skip to content
Snippets Groups Projects
Commit b6e20077 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent da96e57c
No related branches found
No related tags found
No related merge requests found
from collections import OrderedDict
from functools import partial
import tensorflow as tf
from util import esrgan_utils as utils
""" Keras Models for ESRGAN
Classes:
RRDBNet: Generator of ESRGAN. (Residual in Residual Network)
VGGArch: VGG28 Architecture making the Discriminator ESRGAN
"""
class RRDBNet(tf.keras.Model):
""" Residual in Residual Network consisting of:
- Convolution Layers
- Residual in Residual Block as the trunk of the model
- Pixel Shuffler layers (tf.nn.depth_to_space)
- Upscaling Convolutional Layers
Args:
out_channel: number of channels of the fake output image.
num_features (default: 32): number of filters to use in the convolutional layers.
trunk_size (default: 3): number of Residual in Residual Blocks to form the trunk.
growth_channel (default: 32): number of filters to use in the internal convolutional layers.
use_bias (default: True): boolean to indicate if bias is to be used in the conv layers.
"""
def __init__(
self,
out_channel,
num_features=32,
trunk_size=11,
growth_channel=32,
use_bias=True,
first_call=True):
super(RRDBNet, self).__init__()
self.rrdb_block = partial(utils.RRDB, growth_channel, first_call=first_call)
conv = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3],
strides=[1, 1],
padding="same",
use_bias=use_bias)
conv_transpose = partial(
tf.keras.layers.Conv2DTranspose,
kernel_size=[3, 3],
strides=[2, 2],
padding="same",
use_bias=use_bias)
self.conv_first = conv(filters=num_features)
self.rdb_trunk = tf.keras.Sequential(
[self.rrdb_block() for _ in range(trunk_size)])
self.conv_trunk = conv(filters=num_features)
# Upsample
self.upsample1 = conv_transpose(num_features)
self.upsample2 = conv_transpose(num_features)
self.conv_last_1 = conv(num_features)
self.conv_last_2 = conv(out_channel)
self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
# @tf.function(
# input_signature=[
# tf.TensorSpec(shape=[None, None, None, 3],
# dtype=tf.float32)])
def call(self, inputs):
return self.unsigned_call(inputs)
def unsigned_call(self, input_):
feature = self.lrelu(self.conv_first(input_))
trunk = self.conv_trunk(self.rdb_trunk(feature))
feature = trunk + feature
feature = self.lrelu(
self.upsample1(feature))
feature = self.lrelu(
self.upsample2(feature))
feature = self.lrelu(self.conv_last_1(feature))
out = self.conv_last_2(feature)
return out
class VGGArch(tf.keras.Model):
""" Keras Model for VGG28 Architecture needed to form
the discriminator of the architecture.
Args:
output_shape (default: 1): output_shape of the generator
num_features (default: 64): number of features to be used in the convolutional layers
a factor of 2**i will be multiplied as per the need
use_bias (default: True): Boolean to indicate whether to use biases for convolution layers
"""
def __init__(
self,
batch_size=8,
output_shape=1,
num_features=64,
use_bias=False):
super(VGGArch, self).__init__()
conv = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3], use_bias=use_bias, padding="same")
batch_norm = partial(tf.keras.layers.BatchNormalization)
def no_batch_norm(x): return x
self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
self._dense_1 = tf.keras.layers.Dense(100)
self._dense_2 = tf.keras.layers.Dense(output_shape)
self._conv_layers = OrderedDict()
self._batch_norm = OrderedDict()
self._conv_layers["conv_0_0"] = conv(filters=num_features, strides=1)
self._conv_layers["conv_0_1"] = conv(filters=num_features, strides=2)
self._batch_norm["bn_0_1"] = batch_norm()
for i in range(1, 4):
for j in range(1, 3):
self._conv_layers["conv_%d_%d" % (i, j)] = conv(
filters=num_features * (2**i), strides=j)
self._batch_norm["bn_%d_%d" % (i, j)] = batch_norm()
def call(self, inputs):
return self.unsigned_call(inputs)
def unsigned_call(self, input_):
features = self._lrelu(self._conv_layers["conv_0_0"](input_))
features = self._lrelu(
self._batch_norm["bn_0_1"](
self._conv_layers["conv_0_1"](features)))
# VGG Trunk
for i in range(1, 4):
for j in range(1, 3):
features = self._lrelu(
self._batch_norm["bn_%d_%d" % (i, j)](
self._conv_layers["conv_%d_%d" % (i, j)](features)))
flattened = tf.keras.layers.Flatten()(features)
dense = self._lrelu(self._dense_1(flattened))
out = self._dense_2(dense)
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment