diff --git a/modules/deeplearning/espcn_l1b_l2.py b/modules/deeplearning/espcn_l1b_l2.py index bed757e2177b0411732d411ab127363ca726119c..cc80b373424b63e0acfb94bf71804ef9bd12d8bc 100644 --- a/modules/deeplearning/espcn_l1b_l2.py +++ b/modules/deeplearning/espcn_l1b_l2.py @@ -310,10 +310,12 @@ class ESPCN: dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8) self.eval_dataset = dataset - def setup_pipeline(self, train_data_files, test_data_files, num_train_samples): + def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples): self.train_data_files = train_data_files self.test_data_files = test_data_files + self.train_label_files = train_label_files + self.test_label_files = test_label_files trn_idxs = np.arange(len(train_data_files)) np.random.shuffle(trn_idxs) @@ -633,11 +635,13 @@ class ESPCN: def run(self, directory): train_data_files = glob.glob(directory+'data_train*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy') + train_label_files = glob.glob(directory+'label_train*.npy') + valid_label_files = glob.glob(directory+'label_valid*.npy') train_data_files.sort() valid_data_files.sort() - self.setup_pipeline(train_data_files, valid_data_files, 100000) + self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 100000) self.build_model() self.build_training() self.build_evaluation()