From abe9ab432e7723387d38a28b1ac53007e3cd80d6 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 17 May 2021 14:59:56 -0500 Subject: [PATCH] minor... --- modules/deeplearning/icing_cnn.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 9b3035a0..52b60959 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -297,6 +297,21 @@ class IcingIntensityNN: print('num test samples: ', tst_idxs.shape[0]) print('setup_pipeline: Done') + def setup_test_pipeline(self, filename, seed=None): + self.filename_tst = filename + self.h5f_tst = h5py.File(filename, 'r') + + time = self.h5f_tst['time'] + tst_idxs = np.arange(time.shape[0]) + if seed is not None: + np.random.seed(seed) + np.random.shuffle(tst_idxs) + + self.get_test_dataset(tst_idxs) + + print('num test samples: ', tst_idxs.shape[0]) + print('setup_test_pipeline: Done') + def build_1d_cnn(self): print('build_1d_cnn') # padding = 'VALID' @@ -667,7 +682,7 @@ class IcingIntensityNN: self.test_labels = labels self.test_preds = preds - def run(self, filename_trn, filename_tst, filename_l1b=None): + def run(self, filename_trn, filename_tst): with tf.device('/device:GPU:'+str(self.gpu_device)): self.setup_pipeline(filename_trn, filename_tst) self.build_model() @@ -675,8 +690,8 @@ class IcingIntensityNN: self.build_evaluation() self.do_training() - def run_restore(self, filename, ckpt_dir): - self.setup_pipeline(filename) + def run_restore(self, filename_tst, ckpt_dir): + self.setup_test_pipeline(filename_tst) self.build_model() self.build_training() self.build_evaluation() -- GitLab