diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index e8358bfcee46cd91de469b65a7a12724887738ad..4e618041f2d0079597a321fb9cc4a1a6bfbecd9b 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -4,7 +4,7 @@ import numpy as np import scipy.optimize as opt from sklearn import preprocessing import matplotlib.pyplot as plt -from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, jaccard_score, f1_score, precision_score, recall_score, roc_auc_score +from sklearn.metrics import confusion_matrix, accuracy_score, jaccard_score, f1_score, precision_score, recall_score, roc_auc_score from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.neighbors import KNeighborsClassifier @@ -114,4 +114,6 @@ def decision_tree(x, y, max_depth=4): print('Precision: ', "{:.4f}".format(precision_score(y_test, yhat))) print('Recall: ', "{:.4f}".format(recall_score(y_test, yhat))) print('F1: ', "{:.4f}".format(f1_score(y_test, yhat))) - print('AUC: ', "{:.4f}".format(roc_auc_score(y_test, yhat_prob[:, 1]))) \ No newline at end of file + print('AUC: ', "{:.4f}".format(roc_auc_score(y_test, yhat_prob[:, 1]))) + + return DT