diff --git a/modules/util/plot_cm.py b/modules/util/plot_cm.py index f7c52b697addc4beb26f4b052faab263341e6810..c821cee3d6b6612ada98b997cca309f307118709 100644 --- a/modules/util/plot_cm.py +++ b/modules/util/plot_cm.py @@ -8,8 +8,11 @@ from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt +def plot_confusion_matrix_values(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False): + cm = confusion_matrix(correct_labels, predict_labels) + -def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False): +def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False): ''' Parameters: correct_labels : These are your true classification categories. @@ -26,7 +29,6 @@ def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusi - Currently, some of the ticks dont line up due to rotations. ''' - cm = confusion_matrix(correct_labels, predict_labels) if normalize: cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis] cm = np.nan_to_num(cm, copy=True)