diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index 3b6c3589f32fd77b3b0990360157545f323826c8..f372878fb1729d6aba08a992231753985ac22724 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -210,9 +210,9 @@ class ESPCN: self.n_chans = 1 - self.X_img = tf.keras.Input(shape=(None, None, self.n_chans)) + # self.X_img = tf.keras.Input(shape=(None, None, self.n_chans)) # self.X_img = tf.keras.Input(shape=(36, 36, self.n_chans)) - # self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans)) + self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans)) self.inputs.append(self.X_img) @@ -414,20 +414,25 @@ class ESPCN: conv = tf.keras.layers.BatchNormalization()(conv) print(conv.shape) - conv = tf.keras.layers.Conv2D(4, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) - print(conv.shape) + # conv = tf.keras.layers.Conv2D(4, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) + # print(conv.shape) - conv = tf.nn.depth_to_space(conv, block_size=2) - conv = tf.keras.layers.Activation(activation=activation)(conv) + conv = tf.keras.layers.Conv2DTranspose(4, kernel_size=3, strides=2, padding=padding, activation=activation)(conv) print(conv.shape) - if NumClasses == 2: - activation = tf.nn.sigmoid # For binary - else: - activation = tf.nn.softmax # For multi-class + self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=tf.nn.sigmoid)(conv) - # Called logits, but these are actually probabilities, see activation - self.logits = tf.keras.layers.Activation(activation=activation)(conv) + # conv = tf.nn.depth_to_space(conv, block_size=2) + # conv = tf.keras.layers.Activation(activation=activation)(conv) + # print(conv.shape) + # + # if NumClasses == 2: + # activation = tf.nn.sigmoid # For binary + # else: + # activation = tf.nn.softmax # For multi-class + # + # # Called logits, but these are actually probabilities, see activation + # self.logits = tf.keras.layers.Activation(activation=activation)(conv) print(self.logits.shape)