Skip to content
Snippets Groups Projects
plot_cm.py 2.63 KiB
Newer Older
rink's avatar
rink committed
from textwrap import wrap
import re
tomrink's avatar
tomrink committed
import os
rink's avatar
rink committed
import itertools
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt


tomrink's avatar
tomrink committed
def confusion_matrix_values(correct_labels, predict_labels):
tomrink's avatar
tomrink committed
    cm = confusion_matrix(correct_labels, predict_labels)
tomrink's avatar
tomrink committed
    return cm
tomrink's avatar
tomrink committed

rink's avatar
rink committed

tomrink's avatar
tomrink committed
def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False, axis=1):
rink's avatar
rink committed
    '''
    Parameters:
        correct_labels                  : These are your true classification categories.
        predict_labels                  : These are you predicted classification categories
        labels                          : This is a list of labels which will be used to display the axis labels
        title='Confusion matrix'        : Title for your matrix
tomrink's avatar
tomrink committed
        filename = 'confusion_matrix'   : Name for the output summay tensor
rink's avatar
rink committed

    Returns:
        summary: TensorFlow summary

    Other itema to note:
        - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc.
        - Currently, some of the ticks dont line up due to rotations.
    '''

    if normalize:
tomrink's avatar
tomrink committed
        if axis == 1:
            cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis]
        elif axis == 0:
            cm = cm.astype('float') / cm.sum(axis=axis)[np.newaxis, :]
tomrink's avatar
tomrink committed
        cm *= 100
rink's avatar
rink committed
        cm = np.nan_to_num(cm, copy=True)
        cm = cm.astype('int')

    np.set_printoptions(precision=2)
    fig = plt.figure(figsize=(3, 3), dpi=320, facecolor='w', edgecolor='k')
    ax = fig.add_subplot(1, 1, 1)
tomrink's avatar
tomrink committed
    # im = ax.imshow(cm, cmap='Oranges')
    im = ax.imshow(cm, cmap='Blues')
rink's avatar
rink committed

    classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
    classes = ['\n'.join(wrap(l, 40)) for l in classes]

    tick_marks = np.arange(len(classes))

    ax.set_xlabel('Predicted', fontsize=7)
    ax.set_xticks(tick_marks)
    c = ax.set_xticklabels(classes, fontsize=4, rotation=-90,  ha='center')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.tick_bottom()

    ax.set_ylabel('True Label', fontsize=7)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, fontsize=4, va ='center')
    ax.yaxis.set_label_position('left')
    ax.yaxis.tick_left()

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
tomrink's avatar
tomrink committed
        ax.text(j, i, format(cm[i, j], 'd') if cm[i,j] != 0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
rink's avatar
rink committed
    fig.set_tight_layout(True)
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
    plt.title(title, loc='left', fontweight='bold', fontsize=6)

tomrink's avatar
tomrink committed
    ImageDirAndName = os.path.join('/Users/tomrink', filename)
    fig.savefig(ImageDirAndName)