diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index f95221396bcfa8aeaff417bd5323e090372a7f8f..6d371e2a6a5f2d8c9f0e75b0f736d8dabea29c0b 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -14,7 +14,7 @@ import h5py
 
 LOG_DEVICE_PLACEMENT = False
 
-PROC_BATCH_SIZE = 50
+PROC_BATCH_SIZE = 10
 PROC_BATCH_BUFFER_SIZE = 50000
 
 NumClasses = 2
@@ -189,9 +189,11 @@ class UNET:
             label_s.append(nda)
 
         data = np.concatenate(data_s)
+        data = data[:, 0, :, :]
         data = np.expand_dims(data, axis=3)
 
         label = np.concatenate(label_s)
+        label = label[:, 0, :, :]
         label = np.expand_dims(label, axis=3)
 
         data = data.astype(np.float32)
@@ -490,6 +492,68 @@ class UNET:
 
         print(self.logits.shape)
 
+    def build_upsample(self):
+        print('build_upsample')
+        # padding = "VALID"
+        padding = "SAME"
+
+        # activation = tf.nn.relu
+        # activation = tf.nn.elu
+        activation = tf.nn.leaky_relu
+
+        num_filters = self.n_chans * 8
+
+        input_2d = self.inputs[0]
+        print('input: ', input_2d.shape)
+
+        # Expanding (Decoding) branch -------------------------------------------------------------------------------
+        print('expanding branch')
+
+        conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=7, strides=2, padding=padding, activation=activation)(input_2d)
+        conv = tf.keras.layers.BatchNormalization()(conv)
+        print(conv.shape)
+        conv = conv[:, 18:146, 18:146, :]
+
+        num_filters /= 2
+        conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv)
+        print(conv.shape)
+
+        num_filters /= 2
+        conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv)
+        print(conv.shape)
+
+        # if NumClasses == 2:
+        #     activation = tf.nn.sigmoid  # For binary
+        # else:
+        #     activation = tf.nn.softmax  # For multi-class
+        activation = tf.nn.sigmoid
+
+        # Called logits, but these are actually probabilities, see activation
+        self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv)
+
+        print(self.logits.shape)
+
+        # num_filters /= 2
+        # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
+        # conv = tf.keras.layers.concatenate([conv, conv_3])
+        # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
+        # conv = tf.keras.layers.BatchNormalization()(conv)
+        # print('6: ', conv.shape)
+        #
+        # num_filters /= 2
+        # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
+        # conv = tf.keras.layers.concatenate([conv, conv_2])
+        # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
+        # conv = tf.keras.layers.BatchNormalization()(conv)
+        # print('7: ', conv.shape)
+        #
+        # num_filters /= 2
+        # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
+        # print('8: ', conv.shape)
+        #
+        # conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
+        # print('9: ', conv.shape)
+
     def build_training(self):
         # if NumClasses == 2:
         #     self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)  # for two-class only
@@ -770,7 +834,8 @@ class UNET:
         # f.close()
 
     def build_model(self):
-        self.build_unet()
+        # self.build_unet()
+        self.build_upsample()
         self.model = tf.keras.Model(self.inputs, self.logits)
 
     def restore(self, ckpt_dir):
@@ -826,9 +891,17 @@ class UNET:
         self.test_preds = preds
 
     def run(self, directory):
-        data_files = glob.glob(directory+'mod_res*.npy')
-        label_files = [f.replace('mod', 'img') for f in data_files]
-        self.setup_pipeline(data_files, label_files)
+        train_data_files = glob.glob(directory+'data_train*.npy')
+        valid_data_files = glob.glob(directory+'data_valid*.npy')
+        train_label_files = glob.glob(directory+'label_train*.npy')
+        valid_label_files = glob.glob(directory+'label_valid*.npy')
+
+        train_data_files.sort()
+        valid_data_files.sort()
+        train_label_files.sort()
+        valid_label_files.sort()
+
+        self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 100000)
         self.build_model()
         self.build_training()
         self.build_evaluation()