Skip to content
Snippets Groups Projects
Commit 9477cc98 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 5805939a
Branches
No related tags found
No related merge requests found
......@@ -602,7 +602,7 @@ class SRCNN:
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def train_step(self, inputs, labels):
labels = tf.squeeze(labels)
labels = tf.squeeze(labels, axis=[3])
with tf.GradientTape() as tape:
pred = self.model([inputs], training=True)
loss = self.loss(labels, pred)
......@@ -622,7 +622,7 @@ class SRCNN:
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def test_step(self, inputs, labels):
labels = tf.squeeze(labels)
labels = tf.squeeze(labels, axis=[3])
pred = self.model([inputs], training=False)
t_loss = self.loss(labels, pred)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment