Skip to content
Snippets Groups Projects
Commit 71200f13 authored by tomrink's avatar tomrink
Browse files

minor

parent 1cd772ba
No related branches found
No related tags found
No related merge requests found
...@@ -533,7 +533,8 @@ class UNET: ...@@ -533,7 +533,8 @@ class UNET:
input_2d = self.inputs[0] input_2d = self.inputs[0]
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d)
print(conv.shape) print('Contracting Branch')
print('input: ', conv.shape)
skip = conv skip = conv
if NOISE_TRAINING: if NOISE_TRAINING:
...@@ -545,7 +546,6 @@ class UNET: ...@@ -545,7 +546,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv) conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip) skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip) skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
...@@ -553,7 +553,7 @@ class UNET: ...@@ -553,7 +553,7 @@ class UNET:
conv = conv + skip conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv) conv = tf.keras.layers.LeakyReLU()(conv)
print(conv.shape) print('1d: ', conv.shape)
# ----------------------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------------------
conv_2 = conv conv_2 = conv
...@@ -562,7 +562,6 @@ class UNET: ...@@ -562,7 +562,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv) conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip) skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip) skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
...@@ -570,7 +569,7 @@ class UNET: ...@@ -570,7 +569,7 @@ class UNET:
conv = conv + skip conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv) conv = tf.keras.layers.LeakyReLU()(conv)
print(conv.shape) print('2d: ', conv.shape)
# ---------------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------------
conv_3 = conv conv_3 = conv
...@@ -579,7 +578,6 @@ class UNET: ...@@ -579,7 +578,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv) conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip) skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip) skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
...@@ -587,6 +585,7 @@ class UNET: ...@@ -587,6 +585,7 @@ class UNET:
conv = conv + skip conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv) conv = tf.keras.layers.LeakyReLU()(conv)
print('3d: ', conv.shape)
# ----------------------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------------------
conv_4 = conv conv_4 = conv
...@@ -595,7 +594,6 @@ class UNET: ...@@ -595,7 +594,6 @@ class UNET:
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.MaxPool2D(padding=padding)(conv) conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip) skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
skip = tf.keras.layers.MaxPool2D(padding=padding)(skip) skip = tf.keras.layers.MaxPool2D(padding=padding)(skip)
...@@ -603,6 +601,7 @@ class UNET: ...@@ -603,6 +601,7 @@ class UNET:
conv = conv + skip conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv) conv = tf.keras.layers.LeakyReLU()(conv)
print('4d: ', conv.shape)
# Expanding (Decoding) branch ------------------------------------------------------------------------------- # Expanding (Decoding) branch -------------------------------------------------------------------------------
print('expanding branch') print('expanding branch')
...@@ -612,26 +611,25 @@ class UNET: ...@@ -612,26 +611,25 @@ class UNET:
conv = tf.keras.layers.concatenate([conv, conv_4]) conv = tf.keras.layers.concatenate([conv, conv_4])
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape) print('5: ', conv.shape)
num_filters /= 2 num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
conv = tf.keras.layers.concatenate([conv, conv_3]) conv = tf.keras.layers.concatenate([conv, conv_3])
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape) print('6: ', conv.shape)
num_filters /= 2 num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
conv = tf.keras.layers.concatenate([conv, conv_2]) conv = tf.keras.layers.concatenate([conv, conv_2])
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv) conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape) print('7: ', conv.shape)
num_filters /= 2 num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
conv = tf.keras.layers.concatenate([conv, conv_1]) print('8: ', conv.shape)
print(conv.shape)
if NumClasses == 2: if NumClasses == 2:
activation = tf.nn.sigmoid # For binary activation = tf.nn.sigmoid # For binary
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment