From e38255e6e7e506794ae6f70c91d48b2cf51dffbd Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 6 Jun 2022 09:41:10 -0500 Subject: [PATCH] minor... --- modules/deeplearning/unet_l1b_l2.py | 47 +++++++++++------------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/modules/deeplearning/unet_l1b_l2.py b/modules/deeplearning/unet_l1b_l2.py index fc418e80..aaae7b8c 100644 --- a/modules/deeplearning/unet_l1b_l2.py +++ b/modules/deeplearning/unet_l1b_l2.py @@ -214,40 +214,29 @@ class UNET: # print(e) def get_in_mem_data_batch(self, idxs, is_training): - - dat_files = [] - lbl_files = [] - if is_training: - for k in idxs: - f = self.train_data_files[k] - nda = np.load(f) - dat_files.append(nda) - - f = self.train_label_files[k] - nda = np.load(f) - lbl_files.append(nda) - - data = np.concatenate(dat_files) - label = np.concatenate(lbl_files) - - label = label[:, label_idx, :, :] - label = np.expand_dims(label, axis=3) + data_files = self.train_data_files + label_files = self.train_label_files else: - for k in idxs: - f = self.test_data_files[k] - nda = np.load(f) - dat_files.append(nda) + data_files = self.test_data_files + label_files = self.test_label_files + + data_s = [] + label_s = [] + for k in idxs: + f = data_files[k] + nda = np.load(f) + data_s.append(nda) - f = self.test_label_files[k] - nda = np.load(f) - lbl_files.append(nda) + f = label_files[k] + nda = np.load(f) + label_s.append(nda) - data = np.concatenate(dat_files) - label = np.concatenate(lbl_files) + data = np.concatenate(data_s) + label = np.concatenate(label_s) - label = label[:, label_idx, :, :] - label = np.expand_dims(label, axis=3) + label = label[:, label_idx, :, :] + label = np.expand_dims(label, axis=3) data = data.astype(np.float32) label = label.astype(np.float32) -- GitLab