diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 3820f8beaa4d9a582457d846816074d0b1bb1ebd..10d32ca21954aeeeaaff8a516b584f0f2c3705a8 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -488,12 +488,10 @@ class SRCNN: self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.test_loss = tf.keras.metrics.Mean(name='test_loss') - @tf.function - def train_step(self, mini_batch): - inputs = [mini_batch[0]] - labels = mini_batch[1] + @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) + def train_step(self, inputs, labels): with tf.GradientTape() as tape: - pred = self.model(inputs, training=True) + pred = self.model([inputs], training=True) loss = self.loss(labels, pred) total_loss = loss if len(self.model.losses) > 0: @@ -509,11 +507,9 @@ class SRCNN: return loss - @tf.function - def test_step(self, mini_batch): - inputs = [mini_batch[0]] - labels = mini_batch[1] - pred = self.model(inputs, training=False) + @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) + def test_step(self, inputs, labels): + pred = self.model([inputs], training=False) t_loss = self.loss(labels, pred) self.test_loss(t_loss) @@ -585,7 +581,7 @@ class SRCNN: trn_ds = trn_ds.batch(BATCH_SIZE) for mini_batch in trn_ds: if self.learningRateSchedule is not None: - loss = self.train_step(mini_batch) + loss = self.train_step(mini_batch[0], mini_batch[1]) if (step % 100) == 0: @@ -600,7 +596,7 @@ class SRCNN: tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst)) tst_ds = tst_ds.batch(BATCH_SIZE) for mini_batch_test in tst_ds: - self.test_step(mini_batch_test) + self.test_step(mini_batch_test[0], mini_batch_test[1]) with self.writer_valid.as_default(): tf.summary.scalar('loss_val', self.test_loss.result(), step=step) @@ -629,7 +625,7 @@ class SRCNN: ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = ds.batch(BATCH_SIZE) for mini_batch in ds: - self.test_step(mini_batch) + self.test_step(mini_batch[0], mini_batch[1]) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('------------------------------------------------------')