diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index 7bea589d8298152a8883f9d57ece8e4b53fcee88..4f8a7461a41989b77302295278fd0bf23b6e13a4 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -506,7 +506,7 @@ class UNET:
 
         self.get_evaluate_dataset(idxs)
 
-    def setup_salt_pipeline(self, data_path, label_path, perc=0.2):
+    def setup_salt_pipeline(self, data_path, label_path, perc=0.15):
         data_files = np.array(glob.glob(data_path+'*.png'))
         label_files = np.array(glob.glob(label_path+'*.png'))
 
@@ -533,6 +533,20 @@ class UNET:
         self.get_train_dataset(trn_idxs)
         self.get_test_dataset(tst_idxs)
 
+        f = open(home_dir+'/salt_test_files.pkl', 'wb')
+        pickle.dump((self.test_data_files, self.test_label_files), f)
+        f.close()
+
+    def setup_salt_restore(self, test_files='/Users/tomrink/salt_test_files.pkl'):
+        tup = pickle.load(open(test_files, 'rb'))
+
+        self.test_data_files = tup[0]
+        self.test_label_files = tup[1]
+        self.num_data_samples = len(self.test_data_files)
+
+        tst_idxs = np.arange(self.num_data_samples)
+        self.get_test_dataset(tst_idxs)
+
     def build_unet(self):
         print('build_cnn')
         # padding = "VALID"
@@ -1000,7 +1014,8 @@ class UNET:
         self.do_training()
 
     def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
-        self.setup_salt_pipeline(data_path, label_path)
+        #self.setup_salt_pipeline(data_path, label_path)
+        self.setup_salt_restore()
         self.build_model()
         self.build_training()
         self.build_evaluation()