Skip to content
Snippets Groups Projects
Commit c3501ca4 authored by tomrink's avatar tomrink
Browse files

minor

parent be50aca5
Branches
No related tags found
No related merge requests found
......@@ -507,20 +507,27 @@ class UNET:
self.get_evaluate_dataset(idxs)
def setup_salt_pipeline(self, data_path, label_path, perc=0.2):
data_files = glob.glob(data_path+'*.png')
label_files = glob.glob(label_path+'*.png')
data_files = np.array(glob.glob(data_path+'*.png'))
label_files = np.array(glob.glob(label_path+'*.png'))
num_data_samples = len(data_files)
idxs = np.arange(num_data_samples)
np.random.shuffle(idxs)
num_train = int(num_data_samples*(1-perc))
self.num_data_samples = num_train
trn_idxs = idxs[0:num_train]
np.sort(trn_idxs)
tst_idxs = idxs[num_train:]
np.sort(tst_idxs)
self.train_data_files = data_files[trn_idxs]
self.train_label_files = label_files[trn_idxs]
self.train_data_files = data_files[0:num_train]
self.train_label_files = label_files[0:num_train]
self.test_data_files = data_files[tst_idxs]
self.test_label_files = label_files[tst_idxs]
self.test_data_files = data_files[num_train:]
self.test_label_files = label_files[num_train:]
tst_idxs = np.arange(num_data_samples-num_train)
trn_idxs = np.arange(len(trn_idxs))
tst_idxs = np.arange(len(tst_idxs))
np.random.shuffle(trn_idxs)
self.get_train_dataset(trn_idxs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment