From 5b94b998e7c152164b362e9c3ef74dfd5792c25a Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 25 Jan 2022 21:43:33 -0600
Subject: [PATCH] minor changes to run_evaluate_static

---
 modules/deeplearning/icing_cnn.py | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 426c07af..3b01ad2c 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -1103,7 +1103,7 @@ def run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY',
     return ice_lons, ice_lats, preds_2d
 
 
-def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_night='DAY', prob_thresh=0.5,
+def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', prob_thresh=0.5,
                         flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False):
 
     ckpt_dir_s = os.listdir(ckpt_dir_s_path)
@@ -1112,13 +1112,13 @@ def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_nig
     if not use_flight_altitude:
         flight_levels = [0]
 
-    probs_2d_dct = {flvl: None for flvl in flight_levels}
-    preds_2d_dct = {flvl: None for flvl in flight_levels}
+    probs_dct = {flvl: None for flvl in flight_levels}
+    preds_dct = {flvl: None for flvl in flight_levels}
 
     for flvl in flight_levels:
         nn = IcingIntensityNN(day_night=day_night, use_flight_altitude=use_flight_altitude)
         nn.flight_level = flvl
-        nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
+        nn.setup_eval_pipeline(data_dct, num_tiles)
         nn.build_model()
         nn.build_training()
         nn.build_evaluation()
@@ -1129,13 +1129,11 @@ def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_nig
             preds = np.where(probs > prob_thresh, 1, 0)
         else:
             preds = np.argmax(probs, axis=1)
-        preds_2d = preds.reshape((num_lines, num_elems))
-        probs_2d = probs.reshape((num_lines, num_elems))
 
-        probs_2d_dct[flvl] = probs_2d
-        preds_2d_dct[flvl] = preds_2d
+        probs_dct[flvl] = probs
+        preds_dct[flvl] = preds
 
-    return preds_2d_dct, probs_2d_dct
+    return preds_dct, probs_dct
 
 
 if __name__ == "__main__":
-- 
GitLab