diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 02c30c6bfadc4c0ea3484cf6f9f27f282e5b3d4e..16c07c20d9b05553fba252a2b352a07a39bdba62 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -515,9 +515,7 @@ class SRCNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) - def predict(self, mini_batch): - inputs = [mini_batch[0]] - labels = mini_batch[1] + def predict(self, inputs, labels): pred = self.model([inputs], training=False) t_loss = self.loss(labels, pred) @@ -659,7 +657,7 @@ class SRCNN: ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: - self.predict(mini_batch_test) + self.predict(mini_batch_test[0], mini_batch_test[1]) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())