From 57f5813c087349e4a7cb31350d9665f37b6bc995 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 18 Apr 2022 17:15:21 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/unet.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index 7bea589d..4f8a7461 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -506,7 +506,7 @@ class UNET: self.get_evaluate_dataset(idxs) - def setup_salt_pipeline(self, data_path, label_path, perc=0.2): + def setup_salt_pipeline(self, data_path, label_path, perc=0.15): data_files = np.array(glob.glob(data_path+'*.png')) label_files = np.array(glob.glob(label_path+'*.png')) @@ -533,6 +533,20 @@ class UNET: self.get_train_dataset(trn_idxs) self.get_test_dataset(tst_idxs) + f = open(home_dir+'/salt_test_files.pkl', 'wb') + pickle.dump((self.test_data_files, self.test_label_files), f) + f.close() + + def setup_salt_restore(self, test_files='/Users/tomrink/salt_test_files.pkl'): + tup = pickle.load(open(test_files, 'rb')) + + self.test_data_files = tup[0] + self.test_label_files = tup[1] + self.num_data_samples = len(self.test_data_files) + + tst_idxs = np.arange(self.num_data_samples) + self.get_test_dataset(tst_idxs) + def build_unet(self): print('build_cnn') # padding = "VALID" @@ -1000,7 +1014,8 @@ class UNET: self.do_training() def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None): - self.setup_salt_pipeline(data_path, label_path) + #self.setup_salt_pipeline(data_path, label_path) + self.setup_salt_restore() self.build_model() self.build_training() self.build_evaluation() -- GitLab