From d23adae66fd284ceed3216a10cbdf677345a63fa Mon Sep 17 00:00:00 2001 From: rink <rink@ssec.wisc.edu> Date: Wed, 23 Sep 2020 09:57:20 -0500 Subject: [PATCH] add separate predict method --- modules/deeplearning/cloudheight.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index 9545cbb9..fe7c878b 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -666,6 +666,15 @@ class CloudHeightNN: self.test_loss(t_loss) self.test_accuracy(labels, pred) + def predict(self, mini_batch): + inputs = [mini_batch[0], mini_batch[1], mini_batch[3]] + labels = mini_batch[2] + pred = self.model(inputs, training=False) + t_loss = self.loss(labels, pred) + + self.test_loss(t_loss) + self.test_accuracy(labels, pred) + m = np.logical_and(labels >= 0.01, labels < 0.2) self.num_0 += np.sum(m) self.accuracy_0(labels[m], pred[m]) -- GitLab