From 53b5ac7f2644693e1a949d00a1b49d1bf17d4d7f Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Thu, 10 Feb 2022 13:10:17 -0600 Subject: [PATCH] snapshot...don't unnecessarily reload chkpt files --- modules/deeplearning/icing_cnn.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index dfe113af..15b3bafc 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -967,10 +967,10 @@ class IcingIntensityNN: def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5): - if ckpt_dir is not None: # if is None, this has been done already - ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) - ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) - ckpt.restore(ckpt_manager.latest_checkpoint) + # if ckpt_dir is not None: # if is None, this has been done already + # ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) + # ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + # ckpt.restore(ckpt_manager.latest_checkpoint) self.reset_test_metrics() @@ -1115,13 +1115,19 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l probs_dct = {flvl: None for flvl in flight_levels} preds_dct = {flvl: None for flvl in flight_levels} + nn = IcingIntensityNN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude) + nn.num_data_samples = num_tiles + nn.build_model() + nn.build_training() + nn.build_evaluation() + + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model) + ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + ckpt.restore(ckpt_manager.latest_checkpoint) + for flvl in flight_levels: - nn = IcingIntensityNN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude) nn.flight_level = flvl nn.setup_eval_pipeline(data_dct, num_tiles) - nn.build_model() - nn.build_training() - nn.build_evaluation() nn.do_evaluate(ckpt_dir) probs = nn.test_probs -- GitLab