Skip to content
Snippets Groups Projects
Commit 57f5813c authored by tomrink's avatar tomrink
Browse files

snapshot...

parent b3de3ac2
Branches
No related tags found
No related merge requests found
......@@ -506,7 +506,7 @@ class UNET:
self.get_evaluate_dataset(idxs)
def setup_salt_pipeline(self, data_path, label_path, perc=0.2):
def setup_salt_pipeline(self, data_path, label_path, perc=0.15):
data_files = np.array(glob.glob(data_path+'*.png'))
label_files = np.array(glob.glob(label_path+'*.png'))
......@@ -533,6 +533,20 @@ class UNET:
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
f = open(home_dir+'/salt_test_files.pkl', 'wb')
pickle.dump((self.test_data_files, self.test_label_files), f)
f.close()
def setup_salt_restore(self, test_files='/Users/tomrink/salt_test_files.pkl'):
tup = pickle.load(open(test_files, 'rb'))
self.test_data_files = tup[0]
self.test_label_files = tup[1]
self.num_data_samples = len(self.test_data_files)
tst_idxs = np.arange(self.num_data_samples)
self.get_test_dataset(tst_idxs)
def build_unet(self):
print('build_cnn')
# padding = "VALID"
......@@ -1000,7 +1014,8 @@ class UNET:
self.do_training()
def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
self.setup_salt_pipeline(data_path, label_path)
#self.setup_salt_pipeline(data_path, label_path)
self.setup_salt_restore()
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment