Skip to content
Snippets Groups Projects
Commit f8e91300 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 31a4ecb1
Branches
No related tags found
No related merge requests found
......@@ -515,9 +515,7 @@ class SRCNN:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
def predict(self, mini_batch):
inputs = [mini_batch[0]]
labels = mini_batch[1]
def predict(self, inputs, labels):
pred = self.model([inputs], training=False)
t_loss = self.loss(labels, pred)
......@@ -659,7 +657,7 @@ class SRCNN:
ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.batch(BATCH_SIZE)
for mini_batch_test in ds:
self.predict(mini_batch_test)
self.predict(mini_batch_test[0], mini_batch_test[1])
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment