diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py
index 5e61e92bdf68e418263c3a1cf513f0f21992d724..136de6fe0efb91c638d2d894cdf0bcd86cff802a 100644
--- a/modules/deeplearning/srcnn_l1b_l2.py
+++ b/modules/deeplearning/srcnn_l1b_l2.py
@@ -641,7 +641,7 @@ class SRCNN:
 
         return pred
 
-    def run(self, directory):
+    def run(self, directory, ckpt_dir=None):
         train_data_files = glob.glob(directory+'data_train*.npy')
         valid_data_files = glob.glob(directory+'data_valid*.npy')
 
@@ -649,7 +649,7 @@ class SRCNN:
         self.build_model()
         self.build_training()
         self.build_evaluation()
-        self.do_training()
+        self.do_training(ckpt_dir=ckpt_dir)
 
     def run_restore(self, directory, ckpt_dir):
         valid_data_files = glob.glob(directory + 'data_*.npy')