Skip to content
Snippets Groups Projects
Commit 08060ba0 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e19a6cc5
Branches
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ def confusion_matrix_values(correct_labels, predict_labels):
return cm
def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False):
def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False, axis=1):
'''
Parameters:
correct_labels : These are your true classification categories.
......@@ -30,7 +30,10 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'conf
'''
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
if axis == 1:
cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis]
elif axis == 0:
cm = cm.astype('float') / cm.sum(axis=axis)[np.newaxis, :]
cm *= 100
cm = np.nan_to_num(cm, copy=True)
cm = cm.astype('int')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment