From 9212f77a46838b72ac9b8b098a4f0b4912def9d6 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Sun, 19 May 2024 10:31:46 -0500 Subject: [PATCH] snapshot... --- modules/machine_learning/classification.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 914771c7..78256ac3 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -252,3 +252,13 @@ def gradient_boosting(x_train, y_train, x_test, y_test, n_estimators=100, max_de if saveModel: joblib.dump(gbm, '/Users/tomrink/icing_gbm.pkl') + + +def infer(model, x_test, y_test, does_prob=True): + yhat = model.predict(x_test) + if does_prob: + yhat_prob = model.predict_proba(x_test) + metrics(y_test, yhat, y_pred_prob=yhat_prob) + else: + metrics(y_test, yhat) + -- GitLab