diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py
index 3820f8beaa4d9a582457d846816074d0b1bb1ebd..10d32ca21954aeeeaaff8a516b584f0f2c3705a8 100644
--- a/modules/deeplearning/srcnn_l1b_l2.py
+++ b/modules/deeplearning/srcnn_l1b_l2.py
@@ -488,12 +488,10 @@ class SRCNN:
         self.train_loss = tf.keras.metrics.Mean(name='train_loss')
         self.test_loss = tf.keras.metrics.Mean(name='test_loss')
 
-    @tf.function
-    def train_step(self, mini_batch):
-        inputs = [mini_batch[0]]
-        labels = mini_batch[1]
+    @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
+    def train_step(self, inputs, labels):
         with tf.GradientTape() as tape:
-            pred = self.model(inputs, training=True)
+            pred = self.model([inputs], training=True)
             loss = self.loss(labels, pred)
             total_loss = loss
             if len(self.model.losses) > 0:
@@ -509,11 +507,9 @@ class SRCNN:
 
         return loss
 
-    @tf.function
-    def test_step(self, mini_batch):
-        inputs = [mini_batch[0]]
-        labels = mini_batch[1]
-        pred = self.model(inputs, training=False)
+    @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
+    def test_step(self, inputs, labels):
+        pred = self.model([inputs], training=False)
         t_loss = self.loss(labels, pred)
 
         self.test_loss(t_loss)
@@ -585,7 +581,7 @@ class SRCNN:
                 trn_ds = trn_ds.batch(BATCH_SIZE)
                 for mini_batch in trn_ds:
                     if self.learningRateSchedule is not None:
-                        loss = self.train_step(mini_batch)
+                        loss = self.train_step(mini_batch[0], mini_batch[1])
 
                     if (step % 100) == 0:
 
@@ -600,7 +596,7 @@ class SRCNN:
                             tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
                             tst_ds = tst_ds.batch(BATCH_SIZE)
                             for mini_batch_test in tst_ds:
-                                self.test_step(mini_batch_test)
+                                self.test_step(mini_batch_test[0], mini_batch_test[1])
 
                         with self.writer_valid.as_default():
                             tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
@@ -629,7 +625,7 @@ class SRCNN:
                 ds = tf.data.Dataset.from_tensor_slices((data, label))
                 ds = ds.batch(BATCH_SIZE)
                 for mini_batch in ds:
-                    self.test_step(mini_batch)
+                    self.test_step(mini_batch[0], mini_batch[1])
 
             print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')