Skip to content
Snippets Groups Projects
Commit fc3d456d authored by tomrink's avatar tomrink
Browse files

deal with bigger data

parent ea55efe4
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ import h5py
LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 2048
PROC_BATCH_SIZE = 10
PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 2
......@@ -216,23 +216,38 @@ class UNET:
def get_in_mem_data_batch(self, idxs, is_training):
np.sort(idxs)
dat_files = []
lbl_files = []
if is_training:
f = self.train_data_files
nda = np.load(f)
data = nda[idxs, :, :, :]
for k in idxs:
f = self.train_data_files[k]
nda = np.load(f)
dat_files.append(nda)
f = self.train_label_files[k]
nda = np.load(f)
lbl_files.append(nda)
data = np.concatenate(dat_files)
label = np.concatenate(lbl_files)
f = self.train_label_files
nda = np.load(f)
label = nda[idxs, label_idx, :, :]
label = label[:, label_idx, :, :]
label = np.expand_dims(label, axis=3)
else:
f = self.test_data_files
nda = np.load(f)
data = nda[idxs, :, :, :]
for k in idxs:
f = self.train_data_files[k]
nda = np.load(f)
dat_files.append(nda)
f = self.test_label_files
nda = np.load(f)
label = nda[idxs, label_idx, :, :]
f = self.train_label_files[k]
nda = np.load(f)
lbl_files.append(nda)
data = np.concatenate(dat_files)
label = np.concatenate(lbl_files)
label = label[:, label_idx, :, :]
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
......@@ -363,14 +378,14 @@ class UNET:
self.test_data_files = test_data_files
self.test_label_files = test_label_files
trn_idxs = np.arange(20925)
trn_idxs = np.arange(len(train_data_files))
np.random.shuffle(trn_idxs)
tst_idxs = np.arange(2325)
tst_idxs = np.arange(len(train_data_files))
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
self.num_data_samples = 20925
self.num_data_samples = 56000 # approximately
print('datetime: ', now)
print('training and test data: ')
......@@ -880,11 +895,13 @@ class UNET:
self.build_evaluation()
self.do_training()
# def run_test(self, directory):
# data_files = glob.glob(directory+'l1b_*.npy')
# label_files = [f.replace('l1b', 'l2') for f in data_files]
def run_test(self, train_data_files, train_label_files, vld_data_files, vld_label_files):
self.setup_pipeline_files(train_data_files, train_label_files, vld_data_files, vld_label_files)
def run_test(self, directory):
train_data_files = glob.glob(directory+'data_train*.npy')
valid_data_files = glob.glob(directory+'data_valid*.npy')
train_label_files = glob.glob(directory+'label_train*.npy')
valid_label_files = glob.glob(directory+'label_valid*.npy')
self.setup_pipeline_files(train_data_files, train_label_files, valid_data_files, valid_label_files)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment