Skip to content
Snippets Groups Projects
Commit fbc47efc authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e90508d1
Branches
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ LOG_DEVICE_PLACEMENT = False ...@@ -14,7 +14,7 @@ LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 4 PROC_BATCH_SIZE = 4
PROC_BATCH_BUFFER_SIZE = 5000 PROC_BATCH_BUFFER_SIZE = 5000
NumClasses = 5 NumClasses = 2
if NumClasses == 2: if NumClasses == 2:
NumLogits = 1 NumLogits = 1
else: else:
...@@ -499,7 +499,6 @@ class SRCNN: ...@@ -499,7 +499,6 @@ class SRCNN:
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def train_step(self, inputs, labels): def train_step(self, inputs, labels):
labels = tf.squeeze(labels, axis=[3])
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
pred = self.model([inputs], training=True) pred = self.model([inputs], training=True)
loss = self.loss(labels, pred) loss = self.loss(labels, pred)
...@@ -519,7 +518,6 @@ class SRCNN: ...@@ -519,7 +518,6 @@ class SRCNN:
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def test_step(self, inputs, labels): def test_step(self, inputs, labels):
labels = tf.squeeze(labels, axis=[3])
pred = self.model([inputs], training=False) pred = self.model([inputs], training=False)
t_loss = self.loss(labels, pred) t_loss = self.loss(labels, pred)
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment