From 53b5ac7f2644693e1a949d00a1b49d1bf17d4d7f Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 10 Feb 2022 13:10:17 -0600
Subject: [PATCH] snapshot...don't unnecessarily reload chkpt files

---
 modules/deeplearning/icing_cnn.py | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index dfe113af..15b3bafc 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -967,10 +967,10 @@ class IcingIntensityNN:
 
     def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5):
 
-        if ckpt_dir is not None:  # if is None, this has been done already
-            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
-            ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
-            ckpt.restore(ckpt_manager.latest_checkpoint)
+        # if ckpt_dir is not None:  # if is None, this has been done already
+        #     ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
+        #     ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
+        #     ckpt.restore(ckpt_manager.latest_checkpoint)
 
         self.reset_test_metrics()
 
@@ -1115,13 +1115,19 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
     probs_dct = {flvl: None for flvl in flight_levels}
     preds_dct = {flvl: None for flvl in flight_levels}
 
+    nn = IcingIntensityNN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude)
+    nn.num_data_samples = num_tiles
+    nn.build_model()
+    nn.build_training()
+    nn.build_evaluation()
+
+    ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model)
+    ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
+    ckpt.restore(ckpt_manager.latest_checkpoint)
+
     for flvl in flight_levels:
-        nn = IcingIntensityNN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude)
         nn.flight_level = flvl
         nn.setup_eval_pipeline(data_dct, num_tiles)
-        nn.build_model()
-        nn.build_training()
-        nn.build_evaluation()
         nn.do_evaluate(ckpt_dir)
 
         probs = nn.test_probs
-- 
GitLab