From abe9ab432e7723387d38a28b1ac53007e3cd80d6 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 17 May 2021 14:59:56 -0500
Subject: [PATCH] minor...

---
 modules/deeplearning/icing_cnn.py | 21 ++++++++++++++++++---
 1 file changed, 18 insertions(+), 3 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 9b3035a0..52b60959 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -297,6 +297,21 @@ class IcingIntensityNN:
         print('num test samples: ', tst_idxs.shape[0])
         print('setup_pipeline: Done')
 
+    def setup_test_pipeline(self, filename, seed=None):
+        self.filename_tst = filename
+        self.h5f_tst = h5py.File(filename, 'r')
+
+        time = self.h5f_tst['time']
+        tst_idxs = np.arange(time.shape[0])
+        if seed is not None:
+            np.random.seed(seed)
+        np.random.shuffle(tst_idxs)
+
+        self.get_test_dataset(tst_idxs)
+
+        print('num test samples: ', tst_idxs.shape[0])
+        print('setup_test_pipeline: Done')
+
     def build_1d_cnn(self):
         print('build_1d_cnn')
         # padding = 'VALID'
@@ -667,7 +682,7 @@ class IcingIntensityNN:
         self.test_labels = labels
         self.test_preds = preds
 
-    def run(self, filename_trn, filename_tst, filename_l1b=None):
+    def run(self, filename_trn, filename_tst):
         with tf.device('/device:GPU:'+str(self.gpu_device)):
             self.setup_pipeline(filename_trn, filename_tst)
             self.build_model()
@@ -675,8 +690,8 @@ class IcingIntensityNN:
             self.build_evaluation()
             self.do_training()
 
-    def run_restore(self, filename, ckpt_dir):
-        self.setup_pipeline(filename)
+    def run_restore(self, filename_tst, ckpt_dir):
+        self.setup_test_pipeline(filename_tst)
         self.build_model()
         self.build_training()
         self.build_evaluation()
-- 
GitLab