Skip to content
Snippets Groups Projects
Commit 71200f13 authored by tomrink's avatar tomrink
Browse files

minor

parent 1cd772ba
Branches
No related tags found
No related merge requests found
......@@ -533,7 +533,8 @@ class UNET:
input_2d = self.inputs[0]
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d)
print(conv.shape)
print('Contracting Branch')
print('input: ', conv.shape)
skip = conv
if NOISE_TRAINING:
......@@ -545,7 +546,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
......@@ -553,7 +553,7 @@ class UNET:
conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv)
print(conv.shape)
print('1d: ', conv.shape)
# -----------------------------------------------------------------------------------------------------------
conv_2 = conv
......@@ -562,7 +562,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
......@@ -570,7 +569,7 @@ class UNET:
conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv)
print(conv.shape)
print('2d: ', conv.shape)
# ----------------------------------------------------------------------------------------------------------
conv_3 = conv
......@@ -579,7 +578,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
......@@ -587,6 +585,7 @@ class UNET:
conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv)
print('3d: ', conv.shape)
# -----------------------------------------------------------------------------------------------------------
conv_4 = conv
......@@ -595,7 +594,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
......@@ -603,6 +601,7 @@ class UNET:
conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv)
print('4d: ', conv.shape)
# Expanding (Decoding) branch -------------------------------------------------------------------------------
print('expanding branch')
......@@ -612,26 +611,25 @@ class UNET:
conv = tf.keras.layers.concatenate([conv, conv_4])
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
print('5: ', conv.shape)
num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
conv = tf.keras.layers.concatenate([conv, conv_3])
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
print('6: ', conv.shape)
num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
conv = tf.keras.layers.concatenate([conv, conv_2])
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
print('7: ', conv.shape)
num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
conv = tf.keras.layers.concatenate([conv, conv_1])
print(conv.shape)
print('8: ', conv.shape)
if NumClasses == 2:
activation = tf.nn.sigmoid # For binary
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment