diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 7fa5f375e7bc19ef4ba0ce0507c25238164b8e2b..b7bc8a759e5e263ae898e5b636a797be1717a7f6 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -1085,8 +1085,9 @@ def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, fli ckpt_dir_s = os.listdir(ckpt_dir_s_path) ckpt_dir = ckpt_dir_s[0] - probs_2d_s = [] - preds_2d_s = [] + probs_2d_dct = {flvl: None for flvl in flight_levels} + preds_2d_dct = {flvl: None for flvl in flight_levels} + for flvl in flight_levels: nn = IcingIntensityNN() nn.flight_level = flvl @@ -1109,10 +1110,10 @@ def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, fli preds_2d = preds.reshape((num_lines, num_elems)) probs_2d = probs.reshape((num_lines, num_elems)) - probs_2d_s.append(probs_2d) - preds_2d_s.append(preds_2d) + probs_2d_dct[flvl] = probs_2d + preds_2d_dct[flvl] = preds_2d - return preds_2d_s, probs_2d_s + return preds_2d_dct, probs_2d_dct if __name__ == "__main__":