Skip to content
Snippets Groups Projects
Commit 63a9e447 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 31f9d640
No related branches found
No related tags found
No related merge requests found
...@@ -973,9 +973,7 @@ class IcingIntensityFCN: ...@@ -973,9 +973,7 @@ class IcingIntensityFCN:
pred_s = [] pred_s = []
for data in self.eval_dataset: for data in self.eval_dataset:
print(data[0].shape, data[1].shape)
pred = self.model([data]) pred = self.model([data])
print(pred.shape, np.histogram(pred.numpy()))
pred_s.append(pred) pred_s.append(pred)
preds = np.concatenate(pred_s) preds = np.concatenate(pred_s)
...@@ -1122,13 +1120,14 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l ...@@ -1122,13 +1120,14 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
for flvl in flight_levels: for flvl in flight_levels:
nn.flight_level = flvl nn.flight_level = flvl
nn.setup_eval_pipeline(data_dct, num_tiles) nn.setup_eval_pipeline(data_dct, num_tiles)
nn.do_evaluate(ckpt_dir) nn.do_evaluate(prob_thresh=prob_thresh)
probs = nn.test_probs probs = nn.test_probs
if NumClasses == 2: preds = nn.test_preds
preds = np.where(probs > prob_thresh, 1, 0) # if NumClasses == 2:
else: # preds = np.where(probs > prob_thresh, 1, 0)
preds = np.argmax(probs, axis=1) # else:
# preds = np.argmax(probs, axis=1)
probs_dct[flvl] = probs probs_dct[flvl] = probs
preds_dct[flvl] = preds preds_dct[flvl] = preds
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment