Skip to content
Snippets Groups Projects
Commit 5b94b998 authored by tomrink's avatar tomrink
Browse files

minor changes to run_evaluate_static

parent 1aea9a0f
No related branches found
No related tags found
No related merge requests found
......@@ -1103,7 +1103,7 @@ def run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY',
return ice_lons, ice_lats, preds_2d
def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_night='DAY', prob_thresh=0.5,
def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', prob_thresh=0.5,
flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False):
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
......@@ -1112,13 +1112,13 @@ def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_nig
if not use_flight_altitude:
flight_levels = [0]
probs_2d_dct = {flvl: None for flvl in flight_levels}
preds_2d_dct = {flvl: None for flvl in flight_levels}
probs_dct = {flvl: None for flvl in flight_levels}
preds_dct = {flvl: None for flvl in flight_levels}
for flvl in flight_levels:
nn = IcingIntensityNN(day_night=day_night, use_flight_altitude=use_flight_altitude)
nn.flight_level = flvl
nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
nn.setup_eval_pipeline(data_dct, num_tiles)
nn.build_model()
nn.build_training()
nn.build_evaluation()
......@@ -1129,13 +1129,11 @@ def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_nig
preds = np.where(probs > prob_thresh, 1, 0)
else:
preds = np.argmax(probs, axis=1)
preds_2d = preds.reshape((num_lines, num_elems))
probs_2d = probs.reshape((num_lines, num_elems))
probs_2d_dct[flvl] = probs_2d
preds_2d_dct[flvl] = preds_2d
probs_dct[flvl] = probs
preds_dct[flvl] = preds
return preds_2d_dct, probs_2d_dct
return preds_dct, probs_dct
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment