diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index 7eb9e21b4f52e20af28ac4cc472a6de13755b95b..a03d64df77cd450729b493ed22ee4d2e73f97cc5 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -326,9 +326,10 @@ class ESPCN: print('num test samples: ', tst_idxs.shape[0]) print('setup_pipeline: Done') - def setup_test_pipeline(self, filename): - self.test_data_files = [filename] - self.get_test_dataset([0]) + def setup_test_pipeline(self, test_data_files): + self.test_data_files = test_data_files + tst_idxs = np.arange(len(test_data_files)) + self.get_test_dataset(tst_idxs) print('setup_test_pipeline: Done') def setup_eval_pipeline(self, filename): @@ -658,9 +659,10 @@ class ESPCN: self.build_evaluation() self.do_training() - def run_restore(self, filename, ckpt_dir): + def run_restore(self, directory, ckpt_dir): + valid_data_files = glob.glob(directory + 'data_valid*.npy') self.num_data_samples = 1000 - self.setup_test_pipeline(filename) + self.setup_test_pipeline(valid_data_files) self.build_model() self.build_training() self.build_evaluation()