diff --git a/modules/deeplearning/unet_l1b_l2.py b/modules/deeplearning/unet_l1b_l2.py index e94fd5ff7c6938301924a82c9cc22129e7fb4f3d..fc418e80012d9693443587257bcb84106647c985 100644 --- a/modules/deeplearning/unet_l1b_l2.py +++ b/modules/deeplearning/unet_l1b_l2.py @@ -370,7 +370,7 @@ class UNET: # print('num test samples: ', tst_idxs.shape[0]) # print('setup_pipeline: Done') - def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files): + def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples): self.train_data_files = train_data_files self.train_label_files = train_label_files @@ -384,7 +384,7 @@ class UNET: self.get_train_dataset(trn_idxs) self.get_test_dataset(tst_idxs) - self.num_data_samples = 34000 # approximately + self.num_data_samples = num_train_samples # approximately print('datetime: ', now) print('training and test data: ')