diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index f95221396bcfa8aeaff417bd5323e090372a7f8f..6d371e2a6a5f2d8c9f0e75b0f736d8dabea29c0b 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -14,7 +14,7 @@ import h5py LOG_DEVICE_PLACEMENT = False -PROC_BATCH_SIZE = 50 +PROC_BATCH_SIZE = 10 PROC_BATCH_BUFFER_SIZE = 50000 NumClasses = 2 @@ -189,9 +189,11 @@ class UNET: label_s.append(nda) data = np.concatenate(data_s) + data = data[:, 0, :, :] data = np.expand_dims(data, axis=3) label = np.concatenate(label_s) + label = label[:, 0, :, :] label = np.expand_dims(label, axis=3) data = data.astype(np.float32) @@ -490,6 +492,68 @@ class UNET: print(self.logits.shape) + def build_upsample(self): + print('build_upsample') + # padding = "VALID" + padding = "SAME" + + # activation = tf.nn.relu + # activation = tf.nn.elu + activation = tf.nn.leaky_relu + + num_filters = self.n_chans * 8 + + input_2d = self.inputs[0] + print('input: ', input_2d.shape) + + # Expanding (Decoding) branch ------------------------------------------------------------------------------- + print('expanding branch') + + conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=7, strides=2, padding=padding, activation=activation)(input_2d) + conv = tf.keras.layers.BatchNormalization()(conv) + print(conv.shape) + conv = conv[:, 18:146, 18:146, :] + + num_filters /= 2 + conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv) + print(conv.shape) + + num_filters /= 2 + conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv) + print(conv.shape) + + # if NumClasses == 2: + # activation = tf.nn.sigmoid # For binary + # else: + # activation = tf.nn.softmax # For multi-class + activation = tf.nn.sigmoid + + # Called logits, but these are actually probabilities, see activation + self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv) + + print(self.logits.shape) + + # num_filters /= 2 + # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) + # conv = tf.keras.layers.concatenate([conv, conv_3]) + # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) + # conv = tf.keras.layers.BatchNormalization()(conv) + # print('6: ', conv.shape) + # + # num_filters /= 2 + # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) + # conv = tf.keras.layers.concatenate([conv, conv_2]) + # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) + # conv = tf.keras.layers.BatchNormalization()(conv) + # print('7: ', conv.shape) + # + # num_filters /= 2 + # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) + # print('8: ', conv.shape) + # + # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv) + # print('9: ', conv.shape) + def build_training(self): # if NumClasses == 2: # self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only @@ -770,7 +834,8 @@ class UNET: # f.close() def build_model(self): - self.build_unet() + # self.build_unet() + self.build_upsample() self.model = tf.keras.Model(self.inputs, self.logits) def restore(self, ckpt_dir): @@ -826,9 +891,17 @@ class UNET: self.test_preds = preds def run(self, directory): - data_files = glob.glob(directory+'mod_res*.npy') - label_files = [f.replace('mod', 'img') for f in data_files] - self.setup_pipeline(data_files, label_files) + train_data_files = glob.glob(directory+'data_train*.npy') + valid_data_files = glob.glob(directory+'data_valid*.npy') + train_label_files = glob.glob(directory+'label_train*.npy') + valid_label_files = glob.glob(directory+'label_valid*.npy') + + train_data_files.sort() + valid_data_files.sort() + train_label_files.sort() + valid_label_files.sort() + + self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 100000) self.build_model() self.build_training() self.build_evaluation()