from collections import OrderedDict
from functools import partial
import tensorflow as tf
from util import esrgan_utils as utils

""" Keras Models for ESRGAN
    Classes:
      RRDBNet: Generator of ESRGAN. (Residual in Residual Network)
      VGGArch: VGG28 Architecture making the Discriminator ESRGAN
"""


class RRDBNet(tf.keras.Model):
  """ Residual in Residual Network consisting of:
      - Convolution Layers
      - Residual in Residual Block as the trunk of the model
      - Pixel Shuffler layers (tf.nn.depth_to_space)
      - Upscaling Convolutional Layers

      Args:
        out_channel: number of channels of the fake output image.
        num_features (default: 32): number of filters to use in the convolutional layers.
        trunk_size (default: 3): number of Residual in Residual Blocks to form the trunk.
        growth_channel (default: 32): number of filters to use in the internal convolutional layers.
        use_bias (default: True): boolean to indicate if bias is to be used in the conv layers.
  """

  def __init__(
          self,
          out_channel,
          num_features=32,
          trunk_size=11,
          growth_channel=32,
          use_bias=True,
          first_call=True):
    super(RRDBNet, self).__init__()
    self.rrdb_block = partial(utils.RRDB, growth_channel, first_call=first_call)
    conv = partial(
        tf.keras.layers.Conv2D,
        kernel_size=[3, 3],
        strides=[1, 1],
        padding="same",
        use_bias=use_bias)
    conv_transpose = partial(
        tf.keras.layers.Conv2DTranspose,
        kernel_size=[3, 3],
        strides=[2, 2],
        padding="same",
        use_bias=use_bias)
    self.conv_first = conv(filters=num_features)
    self.rdb_trunk = tf.keras.Sequential(
        [self.rrdb_block() for _ in range(trunk_size)])
    self.conv_trunk = conv(filters=num_features)
    # Upsample
    self.upsample1 = conv_transpose(num_features)
    self.upsample2 = conv_transpose(num_features)
    self.conv_last_1 = conv(num_features)
    self.conv_last_2 = conv(out_channel)
    self.lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)

  # @tf.function(
  #    input_signature=[
  #        tf.TensorSpec(shape=[None, None, None, 3],
  #                      dtype=tf.float32)])
  def call(self, inputs):
    return self.unsigned_call(inputs)

  def unsigned_call(self, input_):
    feature = self.lrelu(self.conv_first(input_))
    trunk = self.conv_trunk(self.rdb_trunk(feature))
    feature = trunk + feature
    feature = self.lrelu(
            self.upsample1(feature))
    feature = self.lrelu(
          self.upsample2(feature))
    feature = self.lrelu(self.conv_last_1(feature))
    out = self.conv_last_2(feature)
    return out


class VGGArch(tf.keras.Model):
  """ Keras Model for VGG28 Architecture needed to form
      the discriminator of the architecture.
      Args:
        output_shape (default: 1): output_shape of the generator
        num_features (default: 64): number of features to be used in the convolutional layers
                                    a factor of 2**i will be multiplied as per the need
        use_bias (default: True): Boolean to indicate whether to use biases for convolution layers
  """

  def __init__(
          self,
          batch_size=8,
          output_shape=1,
          num_features=64,
          use_bias=False):

    super(VGGArch, self).__init__()
    conv = partial(
        tf.keras.layers.Conv2D,
        kernel_size=[3, 3], use_bias=use_bias, padding="same")
    batch_norm = partial(tf.keras.layers.BatchNormalization)
    def no_batch_norm(x): return x
    self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
    self._dense_1 = tf.keras.layers.Dense(100)
    self._dense_2 = tf.keras.layers.Dense(output_shape)
    self._conv_layers = OrderedDict()
    self._batch_norm = OrderedDict()
    self._conv_layers["conv_0_0"] = conv(filters=num_features, strides=1)
    self._conv_layers["conv_0_1"] = conv(filters=num_features, strides=2)
    self._batch_norm["bn_0_1"] = batch_norm()
    for i in range(1, 4):
      for j in range(1, 3):
        self._conv_layers["conv_%d_%d" % (i, j)] = conv(
            filters=num_features * (2**i), strides=j)
        self._batch_norm["bn_%d_%d" % (i, j)] = batch_norm()

  def call(self, inputs):
    return self.unsigned_call(inputs)
  def unsigned_call(self, input_):

    features = self._lrelu(self._conv_layers["conv_0_0"](input_))
    features = self._lrelu(
        self._batch_norm["bn_0_1"](
            self._conv_layers["conv_0_1"](features)))
    # VGG Trunk
    for i in range(1, 4):
      for j in range(1, 3):
        features = self._lrelu(
            self._batch_norm["bn_%d_%d" % (i, j)](
                self._conv_layers["conv_%d_%d" % (i, j)](features)))

    flattened = tf.keras.layers.Flatten()(features)
    dense = self._lrelu(self._dense_1(flattened))
    out = self._dense_2(dense)
    return out