diff --git a/lightningcast/performance_diagrams.py b/lightningcast/performance_diagrams.py index c61f46d2db93639876e97397692a013c1912bb72..8969ce6f80616e36cd7f36f32838f6c8eee6e48a 100755 --- a/lightningcast/performance_diagrams.py +++ b/lightningcast/performance_diagrams.py @@ -13,6 +13,7 @@ import matplotlib.colors import matplotlib.pyplot as plt from lightningcast import bootstrap import time +import sklearn.metrics DEFAULT_LINE_COLOUR = np.array([228, 26, 28], dtype=float) / 255 DEFAULT_LINE_WIDTH = 2 @@ -69,6 +70,33 @@ def get_month_name(month): return month_name +def get_area_under_perf_diagram(success_ratio_by_threshold, pod_by_threshold): + """Computes area under performance diagram. + T = number of binarization thresholds + :param success_ratio_by_threshold: length-T numpy array of success ratios. + :param pod_by_threshold: length-T numpy array of corresponding POD values. + :return: area_under_curve: Area under performance diagram. + Credit: https://github.com/thunderhoser/GewitterGefahr/blob/master/gewittergefahr/gg_utils/model_evaluation.py + """ + + num_thresholds = len(success_ratio_by_threshold) + expected_dim = np.array([num_thresholds], dtype=int) + + sort_indices = np.argsort(success_ratio_by_threshold) + success_ratio_by_threshold = success_ratio_by_threshold[sort_indices] + pod_by_threshold = pod_by_threshold[sort_indices] + + nan_flags = np.logical_or( + np.isnan(success_ratio_by_threshold), np.isnan(pod_by_threshold) + ) + if np.all(nan_flags): + return np.nan + + real_indices = np.where(np.invert(nan_flags))[0] + + return sklearn.metrics.auc( + success_ratio_by_threshold[real_indices], pod_by_threshold[real_indices] + ) def _get_sr_pod_grid(success_ratio_spacing=0.01, pod_spacing=0.01): """Creates grid in SR-POD (success ratio / probability of detection) space. diff --git a/lightningcast/skill_by_month.py b/lightningcast/skill_by_month.py index bb1783e6f6b4a3f77afb9e5fa5030c96f0a1e96e..1085e00e1ab17b206613c74cb7b498be61ca12f6 100755 --- a/lightningcast/skill_by_month.py +++ b/lightningcast/skill_by_month.py @@ -1,40 +1,13 @@ import numpy as np import matplotlib.pyplot as plt +import performance_diagrams +import glob +import pickle # goes east conus -# CSI = [0.4543, 0.2572, 0.3238, 0.4015, 0.5259, 0.4286, 0.5037, 0.5095, 0.4584, 0.3646, 0.2341, 0.3222] -CSIday = [ - 0.4537, - 0.3969, - 0.3506, - 0.4220, - 0.5301, - 0.4256, - 0.5112, - 0.5121, - 0.4690, - 0.3521, - 0.2495, - 0.2972, -] -CSInight = [ - 0.4547, - 0.3786, - 0.3061, - 0.3913, - 0.5239, - 0.4308, - 0.4967, - 0.5075, - 0.4508, - 0.3728, - 0.2270, - 0.3369, -] months = [ - "Jan", - "Feb", + "Jan-Feb", "Mar", "Apr", "May", @@ -43,62 +16,237 @@ months = [ "Aug", "Sep", "Oct", - "Nov", - "Dec", + "Nov-Dec", ] -best_threshold = [0.55, 0.44, 0.43, 0.39, 0.36, 0.39, 0.4, 0.36, 0.41, 0.42, 0.59, 0.67] + +test = "c0205ref1315" +months_scores = np.sort(glob.glob(f"/ships22/grain/probsevere/LC/tests/2019-22/{test}/eval2023/month[0,1]*/eval_results.pkl")) + +aupd1 = [] +csi1 = [] +best_thresh1 = [] +bss1 = [] +for mm in months_scores: + print(mm) + scores = pickle.load(open(mm,'rb')) + + labels = [] + pod = [] + far = [] + + for key, val in scores.items(): + if 'pod' in key and f'index0' in key: + prob = key.split('_')[0][3:5] + labels.append(int(prob)) + pod.append(val) + far.append(scores[f'far{prob}_index0']) + # Sort values in ascending order of probability labels + pod = np.array(pod) + far = np.array(far) + labels = np.array(labels) + pod = np.concatenate([np.array([1]),pod[labels.argsort()],np.array([0])]) + far = np.concatenate([np.array([1]),far[labels.argsort()],np.array([0])]) + labels = np.concatenate([np.array([0]),labels,np.array([100])]) + sr = 1 - far + + aupd1.append(performance_diagrams.get_area_under_perf_diagram(sr, pod)) + csi = 1 / (1 / (1 - far) + (1 / pod) - 1) + csi1.append(np.max(csi)) + best_thresh1.append(labels[np.argmax(csi)]) + +test = "c02051315" +months_scores = np.sort(glob.glob(f"/ships22/grain/probsevere/LC/tests/2019-22/{test}/eval2023/month[0,1]*/eval_results.pkl")) + +aupd2 = [] +csi2 = [] +bss2 = [] +for mm in months_scores: + print(mm) + scores = pickle.load(open(mm,'rb')) + + labels = [] + pod = [] + far = [] + + for key, val in scores.items(): + if 'pod' in key and f'index0' in key: + prob = key.split('_')[0][3:5] + labels.append(int(prob)) + pod.append(val) + far.append(scores[f'far{prob}_index0']) + # Sort values in ascending order of probability labels + pod = np.array(pod) + far = np.array(far) + labels = np.array(labels) + pod = np.concatenate([np.array([1]),pod[labels.argsort()],np.array([0])]) + far = np.concatenate([np.array([1]),far[labels.argsort()],np.array([0])]) + labels = np.concatenate([np.array([0]),labels,np.array([100])]) + sr = 1 - far + + aupd2.append(performance_diagrams.get_area_under_perf_diagram(sr, pod)) + csi = 1 / (1 / (1 - far) + (1 / pod) - 1) + csi2.append(np.max(csi)) + + +test = "control" +months_scores = np.sort(glob.glob(f"/ships22/grain/probsevere/LC/tests/2019-22/{test}/eval2023/month[0,1]*/eval_results.pkl")) + +aupd3 = [] +csi3 = [] +bss3 = [] +for mm in months_scores: + print(mm) + scores = pickle.load(open(mm,'rb')) + + labels = [] + pod = [] + far = [] + + for key, val in scores.items(): + if 'pod' in key and f'index0' in key: + prob = key.split('_')[0][3:5] + labels.append(int(prob)) + pod.append(val) + far.append(scores[f'far{prob}_index0']) + # Sort values in ascending order of probability labels + pod = np.array(pod) + far = np.array(far) + labels = np.array(labels) + pod = np.concatenate([np.array([1]),pod[labels.argsort()],np.array([0])]) + far = np.concatenate([np.array([1]),far[labels.argsort()],np.array([0])]) + labels = np.concatenate([np.array([0]),labels,np.array([100])]) + sr = 1 - far + + aupd3.append(performance_diagrams.get_area_under_perf_diagram(sr, pod)) + csi = 1 / (1 / (1 - far) + (1 / pod) - 1) + csi3.append(np.max(csi)) + + fig = plt.figure(num=None, figsize=(10, 7), dpi=300, facecolor="w", edgecolor="k") ax = plt.gca() -fs = 14 +fs = 12 +ms = 6 +lw = 3 +labs = [] +lines = [] -color1 = "orange" -color2 = "red" -color3 = "blue" -(csiLineDay,) = ax.plot( +color1 = "red" +color2 = "orange" +color3 = "yellow" +(aupdLine3,) = ax.plot( range(len(months)), - CSIday, + aupd3, + linestyle="solid", + color=color3, + lw=lw, + marker="o", + markersize=ms, +) +lines.append(aupdLine3) +labs.append('AUPD - control') + +(aupdLine2,) = ax.plot( + range(len(months)), + aupd2, linestyle="solid", color=color2, - lw=4.0, + lw=lw, marker="o", - markersize=8, + markersize=ms, ) -(csiLineNight,) = ax.plot( +lines.append(aupdLine2) +labs.append('AUPD - w/o Ref10') + +(aupdLine1,) = ax.plot( range(len(months)), - CSInight, + aupd1, linestyle="solid", + color=color1, + lw=lw, + marker="o", + markersize=ms, +) +lines.append(aupdLine1) +labs.append('AUPD - w/ Ref10') + +#(threshLine,) = ax.plot( +# range(len(months)), +# best_threshold, +# linestyle="solid", +# color=color1, +# lw=2.0, +# marker="o", +# markersize=5, +#) + +(csiLine3,) = ax.plot( + range(len(months)), + csi3, + linestyle="dashed", color=color3, - lw=4.0, + lw=lw, marker="o", - markersize=8, + markersize=ms, ) -(threshLine,) = ax.plot( +lines.append(csiLine3) +labs.append('max CSI - control') + +(csiLine2,) = ax.plot( range(len(months)), - best_threshold, - linestyle="solid", + csi2, + linestyle="dashed", + color=color2, + lw=lw, + marker="o", + markersize=ms, +) +lines.append(csiLine2) +labs.append('max CSI - w/o Ref10') + +(csiLine1,) = ax.plot( + range(len(months)), + csi1, + linestyle="dashed", color=color1, - lw=2.0, + lw=lw, marker="o", - markersize=5, + markersize=ms, +) +lines.append(csiLine1) +labs.append('max CSI - w/ Ref10') + +(threshLine,) = ax.plot( + range(len(months)), + np.array(best_thresh1)/100., + linestyle="solid", + color="blue", + lw=1, + marker="None", ) +lines.append(threshLine) +labs.append('best prob. thresh. - w/ Ref10') xticklabels = months ax.set_ylim(0.1, 0.8) ax.set_xlim(-0.5, len(months) - 0.5) ax.set_xticks(range(len(months))) ax.set_xticklabels(months, fontsize=fs) -plt.xticks(fontsize=fs) +plt.xticks(fontsize=fs-2, rotation=45) ax.yaxis.grid(True, linestyle=":") ax.xaxis.grid(True, linestyle=":") -ax.set_ylabel("Score or prob. thresh.", fontsize=fs + 2) -ax.set_xlabel("Month", fontsize=fs + 2) +ax.set_ylabel("Score or prob. thresh.", fontsize=fs) +ax.set_xlabel("Month", fontsize=fs) plt.yticks(fontsize=fs) -plt.title("Best CSI and probability threshold", fontsize=fs + 6) -labs = ["CSI (day)", "CSI (night)", "Prob. thresh."] +plt.title("Scores for 2023 test set", fontsize=fs + 4) leg = ax.legend( - [csiLineDay, csiLineNight, threshLine], labs, fontsize=fs, loc=(0.01, 0.80) + lines, + labs, + fontsize=fs-2, + loc=(1.01, 0.65) ) -# plt.show() -plt.savefig("goes_east_conus_by_month.png", format="png", bbox_inches="tight") +#plt.show() +figname = "/ships22/grain/probsevere/LC/tests/2019-22/c0205ref1315/eval2023/goes_east_conus_by_month.png" +plt.savefig(figname, format="png", bbox_inches="tight") +print("Saved " + figname)