From 412b207c768d543d91286db85a608aedda30c6a4 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 22 Aug 2022 09:51:18 -0500 Subject: [PATCH] snapshot.. --- modules/deeplearning/srcnn.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/modules/deeplearning/srcnn.py b/modules/deeplearning/srcnn.py index 30f95980..8e10dad5 100644 --- a/modules/deeplearning/srcnn.py +++ b/modules/deeplearning/srcnn.py @@ -327,9 +327,10 @@ class SRCNN: print('num test samples: ', tst_idxs.shape[0]) print('setup_pipeline: Done') - def setup_test_pipeline(self, filename): - self.test_data_files = [filename] - self.get_test_dataset([0]) + def setup_test_pipeline(self, test_data_files): + self.test_data_files = test_data_files + tst_idxs = np.arange(len(test_data_files)) + self.get_test_dataset(tst_idxs) print('setup_test_pipeline: Done') def setup_eval_pipeline(self, filename): @@ -602,27 +603,14 @@ class SRCNN: self.reset_test_metrics() - for data0, data1, label in self.test_dataset: - ds = tf.data.Dataset.from_tensor_slices((data0, data1, label)) + for data, label in self.test_dataset: + ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: self.predict(mini_batch_test) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) - labels = np.concatenate(self.test_labels) - self.test_labels = labels - - preds = np.concatenate(self.test_preds) - self.test_probs = preds - - if NumClasses == 2: - preds = np.where(preds > 0.5, 1, 0) - else: - preds = np.argmax(preds, axis=1) - - self.test_preds = preds - def do_evaluate(self, nda_lr, param, ckpt_dir): ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) @@ -655,8 +643,10 @@ class SRCNN: self.build_evaluation() self.do_training() - def run_restore(self, filename, ckpt_dir): - self.setup_test_pipeline(filename) + def run_restore(self, directory, ckpt_dir): + valid_data_files = glob.glob(directory + 'data_valid*.npy') + self.num_data_samples = 1000 + self.setup_test_pipeline(valid_data_files) self.build_model() self.build_training() self.build_evaluation() -- GitLab