diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 5e61e92bdf68e418263c3a1cf513f0f21992d724..136de6fe0efb91c638d2d894cdf0bcd86cff802a 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -641,7 +641,7 @@ class SRCNN: return pred - def run(self, directory): + def run(self, directory, ckpt_dir=None): train_data_files = glob.glob(directory+'data_train*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy') @@ -649,7 +649,7 @@ class SRCNN: self.build_model() self.build_training() self.build_evaluation() - self.do_training() + self.do_training(ckpt_dir=ckpt_dir) def run_restore(self, directory, ckpt_dir): valid_data_files = glob.glob(directory + 'data_*.npy')