From 21e9b5c5fde1261b10146a360e9704bbb800d092 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Mon, 5 Sep 2022 18:00:11 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/espcn_l1b_l2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/deeplearning/espcn_l1b_l2.py b/modules/deeplearning/espcn_l1b_l2.py index bed757e2..cc80b373 100644 --- a/modules/deeplearning/espcn_l1b_l2.py +++ b/modules/deeplearning/espcn_l1b_l2.py @@ -310,10 +310,12 @@ class ESPCN: dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8) self.eval_dataset = dataset - def setup_pipeline(self, train_data_files, test_data_files, num_train_samples): + def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples): self.train_data_files = train_data_files self.test_data_files = test_data_files + self.train_label_files = train_label_files + self.test_label_files = test_label_files trn_idxs = np.arange(len(train_data_files)) np.random.shuffle(trn_idxs) @@ -633,11 +635,13 @@ class ESPCN: def run(self, directory): train_data_files = glob.glob(directory+'data_train*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy') + train_label_files = glob.glob(directory+'label_train*.npy') + valid_label_files = glob.glob(directory+'label_valid*.npy') train_data_files.sort() valid_data_files.sort() - self.setup_pipeline(train_data_files, valid_data_files, 100000) + self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 100000) self.build_model() self.build_training() self.build_evaluation() -- GitLab