From c3501ca45de0dfe8a49048e4541c3c0fcc26a6b5 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 12 Apr 2022 12:42:31 -0500
Subject: [PATCH] minor

---
 modules/deeplearning/unet.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index 84a91bbe..4b6f9302 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -507,20 +507,27 @@ class UNET:
         self.get_evaluate_dataset(idxs)
 
     def setup_salt_pipeline(self, data_path, label_path, perc=0.2):
-        data_files = glob.glob(data_path+'*.png')
-        label_files = glob.glob(label_path+'*.png')
+        data_files = np.array(glob.glob(data_path+'*.png'))
+        label_files = np.array(glob.glob(label_path+'*.png'))
+
         num_data_samples = len(data_files)
         idxs = np.arange(num_data_samples)
+        np.random.shuffle(idxs)
         num_train = int(num_data_samples*(1-perc))
         self.num_data_samples = num_train
         trn_idxs = idxs[0:num_train]
+        np.sort(trn_idxs)
+        tst_idxs = idxs[num_train:]
+        np.sort(tst_idxs)
+
+        self.train_data_files = data_files[trn_idxs]
+        self.train_label_files = label_files[trn_idxs]
 
-        self.train_data_files = data_files[0:num_train]
-        self.train_label_files = label_files[0:num_train]
+        self.test_data_files = data_files[tst_idxs]
+        self.test_label_files = label_files[tst_idxs]
 
-        self.test_data_files = data_files[num_train:]
-        self.test_label_files = label_files[num_train:]
-        tst_idxs = np.arange(num_data_samples-num_train)
+        trn_idxs = np.arange(len(trn_idxs))
+        tst_idxs = np.arange(len(tst_idxs))
 
         np.random.shuffle(trn_idxs)
         self.get_train_dataset(trn_idxs)
-- 
GitLab