From 265f47432e80dd0618eb298a95cb49148d92c421 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 29 Apr 2024 16:20:36 -0500
Subject: [PATCH] snapshot...

---
 modules/machine_learning/classification.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index 87de1486..506db5e3 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -197,8 +197,8 @@ def k_nearest_neighbors_all(x, y, k_s=10):
     plt.show()
 
 
-def decision_tree(x, y, criterion='entropy', max_depth=4):
-    x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4)
+def decision_tree(x_train, y_train, x_test, y_test, criterion='entropy', max_depth=4):
+    # x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4)
     print('Train set:', x_train.shape,  y_train.shape)
     print('Test set:', x_test.shape,  y_test.shape)
 
@@ -211,6 +211,7 @@ def decision_tree(x, y, criterion='entropy', max_depth=4):
     yhat = DT.predict(x_test)
     yhat_prob = DT.predict_proba(x_test)
 
+    print(confusion_matrix(y_test, yhat, labels=[1, 0]))
     print('Accuracy:    ', "{:.4f}".format(accuracy_score(y_test, yhat)))
     print('Jaccard Idx: ', "{:.4f}".format(jaccard_score(y_test, yhat)))
     print('Precision:   ', "{:.4f}".format(precision_score(y_test, yhat)))
-- 
GitLab