diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 7716228eb0f324ba2fe3ea1834a7accb3a25d604..24a6fc44cf5b4e90d4b73b032474e44897a0f6d5 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -638,11 +638,11 @@ class SRCNN: return pred - def run(self, directory, ckpt_dir=None): + def run(self, directory, ckpt_dir=None, num_data_samples=50000): train_data_files = glob.glob(directory+'data_train*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy') - self.setup_pipeline(train_data_files, valid_data_files, 50000) + self.setup_pipeline(train_data_files, valid_data_files, num_data_samples) self.build_model() self.build_training() self.build_evaluation()