diff --git a/modules/util/plot_cm.py b/modules/util/plot_cm.py index 67f09d1ba8727790c5a95cfd0e3878fb73037534..619d00cdf025ac9da2a6e6fb8365ad8984e112f2 100644 --- a/modules/util/plot_cm.py +++ b/modules/util/plot_cm.py @@ -12,7 +12,7 @@ def confusion_matrix_values(correct_labels, predict_labels): return cm -def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False): +def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False, axis=1): ''' Parameters: correct_labels : These are your true classification categories. @@ -30,7 +30,10 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'conf ''' if normalize: - cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + if axis == 1: + cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis] + elif axis == 0: + cm = cm.astype('float') / cm.sum(axis=axis)[np.newaxis, :] cm *= 100 cm = np.nan_to_num(cm, copy=True) cm = cm.astype('int')