Skip to content
Snippets Groups Projects
Commit e5a95c4a authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 9b594e17
Branches
No related tags found
No related merge requests found
......@@ -1080,29 +1080,28 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5, s
return ice_lons, ice_lats, preds_2d, lons_2d, lats_2d, x_rad, y_rad
def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5):
def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, flight_levels=[0, 1, 2, 3, 4], prob_thresh=0.5):
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
ckpt_dir = ckpt_dir_s[0]
nn = IcingIntensityNN()
nn.flight_level = flight_level
nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
nn.build_model()
nn.build_training()
nn.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
probs_2d_s = []
preds_2d_s = []
for flvl in [0, 1, 2, 3, 4]:
for flvl in flight_levels:
nn = IcingIntensityNN()
nn.flight_level = flvl
nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
nn.build_model()
nn.build_training()
nn.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
nn.do_evaluate()
probs = nn.test_probs
probs = nn.test_probs
if NumClasses == 2:
preds = np.where(probs > prob_thresh, 1, 0)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment