diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index 9545cbb9e4afa1d4493c5975c533cef071de3005..fe7c878bebf6c25d088c67c769dfd30fbc1f0666 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -666,6 +666,15 @@ class CloudHeightNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) + def predict(self, mini_batch): + inputs = [mini_batch[0], mini_batch[1], mini_batch[3]] + labels = mini_batch[2] + pred = self.model(inputs, training=False) + t_loss = self.loss(labels, pred) + + self.test_loss(t_loss) + self.test_accuracy(labels, pred) + m = np.logical_and(labels >= 0.01, labels < 0.2) self.num_0 += np.sum(m) self.accuracy_0(labels[m], pred[m])