Skip to content
Snippets Groups Projects
Commit 168c6d9e authored by tomrink's avatar tomrink
Browse files

minor

parent 15e3cbbd
No related branches found
No related tags found
No related merge requests found
...@@ -8,8 +8,11 @@ from sklearn.metrics import confusion_matrix ...@@ -8,8 +8,11 @@ from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot_confusion_matrix_values(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
cm = confusion_matrix(correct_labels, predict_labels)
def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False): def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
''' '''
Parameters: Parameters:
correct_labels : These are your true classification categories. correct_labels : These are your true classification categories.
...@@ -26,7 +29,6 @@ def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusi ...@@ -26,7 +29,6 @@ def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusi
- Currently, some of the ticks dont line up due to rotations. - Currently, some of the ticks dont line up due to rotations.
''' '''
cm = confusion_matrix(correct_labels, predict_labels)
if normalize: if normalize:
cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis] cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis]
cm = np.nan_to_num(cm, copy=True) cm = np.nan_to_num(cm, copy=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment