Skip to content
Snippets Groups Projects
Commit 53b5ac7f authored by tomrink's avatar tomrink
Browse files

snapshot...don't unnecessarily reload chkpt files

parent c4a81c96
No related branches found
No related tags found
No related merge requests found
...@@ -967,10 +967,10 @@ class IcingIntensityNN: ...@@ -967,10 +967,10 @@ class IcingIntensityNN:
def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5): def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5):
if ckpt_dir is not None: # if is None, this has been done already # if ckpt_dir is not None: # if is None, this has been done already
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) # ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) # ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint) # ckpt.restore(ckpt_manager.latest_checkpoint)
self.reset_test_metrics() self.reset_test_metrics()
...@@ -1115,13 +1115,19 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l ...@@ -1115,13 +1115,19 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
probs_dct = {flvl: None for flvl in flight_levels} probs_dct = {flvl: None for flvl in flight_levels}
preds_dct = {flvl: None for flvl in flight_levels} preds_dct = {flvl: None for flvl in flight_levels}
for flvl in flight_levels:
nn = IcingIntensityNN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude) nn = IcingIntensityNN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude)
nn.flight_level = flvl nn.num_data_samples = num_tiles
nn.setup_eval_pipeline(data_dct, num_tiles)
nn.build_model() nn.build_model()
nn.build_training() nn.build_training()
nn.build_evaluation() nn.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
for flvl in flight_levels:
nn.flight_level = flvl
nn.setup_eval_pipeline(data_dct, num_tiles)
nn.do_evaluate(ckpt_dir) nn.do_evaluate(ckpt_dir)
probs = nn.test_probs probs = nn.test_probs
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment