Skip to content
Snippets Groups Projects
Commit 35e254c3 authored by tomrink's avatar tomrink
Browse files

minor...

parent afa0d82e
Branches
No related tags found
No related merge requests found
......@@ -309,35 +309,35 @@ class UNET:
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
self.eval_dataset = dataset
def setup_pipeline(self, data_nda, label_nda, perc=0.20):
num_samples = data_nda.shape[0]
num_test = int(num_samples * perc)
self.num_data_samples = num_samples - num_test
num_train = self.num_data_samples
self.train_data_nda = data_nda[0:num_train]
self.train_label_nda = label_nda[0:num_train]
self.test_data_nda = data_nda[num_train:]
self.test_label_nda = label_nda[num_train:]
trn_idxs = np.arange(self.train_data_nda.shape[0])
tst_idxs = np.arange(self.test_data_nda.shape[0])
np.random.shuffle(tst_idxs)
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
print('datetime: ', now)
print('training and test data: ')
print('---------------------------')
print('num train samples: ', self.num_data_samples)
print('BATCH SIZE: ', BATCH_SIZE)
print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done')
def setup_pipeline_files(self, data_files, label_files, perc=0.20):
# def setup_pipeline(self, data_nda, label_nda, perc=0.20):
#
# num_samples = data_nda.shape[0]
# num_test = int(num_samples * perc)
# self.num_data_samples = num_samples - num_test
# num_train = self.num_data_samples
#
# self.train_data_nda = data_nda[0:num_train]
# self.train_label_nda = label_nda[0:num_train]
# self.test_data_nda = data_nda[num_train:]
# self.test_label_nda = label_nda[num_train:]
#
# trn_idxs = np.arange(self.train_data_nda.shape[0])
# tst_idxs = np.arange(self.test_data_nda.shape[0])
#
# np.random.shuffle(tst_idxs)
#
# self.get_train_dataset(trn_idxs)
# self.get_test_dataset(tst_idxs)
#
# print('datetime: ', now)
# print('training and test data: ')
# print('---------------------------')
# print('num train samples: ', self.num_data_samples)
# print('BATCH SIZE: ', BATCH_SIZE)
# print('num test samples: ', tst_idxs.shape[0])
# print('setup_pipeline: Done')
def setup_pipeline(self, data_files, label_files, perc=0.20):
num_files = len(data_files)
num_test_files = int(num_files * perc)
num_train_files = num_files - num_test_files
......@@ -859,7 +859,7 @@ class UNET:
def run(self, directory):
data_files = glob.glob(directory+'mod_res*.npy')
label_files = [f.replace('mod', 'img') for f in data_files]
self.setup_pipeline_files(data_files, label_files)
self.setup_pipeline(data_files, label_files)
self.build_model()
self.build_training()
self.build_evaluation()
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment