diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 426c07aff1f66cbb7a1ea00462ba3dd7c4530193..3b01ad2cb5d48d8b771a8abbaadd833b98d4f039 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -1103,7 +1103,7 @@ def run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY', return ice_lons, ice_lats, preds_2d -def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_night='DAY', prob_thresh=0.5, +def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', prob_thresh=0.5, flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False): ckpt_dir_s = os.listdir(ckpt_dir_s_path) @@ -1112,13 +1112,13 @@ def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_nig if not use_flight_altitude: flight_levels = [0] - probs_2d_dct = {flvl: None for flvl in flight_levels} - preds_2d_dct = {flvl: None for flvl in flight_levels} + probs_dct = {flvl: None for flvl in flight_levels} + preds_dct = {flvl: None for flvl in flight_levels} for flvl in flight_levels: nn = IcingIntensityNN(day_night=day_night, use_flight_altitude=use_flight_altitude) nn.flight_level = flvl - nn.setup_eval_pipeline(data_dct, num_lines * num_elems) + nn.setup_eval_pipeline(data_dct, num_tiles) nn.build_model() nn.build_training() nn.build_evaluation() @@ -1129,13 +1129,11 @@ def run_evaluate_static(data_dct, num_lines, num_elems, ckpt_dir_s_path, day_nig preds = np.where(probs > prob_thresh, 1, 0) else: preds = np.argmax(probs, axis=1) - preds_2d = preds.reshape((num_lines, num_elems)) - probs_2d = probs.reshape((num_lines, num_elems)) - probs_2d_dct[flvl] = probs_2d - preds_2d_dct[flvl] = preds_2d + probs_dct[flvl] = probs + preds_dct[flvl] = preds - return preds_2d_dct, probs_2d_dct + return preds_dct, probs_dct if __name__ == "__main__":