Skip to content
Snippets Groups Projects
Commit b75a359c authored by tomrink's avatar tomrink
Browse files

bug fix

parent 4a525bf6
No related branches found
No related tags found
No related merge requests found
...@@ -951,6 +951,9 @@ class IcingIntensityNN: ...@@ -951,6 +951,9 @@ class IcingIntensityNN:
def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path): def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path):
ckpt_dir_s = os.listdir(ckpt_dir_s_path) ckpt_dir_s = os.listdir(ckpt_dir_s_path)
cm_s = [] cm_s = []
prob_s = []
labels = None
for ckpt in ckpt_dir_s: for ckpt in ckpt_dir_s:
ckpt_dir = ckpt_dir_s_path + ckpt ckpt_dir = ckpt_dir_s_path + ckpt
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
...@@ -958,13 +961,20 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path): ...@@ -958,13 +961,20 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path):
nn = IcingIntensityNN() nn = IcingIntensityNN()
nn.run_restore(filename_l1b, filename_l2, ckpt_dir) nn.run_restore(filename_l1b, filename_l2, ckpt_dir)
cm_s.append(tf.math.confusion_matrix(nn.test_labels.flatten(), nn.test_preds.flatten())) cm_s.append(tf.math.confusion_matrix(nn.test_labels.flatten(), nn.test_preds.flatten()))
prob_s.append(nn.test_probs.flatten())
if labels is None: # These should be the same
labels = nn.test_labels.flatten()
num = len(cm_s) num = len(cm_s)
cm_avg = cm_s[0] cm_avg = cm_s[0]
prob_avg = prob_s[0]
for k in range(num-1): for k in range(num-1):
cm_avg += cm_s[k+1] cm_avg += cm_s[k+1]
prob_avg += prob_s[k+1]
cm_avg /= num cm_avg /= num
prob_avg /= num
return cm_avg return labels, prob_avg, cm_avg
def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16', domain='FD'): def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16', domain='FD'):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment