From fc3d456dce3a67a7baaa8aa663625173af85ad43 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Tue, 24 May 2022 17:19:57 -0500 Subject: [PATCH] deal with bigger data --- modules/deeplearning/unet_l1b_l2.py | 59 +++++++++++++++++++---------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/modules/deeplearning/unet_l1b_l2.py b/modules/deeplearning/unet_l1b_l2.py index 192f60fe..0afb4728 100644 --- a/modules/deeplearning/unet_l1b_l2.py +++ b/modules/deeplearning/unet_l1b_l2.py @@ -13,7 +13,7 @@ import h5py LOG_DEVICE_PLACEMENT = False -PROC_BATCH_SIZE = 2048 +PROC_BATCH_SIZE = 10 PROC_BATCH_BUFFER_SIZE = 50000 NumClasses = 2 @@ -216,23 +216,38 @@ class UNET: def get_in_mem_data_batch(self, idxs, is_training): np.sort(idxs) + dat_files = [] + lbl_files = [] + if is_training: - f = self.train_data_files - nda = np.load(f) - data = nda[idxs, :, :, :] + for k in idxs: + f = self.train_data_files[k] + nda = np.load(f) + dat_files.append(nda) + + f = self.train_label_files[k] + nda = np.load(f) + lbl_files.append(nda) + + data = np.concatenate(dat_files) + label = np.concatenate(lbl_files) - f = self.train_label_files - nda = np.load(f) - label = nda[idxs, label_idx, :, :] + label = label[:, label_idx, :, :] label = np.expand_dims(label, axis=3) else: - f = self.test_data_files - nda = np.load(f) - data = nda[idxs, :, :, :] + for k in idxs: + f = self.train_data_files[k] + nda = np.load(f) + dat_files.append(nda) - f = self.test_label_files - nda = np.load(f) - label = nda[idxs, label_idx, :, :] + f = self.train_label_files[k] + nda = np.load(f) + lbl_files.append(nda) + + data = np.concatenate(dat_files) + label = np.concatenate(lbl_files) + + label = label[:, label_idx, :, :] label = np.expand_dims(label, axis=3) data = data.astype(np.float32) @@ -363,14 +378,14 @@ class UNET: self.test_data_files = test_data_files self.test_label_files = test_label_files - trn_idxs = np.arange(20925) + trn_idxs = np.arange(len(train_data_files)) np.random.shuffle(trn_idxs) - tst_idxs = np.arange(2325) + tst_idxs = np.arange(len(train_data_files)) self.get_train_dataset(trn_idxs) self.get_test_dataset(tst_idxs) - self.num_data_samples = 20925 + self.num_data_samples = 56000 # approximately print('datetime: ', now) print('training and test data: ') @@ -880,11 +895,13 @@ class UNET: self.build_evaluation() self.do_training() - # def run_test(self, directory): - # data_files = glob.glob(directory+'l1b_*.npy') - # label_files = [f.replace('l1b', 'l2') for f in data_files] - def run_test(self, train_data_files, train_label_files, vld_data_files, vld_label_files): - self.setup_pipeline_files(train_data_files, train_label_files, vld_data_files, vld_label_files) + def run_test(self, directory): + train_data_files = glob.glob(directory+'data_train*.npy') + valid_data_files = glob.glob(directory+'data_valid*.npy') + train_label_files = glob.glob(directory+'label_train*.npy') + valid_label_files = glob.glob(directory+'label_valid*.npy') + + self.setup_pipeline_files(train_data_files, train_label_files, valid_data_files, valid_label_files) self.build_model() self.build_training() self.build_evaluation() -- GitLab