Skip to content
Snippets Groups Projects
Commit 1ed2d21e authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 11bd6cf9
No related branches found
Tags v0.2.0
No related merge requests found
...@@ -14,7 +14,7 @@ import h5py ...@@ -14,7 +14,7 @@ import h5py
LOG_DEVICE_PLACEMENT = False LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 50 PROC_BATCH_SIZE = 10
PROC_BATCH_BUFFER_SIZE = 50000 PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 2 NumClasses = 2
...@@ -189,9 +189,11 @@ class UNET: ...@@ -189,9 +189,11 @@ class UNET:
label_s.append(nda) label_s.append(nda)
data = np.concatenate(data_s) data = np.concatenate(data_s)
data = data[:, 0, :, :]
data = np.expand_dims(data, axis=3) data = np.expand_dims(data, axis=3)
label = np.concatenate(label_s) label = np.concatenate(label_s)
label = label[:, 0, :, :]
label = np.expand_dims(label, axis=3) label = np.expand_dims(label, axis=3)
data = data.astype(np.float32) data = data.astype(np.float32)
...@@ -490,6 +492,68 @@ class UNET: ...@@ -490,6 +492,68 @@ class UNET:
print(self.logits.shape) print(self.logits.shape)
def build_upsample(self):
print('build_upsample')
# padding = "VALID"
padding = "SAME"
# activation = tf.nn.relu
# activation = tf.nn.elu
activation = tf.nn.leaky_relu
num_filters = self.n_chans * 8
input_2d = self.inputs[0]
print('input: ', input_2d.shape)
# Expanding (Decoding) branch -------------------------------------------------------------------------------
print('expanding branch')
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=7, strides=2, padding=padding, activation=activation)(input_2d)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
conv = conv[:, 18:146, 18:146, :]
num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv)
print(conv.shape)
num_filters /= 2
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv)
print(conv.shape)
# if NumClasses == 2:
# activation = tf.nn.sigmoid # For binary
# else:
# activation = tf.nn.softmax # For multi-class
activation = tf.nn.sigmoid
# Called logits, but these are actually probabilities, see activation
self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv)
print(self.logits.shape)
# num_filters /= 2
# conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
# conv = tf.keras.layers.concatenate([conv, conv_3])
# conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
# conv = tf.keras.layers.BatchNormalization()(conv)
# print('6: ', conv.shape)
#
# num_filters /= 2
# conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
# conv = tf.keras.layers.concatenate([conv, conv_2])
# conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
# conv = tf.keras.layers.BatchNormalization()(conv)
# print('7: ', conv.shape)
#
# num_filters /= 2
# conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
# print('8: ', conv.shape)
#
# conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
# print('9: ', conv.shape)
def build_training(self): def build_training(self):
# if NumClasses == 2: # if NumClasses == 2:
# self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only # self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only
...@@ -770,7 +834,8 @@ class UNET: ...@@ -770,7 +834,8 @@ class UNET:
# f.close() # f.close()
def build_model(self): def build_model(self):
self.build_unet() # self.build_unet()
self.build_upsample()
self.model = tf.keras.Model(self.inputs, self.logits) self.model = tf.keras.Model(self.inputs, self.logits)
def restore(self, ckpt_dir): def restore(self, ckpt_dir):
...@@ -826,9 +891,17 @@ class UNET: ...@@ -826,9 +891,17 @@ class UNET:
self.test_preds = preds self.test_preds = preds
def run(self, directory): def run(self, directory):
data_files = glob.glob(directory+'mod_res*.npy') train_data_files = glob.glob(directory+'data_train*.npy')
label_files = [f.replace('mod', 'img') for f in data_files] valid_data_files = glob.glob(directory+'data_valid*.npy')
self.setup_pipeline(data_files, label_files) train_label_files = glob.glob(directory+'label_train*.npy')
valid_label_files = glob.glob(directory+'label_valid*.npy')
train_data_files.sort()
valid_data_files.sort()
train_label_files.sort()
valid_label_files.sort()
self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 100000)
self.build_model() self.build_model()
self.build_training() self.build_training()
self.build_evaluation() self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment