Skip to content
Snippets Groups Projects
Commit fe16edde authored by tomrink's avatar tomrink
Browse files

snapshot...

parent adc02cea
No related branches found
No related tags found
No related merge requests found
...@@ -1095,19 +1095,24 @@ def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, fli ...@@ -1095,19 +1095,24 @@ def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, fli
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint) ckpt.restore(ckpt_manager.latest_checkpoint)
probs_2d_s = []
preds_2d_s = []
for flvl in [0, 1, 2, 3, 4]: for flvl in [0, 1, 2, 3, 4]:
nn.flight_level = flvl nn.flight_level = flvl
nn.do_evaluate() nn.do_evaluate()
probs = nn.test_probs probs = nn.test_probs
if NumClasses == 2: if NumClasses == 2:
preds = np.where(probs > prob_thresh, 1, 0) preds = np.where(probs > prob_thresh, 1, 0)
else: else:
preds = np.argmax(probs, axis=1) preds = np.argmax(probs, axis=1)
preds_2d = preds.reshape((num_lines, num_elems)) preds_2d = preds.reshape((num_lines, num_elems))
probs_2d = probs.reshape((num_lines, num_elems)) probs_2d = probs.reshape((num_lines, num_elems))
probs_2d_s.append(probs_2d)
preds_2d_s.append(preds_2d)
return preds_2d, probs_2d return preds_2d_s, probs_2d_s
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment