From d80f3440e03dfc37e8da3c9938dbe6d8a62fe65b Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Sun, 21 Aug 2022 11:51:56 -0500
Subject: [PATCH] snapshot..

---
 modules/deeplearning/espcn.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py
index 7eb9e21b..a03d64df 100644
--- a/modules/deeplearning/espcn.py
+++ b/modules/deeplearning/espcn.py
@@ -326,9 +326,10 @@ class ESPCN:
         print('num test samples: ', tst_idxs.shape[0])
         print('setup_pipeline: Done')
 
-    def setup_test_pipeline(self, filename):
-        self.test_data_files = [filename]
-        self.get_test_dataset([0])
+    def setup_test_pipeline(self, test_data_files):
+        self.test_data_files = test_data_files
+        tst_idxs = np.arange(len(test_data_files))
+        self.get_test_dataset(tst_idxs)
         print('setup_test_pipeline: Done')
 
     def setup_eval_pipeline(self, filename):
@@ -658,9 +659,10 @@ class ESPCN:
         self.build_evaluation()
         self.do_training()
 
-    def run_restore(self, filename, ckpt_dir):
+    def run_restore(self, directory, ckpt_dir):
+        valid_data_files = glob.glob(directory + 'data_valid*.npy')
         self.num_data_samples = 1000
-        self.setup_test_pipeline(filename)
+        self.setup_test_pipeline(valid_data_files)
         self.build_model()
         self.build_training()
         self.build_evaluation()
-- 
GitLab