diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index a664f450c4530511f291719dd4240bdce420f87c..98493282260142563ddfa7e7cb00b6ad0f73123d 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -1103,7 +1103,7 @@ def run_evaluate_static(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY', flig def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_night='DAY', flight_levels=[0, 1, 2, 3, 4], prob_thresh=0.5): ckpt_dir_s = os.listdir(ckpt_dir_s_path) - ckpt_dir = ckpt_dir_s[0] + ckpt_dir = ckpt_dir_s_path + ckpt_dir_s[0] probs_2d_dct = {flvl: None for flvl in flight_levels} preds_2d_dct = {flvl: None for flvl in flight_levels} @@ -1115,12 +1115,7 @@ def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, day nn.build_model() nn.build_training() nn.build_evaluation() - - ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model) - ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) - ckpt.restore(ckpt_manager.latest_checkpoint) - - nn.do_evaluate() + nn.do_evaluate(ckpt_dir) probs = nn.test_probs if NumClasses == 2: