diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index 7bea589d8298152a8883f9d57ece8e4b53fcee88..4f8a7461a41989b77302295278fd0bf23b6e13a4 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -506,7 +506,7 @@ class UNET: self.get_evaluate_dataset(idxs) - def setup_salt_pipeline(self, data_path, label_path, perc=0.2): + def setup_salt_pipeline(self, data_path, label_path, perc=0.15): data_files = np.array(glob.glob(data_path+'*.png')) label_files = np.array(glob.glob(label_path+'*.png')) @@ -533,6 +533,20 @@ class UNET: self.get_train_dataset(trn_idxs) self.get_test_dataset(tst_idxs) + f = open(home_dir+'/salt_test_files.pkl', 'wb') + pickle.dump((self.test_data_files, self.test_label_files), f) + f.close() + + def setup_salt_restore(self, test_files='/Users/tomrink/salt_test_files.pkl'): + tup = pickle.load(open(test_files, 'rb')) + + self.test_data_files = tup[0] + self.test_label_files = tup[1] + self.num_data_samples = len(self.test_data_files) + + tst_idxs = np.arange(self.num_data_samples) + self.get_test_dataset(tst_idxs) + def build_unet(self): print('build_cnn') # padding = "VALID" @@ -1000,7 +1014,8 @@ class UNET: self.do_training() def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None): - self.setup_salt_pipeline(data_path, label_path) + #self.setup_salt_pipeline(data_path, label_path) + self.setup_salt_restore() self.build_model() self.build_training() self.build_evaluation()