diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index dcbfa2de05509c4f54e88e6cddc5a1c78c25c9e7..9b5320380fd1ed70ae0658bf1d86cddf5a3922b3 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -10,10 +10,46 @@ from sklearn.linear_model import LogisticRegression from sklearn.neighbors import KNeighborsClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier +import itertools import sklearn.tree as tree from sklearn.tree import export_graphviz +def plot_confusion_matrix(cm, classes, + normalize=False, + title='Confusion matrix', + cmap=plt.cm.Blues): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + print("Normalized confusion matrix") + else: + print('Confusion matrix, without normalization') + + print(cm) + + plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.title(title) + plt.colorbar() + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, rotation=45) + plt.yticks(tick_marks, classes) + + fmt = '.2f' if normalize else 'd' + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, format(cm[i, j], fmt), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black") + + plt.tight_layout() + plt.ylabel('True label') + plt.xlabel('Predicted label') + + def get_csv_as_dataframe(csv_file, reduce_frac=None): icing_df = pd.read_csv(csv_file) # Random selection of reduce_frac of the rows