From 21e9b5c5fde1261b10146a360e9704bbb800d092 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 5 Sep 2022 18:00:11 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/espcn_l1b_l2.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/espcn_l1b_l2.py b/modules/deeplearning/espcn_l1b_l2.py
index bed757e2..cc80b373 100644
--- a/modules/deeplearning/espcn_l1b_l2.py
+++ b/modules/deeplearning/espcn_l1b_l2.py
@@ -310,10 +310,12 @@ class ESPCN:
         dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
         self.eval_dataset = dataset
 
-    def setup_pipeline(self, train_data_files, test_data_files, num_train_samples):
+    def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
 
         self.train_data_files = train_data_files
         self.test_data_files = test_data_files
+        self.train_label_files = train_label_files
+        self.test_label_files = test_label_files
 
         trn_idxs = np.arange(len(train_data_files))
         np.random.shuffle(trn_idxs)
@@ -633,11 +635,13 @@ class ESPCN:
     def run(self, directory):
         train_data_files = glob.glob(directory+'data_train*.npy')
         valid_data_files = glob.glob(directory+'data_valid*.npy')
+        train_label_files = glob.glob(directory+'label_train*.npy')
+        valid_label_files = glob.glob(directory+'label_valid*.npy')
 
         train_data_files.sort()
         valid_data_files.sort()
 
-        self.setup_pipeline(train_data_files, valid_data_files, 100000)
+        self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 100000)
         self.build_model()
         self.build_training()
         self.build_evaluation()
-- 
GitLab