diff --git a/modules/deeplearning/unet_l1b_l2.py b/modules/deeplearning/unet_l1b_l2.py index c15cc582f715b8e9d79172e48614b22880d91024..69c67880fa0b1fc5e0fae078293198655628992e 100644 --- a/modules/deeplearning/unet_l1b_l2.py +++ b/modules/deeplearning/unet_l1b_l2.py @@ -13,7 +13,7 @@ import h5py LOG_DEVICE_PLACEMENT = False -PROC_BATCH_SIZE = 10 +PROC_BATCH_SIZE = 2 PROC_BATCH_BUFFER_SIZE = 50000 NumClasses = 2 @@ -188,7 +188,7 @@ class UNET: self.inputs.append(self.X_img) # self.inputs.append(tf.keras.Input(shape=(None, None, 5))) - self.inputs.append(tf.keras.Input(shape=(None, None, 6))) + self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans))) self.DISK_CACHE = False @@ -861,7 +861,7 @@ class UNET: train_label_files.sort() valid_label_files.sort() - self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files) + self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 50000) self.build_model() self.build_training() self.build_evaluation()