diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index 1044e416ec9d66b4a44d2be94f237bc85c10fa9b..5bd6ebeaaae3c350d664087953c9e2445641a487 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -376,11 +376,13 @@ class UNET: activation = tf.nn.leaky_relu momentum = 0.99 - num_filters = self.n_chans * 4 + num_filters = self.n_chans * 8 input_2d = self.inputs[0] - conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d) - print('Contracting Branch') + print('input: ', input_2d.shape) + conv = tf.keras.layers.Conv2D(num_filters, kernel_size=7, strides=1, padding='VALID', activation=None)(input_2d) + conv = conv[:, 6:70, 6:70, :] + print('Contracting Branch -----------') print('input: ', conv.shape) skip = conv @@ -451,7 +453,7 @@ class UNET: print('4d: ', conv.shape) # Expanding (Decoding) branch ------------------------------------------------------------------------------- - print('expanding branch') + print('expanding branch --------------') num_filters /= 2 conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) @@ -834,8 +836,8 @@ class UNET: # f.close() def build_model(self): - # self.build_unet() - self.build_upsample() + self.build_unet() + # self.build_upsample() self.model = tf.keras.Model(self.inputs, self.logits) def restore(self, ckpt_dir):