From 3245acdf564e247133fb90cd5e19efc32f5f1352 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 27 Oct 2021 10:44:25 -0500
Subject: [PATCH] new methods for running predictions

---
 modules/deeplearning/icing_cnn.py | 43 +++++++++++++++++++++++++++----
 1 file changed, 38 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index d8cc5523..afb0cf37 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -1,6 +1,6 @@
 import tensorflow as tf
 import tensorflow_addons as tfa
-from util.setup import logdir, modeldir, cachepath, now, ewa_varsdir
+from util.setup import logdir, modeldir, cachepath, now
 from util.util import homedir, EarlyStop, normalize, make_for_full_domain_predict
 from util.geos_nav import get_navigation
 
@@ -939,12 +939,14 @@ class IcingIntensityNN:
 
         self.test_preds = preds
 
-    def do_evaluate(self, ckpt_dir, prob_thresh=0.5):
+    def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5):
 
-        ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
-        ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
+        if ckpt_dir is not None:  # if is None, this has been done already
+            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
+            ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
+            ckpt.restore(ckpt_manager.latest_checkpoint)
 
-        ckpt.restore(ckpt_manager.latest_checkpoint)
+        self.reset_test_metrics()
 
         pred_s = []
         for data in self.eval_dataset:
@@ -1079,6 +1081,37 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5, s
     return ice_lons, ice_lats, preds_2d, lons_2d, lats_2d, x_rad, y_rad
 
 
+def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5):
+
+    ckpt_dir_s = os.listdir(ckpt_dir_s_path)
+    ckpt_dir = ckpt_dir_s[0]
+
+    nn = IcingIntensityNN()
+    nn.flight_level = flight_level
+    nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
+    nn.build_model()
+    nn.build_training()
+    nn.build_evaluation()
+
+    ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model)
+    ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
+    ckpt.restore(ckpt_manager.latest_checkpoint)
+
+    for flvl in [0, 1, 2, 3, 4]:
+        nn.flight_level = flvl
+        nn.do_evaluate()
+        probs = nn.test_probs
+
+    if NumClasses == 2:
+        preds = np.where(probs > prob_thresh, 1, 0)
+    else:
+        preds = np.argmax(probs, axis=1)
+    preds_2d = preds.reshape((num_lines, num_elems))
+    probs_2d = probs.reshape((num_lines, num_elems))
+
+    return preds_2d, probs_2d
+
+
 if __name__ == "__main__":
     nn = IcingIntensityNN()
     nn.run('matchup_filename')
-- 
GitLab