From 412b207c768d543d91286db85a608aedda30c6a4 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 22 Aug 2022 09:51:18 -0500
Subject: [PATCH] snapshot..

---
 modules/deeplearning/srcnn.py | 30 ++++++++++--------------------
 1 file changed, 10 insertions(+), 20 deletions(-)

diff --git a/modules/deeplearning/srcnn.py b/modules/deeplearning/srcnn.py
index 30f95980..8e10dad5 100644
--- a/modules/deeplearning/srcnn.py
+++ b/modules/deeplearning/srcnn.py
@@ -327,9 +327,10 @@ class SRCNN:
         print('num test samples: ', tst_idxs.shape[0])
         print('setup_pipeline: Done')
 
-    def setup_test_pipeline(self, filename):
-        self.test_data_files = [filename]
-        self.get_test_dataset([0])
+    def setup_test_pipeline(self, test_data_files):
+        self.test_data_files = test_data_files
+        tst_idxs = np.arange(len(test_data_files))
+        self.get_test_dataset(tst_idxs)
         print('setup_test_pipeline: Done')
 
     def setup_eval_pipeline(self, filename):
@@ -602,27 +603,14 @@ class SRCNN:
 
         self.reset_test_metrics()
 
-        for data0, data1, label in self.test_dataset:
-            ds = tf.data.Dataset.from_tensor_slices((data0, data1, label))
+        for data, label in self.test_dataset:
+            ds = tf.data.Dataset.from_tensor_slices((data, label))
             ds = ds.batch(BATCH_SIZE)
             for mini_batch_test in ds:
                 self.predict(mini_batch_test)
 
         print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
 
-        labels = np.concatenate(self.test_labels)
-        self.test_labels = labels
-
-        preds = np.concatenate(self.test_preds)
-        self.test_probs = preds
-
-        if NumClasses == 2:
-            preds = np.where(preds > 0.5, 1, 0)
-        else:
-            preds = np.argmax(preds, axis=1)
-
-        self.test_preds = preds
-
     def do_evaluate(self, nda_lr, param, ckpt_dir):
 
         ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
@@ -655,8 +643,10 @@ class SRCNN:
         self.build_evaluation()
         self.do_training()
 
-    def run_restore(self, filename, ckpt_dir):
-        self.setup_test_pipeline(filename)
+    def run_restore(self, directory, ckpt_dir):
+        valid_data_files = glob.glob(directory + 'data_valid*.npy')
+        self.num_data_samples = 1000
+        self.setup_test_pipeline(valid_data_files)
         self.build_model()
         self.build_training()
         self.build_evaluation()
-- 
GitLab