From b6e20077fb9b5409e570fc50109f4e9c4e923c7c Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 29 Jan 2024 16:53:04 -0600 Subject: [PATCH] snapshot... --- modules/deeplearning/esrgan.py | 136 +++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 modules/deeplearning/esrgan.py diff --git a/modules/deeplearning/esrgan.py b/modules/deeplearning/esrgan.py new file mode 100644 index 00000000..808503c1 --- /dev/null +++ b/modules/deeplearning/esrgan.py @@ -0,0 +1,136 @@ +from collections import OrderedDict +from functools import partial +import tensorflow as tf +from util import esrgan_utils as utils + +""" Keras Models for ESRGAN + Classes: + RRDBNet: Generator of ESRGAN. (Residual in Residual Network) + VGGArch: VGG28 Architecture making the Discriminator ESRGAN +""" + + +class RRDBNet(tf.keras.Model): + """ Residual in Residual Network consisting of: + - Convolution Layers + - Residual in Residual Block as the trunk of the model + - Pixel Shuffler layers (tf.nn.depth_to_space) + - Upscaling Convolutional Layers + + Args: + out_channel: number of channels of the fake output image. + num_features (default: 32): number of filters to use in the convolutional layers. + trunk_size (default: 3): number of Residual in Residual Blocks to form the trunk. + growth_channel (default: 32): number of filters to use in the internal convolutional layers. + use_bias (default: True): boolean to indicate if bias is to be used in the conv layers. + """ + + def __init__( + self, + out_channel, + num_features=32, + trunk_size=11, + growth_channel=32, + use_bias=True, + first_call=True): + super(RRDBNet, self).__init__() + self.rrdb_block = partial(utils.RRDB, growth_channel, first_call=first_call) + conv = partial( + tf.keras.layers.Conv2D, + kernel_size=[3, 3], + strides=[1, 1], + padding="same", + use_bias=use_bias) + conv_transpose = partial( + tf.keras.layers.Conv2DTranspose, + kernel_size=[3, 3], + strides=[2, 2], + padding="same", + use_bias=use_bias) + self.conv_first = conv(filters=num_features) + self.rdb_trunk = tf.keras.Sequential( + [self.rrdb_block() for _ in range(trunk_size)]) + self.conv_trunk = conv(filters=num_features) + # Upsample + self.upsample1 = conv_transpose(num_features) + self.upsample2 = conv_transpose(num_features) + self.conv_last_1 = conv(num_features) + self.conv_last_2 = conv(out_channel) + self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2) + + # @tf.function( + # input_signature=[ + # tf.TensorSpec(shape=[None, None, None, 3], + # dtype=tf.float32)]) + def call(self, inputs): + return self.unsigned_call(inputs) + + def unsigned_call(self, input_): + feature = self.lrelu(self.conv_first(input_)) + trunk = self.conv_trunk(self.rdb_trunk(feature)) + feature = trunk + feature + feature = self.lrelu( + self.upsample1(feature)) + feature = self.lrelu( + self.upsample2(feature)) + feature = self.lrelu(self.conv_last_1(feature)) + out = self.conv_last_2(feature) + return out + + +class VGGArch(tf.keras.Model): + """ Keras Model for VGG28 Architecture needed to form + the discriminator of the architecture. + Args: + output_shape (default: 1): output_shape of the generator + num_features (default: 64): number of features to be used in the convolutional layers + a factor of 2**i will be multiplied as per the need + use_bias (default: True): Boolean to indicate whether to use biases for convolution layers + """ + + def __init__( + self, + batch_size=8, + output_shape=1, + num_features=64, + use_bias=False): + + super(VGGArch, self).__init__() + conv = partial( + tf.keras.layers.Conv2D, + kernel_size=[3, 3], use_bias=use_bias, padding="same") + batch_norm = partial(tf.keras.layers.BatchNormalization) + def no_batch_norm(x): return x + self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2) + self._dense_1 = tf.keras.layers.Dense(100) + self._dense_2 = tf.keras.layers.Dense(output_shape) + self._conv_layers = OrderedDict() + self._batch_norm = OrderedDict() + self._conv_layers["conv_0_0"] = conv(filters=num_features, strides=1) + self._conv_layers["conv_0_1"] = conv(filters=num_features, strides=2) + self._batch_norm["bn_0_1"] = batch_norm() + for i in range(1, 4): + for j in range(1, 3): + self._conv_layers["conv_%d_%d" % (i, j)] = conv( + filters=num_features * (2**i), strides=j) + self._batch_norm["bn_%d_%d" % (i, j)] = batch_norm() + + def call(self, inputs): + return self.unsigned_call(inputs) + def unsigned_call(self, input_): + + features = self._lrelu(self._conv_layers["conv_0_0"](input_)) + features = self._lrelu( + self._batch_norm["bn_0_1"]( + self._conv_layers["conv_0_1"](features))) + # VGG Trunk + for i in range(1, 4): + for j in range(1, 3): + features = self._lrelu( + self._batch_norm["bn_%d_%d" % (i, j)]( + self._conv_layers["conv_%d_%d" % (i, j)](features))) + + flattened = tf.keras.layers.Flatten()(features) + dense = self._lrelu(self._dense_1(flattened)) + out = self._dense_2(dense) + return out -- GitLab