Skip to content
Snippets Groups Projects
Commit fc3d456d authored by tomrink's avatar tomrink
Browse files

deal with bigger data

parent ea55efe4
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ import h5py ...@@ -13,7 +13,7 @@ import h5py
LOG_DEVICE_PLACEMENT = False LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 2048 PROC_BATCH_SIZE = 10
PROC_BATCH_BUFFER_SIZE = 50000 PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 2 NumClasses = 2
...@@ -216,23 +216,38 @@ class UNET: ...@@ -216,23 +216,38 @@ class UNET:
def get_in_mem_data_batch(self, idxs, is_training): def get_in_mem_data_batch(self, idxs, is_training):
np.sort(idxs) np.sort(idxs)
dat_files = []
lbl_files = []
if is_training: if is_training:
f = self.train_data_files for k in idxs:
nda = np.load(f) f = self.train_data_files[k]
data = nda[idxs, :, :, :] nda = np.load(f)
dat_files.append(nda)
f = self.train_label_files[k]
nda = np.load(f)
lbl_files.append(nda)
data = np.concatenate(dat_files)
label = np.concatenate(lbl_files)
f = self.train_label_files label = label[:, label_idx, :, :]
nda = np.load(f)
label = nda[idxs, label_idx, :, :]
label = np.expand_dims(label, axis=3) label = np.expand_dims(label, axis=3)
else: else:
f = self.test_data_files for k in idxs:
nda = np.load(f) f = self.train_data_files[k]
data = nda[idxs, :, :, :] nda = np.load(f)
dat_files.append(nda)
f = self.test_label_files f = self.train_label_files[k]
nda = np.load(f) nda = np.load(f)
label = nda[idxs, label_idx, :, :] lbl_files.append(nda)
data = np.concatenate(dat_files)
label = np.concatenate(lbl_files)
label = label[:, label_idx, :, :]
label = np.expand_dims(label, axis=3) label = np.expand_dims(label, axis=3)
data = data.astype(np.float32) data = data.astype(np.float32)
...@@ -363,14 +378,14 @@ class UNET: ...@@ -363,14 +378,14 @@ class UNET:
self.test_data_files = test_data_files self.test_data_files = test_data_files
self.test_label_files = test_label_files self.test_label_files = test_label_files
trn_idxs = np.arange(20925) trn_idxs = np.arange(len(train_data_files))
np.random.shuffle(trn_idxs) np.random.shuffle(trn_idxs)
tst_idxs = np.arange(2325) tst_idxs = np.arange(len(train_data_files))
self.get_train_dataset(trn_idxs) self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs) self.get_test_dataset(tst_idxs)
self.num_data_samples = 20925 self.num_data_samples = 56000 # approximately
print('datetime: ', now) print('datetime: ', now)
print('training and test data: ') print('training and test data: ')
...@@ -880,11 +895,13 @@ class UNET: ...@@ -880,11 +895,13 @@ class UNET:
self.build_evaluation() self.build_evaluation()
self.do_training() self.do_training()
# def run_test(self, directory): def run_test(self, directory):
# data_files = glob.glob(directory+'l1b_*.npy') train_data_files = glob.glob(directory+'data_train*.npy')
# label_files = [f.replace('l1b', 'l2') for f in data_files] valid_data_files = glob.glob(directory+'data_valid*.npy')
def run_test(self, train_data_files, train_label_files, vld_data_files, vld_label_files): train_label_files = glob.glob(directory+'label_train*.npy')
self.setup_pipeline_files(train_data_files, train_label_files, vld_data_files, vld_label_files) valid_label_files = glob.glob(directory+'label_valid*.npy')
self.setup_pipeline_files(train_data_files, train_label_files, valid_data_files, valid_label_files)
self.build_model() self.build_model()
self.build_training() self.build_training()
self.build_evaluation() self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment