Skip to content
Snippets Groups Projects
Commit a1767ed4 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 631dca6e
No related branches found
No related tags found
No related merge requests found
from textwrap import wrap from textwrap import wrap
import re import re
import os
import itertools import itertools
#import tfplot #import tfplot
import matplotlib import matplotlib
...@@ -8,18 +9,18 @@ from sklearn.metrics import confusion_matrix ...@@ -8,18 +9,18 @@ from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot_confusion_matrix_values(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False): def plot_confusion_matrix_values(correct_labels, predict_labels, labels, title='Confusion matrix', filename='confusion_matrix', normalize=False):
cm = confusion_matrix(correct_labels, predict_labels) cm = confusion_matrix(correct_labels, predict_labels)
def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False): def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False):
''' '''
Parameters: Parameters:
correct_labels : These are your true classification categories. correct_labels : These are your true classification categories.
predict_labels : These are you predicted classification categories predict_labels : These are you predicted classification categories
labels : This is a list of labels which will be used to display the axis labels labels : This is a list of labels which will be used to display the axis labels
title='Confusion matrix' : Title for your matrix title='Confusion matrix' : Title for your matrix
tensor_name = 'MyFigure/image' : Name for the output summay tensor filename = 'confusion_matrix' : Name for the output summay tensor
Returns: Returns:
summary: TensorFlow summary summary: TensorFlow summary
...@@ -59,3 +60,6 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'M ...@@ -59,3 +60,6 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', tensor_name = 'M
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black") ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
fig.set_tight_layout(True) fig.set_tight_layout(True)
ImageDirAndName = os.path.join('/Users/tomrink', filename)
fig.savefig(ImageDirAndName)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment