From e38255e6e7e506794ae6f70c91d48b2cf51dffbd Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 6 Jun 2022 09:41:10 -0500
Subject: [PATCH] minor...

---
 modules/deeplearning/unet_l1b_l2.py | 47 +++++++++++------------------
 1 file changed, 18 insertions(+), 29 deletions(-)

diff --git a/modules/deeplearning/unet_l1b_l2.py b/modules/deeplearning/unet_l1b_l2.py
index fc418e80..aaae7b8c 100644
--- a/modules/deeplearning/unet_l1b_l2.py
+++ b/modules/deeplearning/unet_l1b_l2.py
@@ -214,40 +214,29 @@ class UNET:
         #         print(e)
 
     def get_in_mem_data_batch(self, idxs, is_training):
-
-        dat_files = []
-        lbl_files = []
-
         if is_training:
-            for k in idxs:
-                f = self.train_data_files[k]
-                nda = np.load(f)
-                dat_files.append(nda)
-
-                f = self.train_label_files[k]
-                nda = np.load(f)
-                lbl_files.append(nda)
-
-            data = np.concatenate(dat_files)
-            label = np.concatenate(lbl_files)
-
-            label = label[:, label_idx, :, :]
-            label = np.expand_dims(label, axis=3)
+            data_files = self.train_data_files
+            label_files = self.train_label_files
         else:
-            for k in idxs:
-                f = self.test_data_files[k]
-                nda = np.load(f)
-                dat_files.append(nda)
+            data_files = self.test_data_files
+            label_files = self.test_label_files
+
+        data_s = []
+        label_s = []
+        for k in idxs:
+            f = data_files[k]
+            nda = np.load(f)
+            data_s.append(nda)
 
-                f = self.test_label_files[k]
-                nda = np.load(f)
-                lbl_files.append(nda)
+            f = label_files[k]
+            nda = np.load(f)
+            label_s.append(nda)
 
-            data = np.concatenate(dat_files)
-            label = np.concatenate(lbl_files)
+        data = np.concatenate(data_s)
+        label = np.concatenate(label_s)
 
-            label = label[:, label_idx, :, :]
-            label = np.expand_dims(label, axis=3)
+        label = label[:, label_idx, :, :]
+        label = np.expand_dims(label, axis=3)
 
         data = data.astype(np.float32)
         label = label.astype(np.float32)
-- 
GitLab