From c3501ca45de0dfe8a49048e4541c3c0fcc26a6b5 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Tue, 12 Apr 2022 12:42:31 -0500 Subject: [PATCH] minor --- modules/deeplearning/unet.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index 84a91bbe..4b6f9302 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -507,20 +507,27 @@ class UNET: self.get_evaluate_dataset(idxs) def setup_salt_pipeline(self, data_path, label_path, perc=0.2): - data_files = glob.glob(data_path+'*.png') - label_files = glob.glob(label_path+'*.png') + data_files = np.array(glob.glob(data_path+'*.png')) + label_files = np.array(glob.glob(label_path+'*.png')) + num_data_samples = len(data_files) idxs = np.arange(num_data_samples) + np.random.shuffle(idxs) num_train = int(num_data_samples*(1-perc)) self.num_data_samples = num_train trn_idxs = idxs[0:num_train] + np.sort(trn_idxs) + tst_idxs = idxs[num_train:] + np.sort(tst_idxs) + + self.train_data_files = data_files[trn_idxs] + self.train_label_files = label_files[trn_idxs] - self.train_data_files = data_files[0:num_train] - self.train_label_files = label_files[0:num_train] + self.test_data_files = data_files[tst_idxs] + self.test_label_files = label_files[tst_idxs] - self.test_data_files = data_files[num_train:] - self.test_label_files = label_files[num_train:] - tst_idxs = np.arange(num_data_samples-num_train) + trn_idxs = np.arange(len(trn_idxs)) + tst_idxs = np.arange(len(tst_idxs)) np.random.shuffle(trn_idxs) self.get_train_dataset(trn_idxs) -- GitLab