Skip to content
Snippets Groups Projects
Commit 8b5a0dc7 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent a8547977
Branches
No related tags found
No related merge requests found
...@@ -488,12 +488,10 @@ class SRCNN: ...@@ -488,12 +488,10 @@ class SRCNN:
self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss') self.test_loss = tf.keras.metrics.Mean(name='test_loss')
@tf.function @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def train_step(self, mini_batch): def train_step(self, inputs, labels):
inputs = [mini_batch[0]]
labels = mini_batch[1]
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
pred = self.model(inputs, training=True) pred = self.model([inputs], training=True)
loss = self.loss(labels, pred) loss = self.loss(labels, pred)
total_loss = loss total_loss = loss
if len(self.model.losses) > 0: if len(self.model.losses) > 0:
...@@ -509,11 +507,9 @@ class SRCNN: ...@@ -509,11 +507,9 @@ class SRCNN:
return loss return loss
@tf.function @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def test_step(self, mini_batch): def test_step(self, inputs, labels):
inputs = [mini_batch[0]] pred = self.model([inputs], training=False)
labels = mini_batch[1]
pred = self.model(inputs, training=False)
t_loss = self.loss(labels, pred) t_loss = self.loss(labels, pred)
self.test_loss(t_loss) self.test_loss(t_loss)
...@@ -585,7 +581,7 @@ class SRCNN: ...@@ -585,7 +581,7 @@ class SRCNN:
trn_ds = trn_ds.batch(BATCH_SIZE) trn_ds = trn_ds.batch(BATCH_SIZE)
for mini_batch in trn_ds: for mini_batch in trn_ds:
if self.learningRateSchedule is not None: if self.learningRateSchedule is not None:
loss = self.train_step(mini_batch) loss = self.train_step(mini_batch[0], mini_batch[1])
if (step % 100) == 0: if (step % 100) == 0:
...@@ -600,7 +596,7 @@ class SRCNN: ...@@ -600,7 +596,7 @@ class SRCNN:
tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst)) tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
tst_ds = tst_ds.batch(BATCH_SIZE) tst_ds = tst_ds.batch(BATCH_SIZE)
for mini_batch_test in tst_ds: for mini_batch_test in tst_ds:
self.test_step(mini_batch_test) self.test_step(mini_batch_test[0], mini_batch_test[1])
with self.writer_valid.as_default(): with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step) tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
...@@ -629,7 +625,7 @@ class SRCNN: ...@@ -629,7 +625,7 @@ class SRCNN:
ds = tf.data.Dataset.from_tensor_slices((data, label)) ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.batch(BATCH_SIZE) ds = ds.batch(BATCH_SIZE)
for mini_batch in ds: for mini_batch in ds:
self.test_step(mini_batch) self.test_step(mini_batch[0], mini_batch[1])
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------') print('------------------------------------------------------')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment