From 9212f77a46838b72ac9b8b098a4f0b4912def9d6 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Sun, 19 May 2024 10:31:46 -0500
Subject: [PATCH] snapshot...

---
 modules/machine_learning/classification.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index 914771c7..78256ac3 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -252,3 +252,13 @@ def gradient_boosting(x_train, y_train, x_test, y_test, n_estimators=100, max_de
 
     if saveModel:
         joblib.dump(gbm, '/Users/tomrink/icing_gbm.pkl')
+
+
+def infer(model, x_test, y_test, does_prob=True):
+    yhat = model.predict(x_test)
+    if does_prob:
+        yhat_prob = model.predict_proba(x_test)
+        metrics(y_test, yhat, y_pred_prob=yhat_prob)
+    else:
+        metrics(y_test, yhat)
+
-- 
GitLab