From b6e20077fb9b5409e570fc50109f4e9c4e923c7c Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 29 Jan 2024 16:53:04 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/esrgan.py | 136 +++++++++++++++++++++++++++++++++
 1 file changed, 136 insertions(+)
 create mode 100644 modules/deeplearning/esrgan.py

diff --git a/modules/deeplearning/esrgan.py b/modules/deeplearning/esrgan.py
new file mode 100644
index 00000000..808503c1
--- /dev/null
+++ b/modules/deeplearning/esrgan.py
@@ -0,0 +1,136 @@
+from collections import OrderedDict
+from functools import partial
+import tensorflow as tf
+from util import esrgan_utils as utils
+
+""" Keras Models for ESRGAN
+    Classes:
+      RRDBNet: Generator of ESRGAN. (Residual in Residual Network)
+      VGGArch: VGG28 Architecture making the Discriminator ESRGAN
+"""
+
+
+class RRDBNet(tf.keras.Model):
+  """ Residual in Residual Network consisting of:
+      - Convolution Layers
+      - Residual in Residual Block as the trunk of the model
+      - Pixel Shuffler layers (tf.nn.depth_to_space)
+      - Upscaling Convolutional Layers
+
+      Args:
+        out_channel: number of channels of the fake output image.
+        num_features (default: 32): number of filters to use in the convolutional layers.
+        trunk_size (default: 3): number of Residual in Residual Blocks to form the trunk.
+        growth_channel (default: 32): number of filters to use in the internal convolutional layers.
+        use_bias (default: True): boolean to indicate if bias is to be used in the conv layers.
+  """
+
+  def __init__(
+          self,
+          out_channel,
+          num_features=32,
+          trunk_size=11,
+          growth_channel=32,
+          use_bias=True,
+          first_call=True):
+    super(RRDBNet, self).__init__()
+    self.rrdb_block = partial(utils.RRDB, growth_channel, first_call=first_call)
+    conv = partial(
+        tf.keras.layers.Conv2D,
+        kernel_size=[3, 3],
+        strides=[1, 1],
+        padding="same",
+        use_bias=use_bias)
+    conv_transpose = partial(
+        tf.keras.layers.Conv2DTranspose,
+        kernel_size=[3, 3],
+        strides=[2, 2],
+        padding="same",
+        use_bias=use_bias)
+    self.conv_first = conv(filters=num_features)
+    self.rdb_trunk = tf.keras.Sequential(
+        [self.rrdb_block() for _ in range(trunk_size)])
+    self.conv_trunk = conv(filters=num_features)
+    # Upsample
+    self.upsample1 = conv_transpose(num_features)
+    self.upsample2 = conv_transpose(num_features)
+    self.conv_last_1 = conv(num_features)
+    self.conv_last_2 = conv(out_channel)
+    self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
+
+  # @tf.function(
+  #    input_signature=[
+  #        tf.TensorSpec(shape=[None, None, None, 3],
+  #                      dtype=tf.float32)])
+  def call(self, inputs):
+    return self.unsigned_call(inputs)
+
+  def unsigned_call(self, input_):
+    feature = self.lrelu(self.conv_first(input_))
+    trunk = self.conv_trunk(self.rdb_trunk(feature))
+    feature = trunk + feature
+    feature = self.lrelu(
+            self.upsample1(feature))
+    feature = self.lrelu(
+          self.upsample2(feature))
+    feature = self.lrelu(self.conv_last_1(feature))
+    out = self.conv_last_2(feature)
+    return out
+
+
+class VGGArch(tf.keras.Model):
+  """ Keras Model for VGG28 Architecture needed to form
+      the discriminator of the architecture.
+      Args:
+        output_shape (default: 1): output_shape of the generator
+        num_features (default: 64): number of features to be used in the convolutional layers
+                                    a factor of 2**i will be multiplied as per the need
+        use_bias (default: True): Boolean to indicate whether to use biases for convolution layers
+  """
+
+  def __init__(
+          self,
+          batch_size=8,
+          output_shape=1,
+          num_features=64,
+          use_bias=False):
+
+    super(VGGArch, self).__init__()
+    conv = partial(
+        tf.keras.layers.Conv2D,
+        kernel_size=[3, 3], use_bias=use_bias, padding="same")
+    batch_norm = partial(tf.keras.layers.BatchNormalization)
+    def no_batch_norm(x): return x
+    self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
+    self._dense_1 = tf.keras.layers.Dense(100)
+    self._dense_2 = tf.keras.layers.Dense(output_shape)
+    self._conv_layers = OrderedDict()
+    self._batch_norm = OrderedDict()
+    self._conv_layers["conv_0_0"] = conv(filters=num_features, strides=1)
+    self._conv_layers["conv_0_1"] = conv(filters=num_features, strides=2)
+    self._batch_norm["bn_0_1"] = batch_norm()
+    for i in range(1, 4):
+      for j in range(1, 3):
+        self._conv_layers["conv_%d_%d" % (i, j)] = conv(
+            filters=num_features * (2**i), strides=j)
+        self._batch_norm["bn_%d_%d" % (i, j)] = batch_norm()
+
+  def call(self, inputs):
+    return self.unsigned_call(inputs)
+  def unsigned_call(self, input_):
+
+    features = self._lrelu(self._conv_layers["conv_0_0"](input_))
+    features = self._lrelu(
+        self._batch_norm["bn_0_1"](
+            self._conv_layers["conv_0_1"](features)))
+    # VGG Trunk
+    for i in range(1, 4):
+      for j in range(1, 3):
+        features = self._lrelu(
+            self._batch_norm["bn_%d_%d" % (i, j)](
+                self._conv_layers["conv_%d_%d" % (i, j)](features)))
+
+    flattened = tf.keras.layers.Flatten()(features)
+    dense = self._lrelu(self._dense_1(flattened))
+    out = self._dense_2(dense)
+    return out
-- 
GitLab