diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index a09007e109548a29d22568ba020ba82508d80cc1..f665a5cf4980e3732eca7b8f0a48a8660859c0f3 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -210,12 +210,14 @@ class ESPCN: self.n_chans = 1 - self.X_img = tf.keras.Input(shape=(None, None, self.n_chans)) + #self.X_img = tf.keras.Input(shape=(None, None, self.n_chans)) # self.X_img = tf.keras.Input(shape=(36, 36, self.n_chans)) + self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans)) self.inputs.append(self.X_img) - self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans))) + #self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans))) # self.inputs.append(tf.keras.Input(shape=(36, 36, self.n_chans))) + self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans))) self.DISK_CACHE = False @@ -420,20 +422,19 @@ class ESPCN: print('input: ', input_2d.shape) # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding='VALID', activation=None)(input_2d) conv = input_2d - print('Contracting Branch') print('input: ', conv.shape) skip = conv if NOISE_TRAINING: conv = tf.keras.layers.GaussianNoise(stddev=NOISE_STDDEV)(conv) - conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) + conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.BatchNormalization()(conv) print(conv.shape) - # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) - # conv = tf.keras.layers.BatchNormalization()(conv) - # print(conv.shape) + conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) + conv = tf.keras.layers.BatchNormalization()(conv) + print(conv.shape) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(conv) conv = tf.keras.layers.BatchNormalization()(conv) @@ -447,10 +448,15 @@ class ESPCN: conv = tf.keras.layers.BatchNormalization()(conv) print(conv.shape) + conv = tf.keras.layers.Conv2D(num_filters/2, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) + conv = tf.keras.layers.BatchNormalization()(conv) + print(conv.shape) + conv = tf.keras.layers.Conv2D(4, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) print(conv.shape) conv = tf.nn.depth_to_space(conv, block_size=2) + conv = tf.keras.layers.Activation(activation=activation)(conv) print(conv.shape) if NumClasses == 2: