Skip to content
Snippets Groups Projects
Commit b75a359c authored by tomrink's avatar tomrink
Browse files

bug fix

parent 4a525bf6
Branches
No related tags found
No related merge requests found
...@@ -951,6 +951,9 @@ class IcingIntensityNN: ...@@ -951,6 +951,9 @@ class IcingIntensityNN:
def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path): def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path):
ckpt_dir_s = os.listdir(ckpt_dir_s_path) ckpt_dir_s = os.listdir(ckpt_dir_s_path)
cm_s = [] cm_s = []
prob_s = []
labels = None
for ckpt in ckpt_dir_s: for ckpt in ckpt_dir_s:
ckpt_dir = ckpt_dir_s_path + ckpt ckpt_dir = ckpt_dir_s_path + ckpt
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
...@@ -958,13 +961,20 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path): ...@@ -958,13 +961,20 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path):
nn = IcingIntensityNN() nn = IcingIntensityNN()
nn.run_restore(filename_l1b, filename_l2, ckpt_dir) nn.run_restore(filename_l1b, filename_l2, ckpt_dir)
cm_s.append(tf.math.confusion_matrix(nn.test_labels.flatten(), nn.test_preds.flatten())) cm_s.append(tf.math.confusion_matrix(nn.test_labels.flatten(), nn.test_preds.flatten()))
prob_s.append(nn.test_probs.flatten())
if labels is None: # These should be the same
labels = nn.test_labels.flatten()
num = len(cm_s) num = len(cm_s)
cm_avg = cm_s[0] cm_avg = cm_s[0]
prob_avg = prob_s[0]
for k in range(num-1): for k in range(num-1):
cm_avg += cm_s[k+1] cm_avg += cm_s[k+1]
prob_avg += prob_s[k+1]
cm_avg /= num cm_avg /= num
prob_avg /= num
return cm_avg return labels, prob_avg, cm_avg
def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16', domain='FD'): def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16', domain='FD'):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment