diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index fe7c878bebf6c25d088c67c769dfd30fbc1f0666..0210c863c46ab5dfc6863a802725d11485235755 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -637,7 +637,7 @@ class CloudHeightNN: self.variable_averages = tf.train.ExponentialMovingAverage(0.999, self.global_step) self.variable_averages.apply(self.model.trainable_variables) - # @tf.function + @tf.function def train_step(self, mini_batch): inputs = [mini_batch[0], mini_batch[1], mini_batch[3]] labels = mini_batch[2] @@ -656,7 +656,7 @@ class CloudHeightNN: return loss - # @tf.function + @tf.function def test_step(self, mini_batch): inputs = [mini_batch[0], mini_batch[1], mini_batch[3]] labels = mini_batch[2] @@ -806,7 +806,7 @@ class CloudHeightNN: ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: - self.test_step(mini_batch_test) + self.predict(mini_batch_test) print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result()) print('acc_0', self.num_0, self.accuracy_0.result()) print('acc_1', self.num_1, self.accuracy_1.result())