diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 36acfb27e57280754afe2e68eba0140d4d05d16f..c8d9ff479143d29ecc87c713bbe0fe2e267732bf 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -1095,19 +1095,24 @@ def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, fli ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) ckpt.restore(ckpt_manager.latest_checkpoint) + probs_2d_s = [] + preds_2d_s = [] for flvl in [0, 1, 2, 3, 4]: nn.flight_level = flvl nn.do_evaluate() probs = nn.test_probs - if NumClasses == 2: - preds = np.where(probs > prob_thresh, 1, 0) - else: - preds = np.argmax(probs, axis=1) - preds_2d = preds.reshape((num_lines, num_elems)) - probs_2d = probs.reshape((num_lines, num_elems)) + if NumClasses == 2: + preds = np.where(probs > prob_thresh, 1, 0) + else: + preds = np.argmax(probs, axis=1) + preds_2d = preds.reshape((num_lines, num_elems)) + probs_2d = probs.reshape((num_lines, num_elems)) + + probs_2d_s.append(probs_2d) + preds_2d_s.append(preds_2d) - return preds_2d, probs_2d + return preds_2d_s, probs_2d_s if __name__ == "__main__":