Skip to content
Snippets Groups Projects
Commit a1767ed4 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 631dca6e
No related branches found
No related tags found
No related merge requests found
from textwrap import wrap
import re
import os
import itertools
#import tfplot
import matplotlib
......@@ -8,18 +9,18 @@ from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
def plot_confusion_matrix_values(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
def plot_confusion_matrix_values(correct_labels, predict_labels, labels, title='Confusion matrix', filename='confusion_matrix', normalize=False):
cm = confusion_matrix(correct_labels, predict_labels)
def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False):
'''
Parameters:
correct_labels : These are your true classification categories.
predict_labels : These are you predicted classification categories
labels : This is a list of labels which will be used to display the axis labels
title='Confusion matrix' : Title for your matrix
tensor_name = 'MyFigure/image' : Name for the output summay tensor
filename = 'confusion_matrix' : Name for the output summay tensor
Returns:
summary: TensorFlow summary
......@@ -59,3 +60,6 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'M
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
fig.set_tight_layout(True)
ImageDirAndName = os.path.join('/Users/tomrink', filename)
fig.savefig(ImageDirAndName)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment