diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 914771c7c4f5e88ce6ecac16237be056de9b98a4..78256ac3bab979139c348b3f5cc16167ff78d607 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -252,3 +252,13 @@ def gradient_boosting(x_train, y_train, x_test, y_test, n_estimators=100, max_de if saveModel: joblib.dump(gbm, '/Users/tomrink/icing_gbm.pkl') + + +def infer(model, x_test, y_test, does_prob=True): + yhat = model.predict(x_test) + if does_prob: + yhat_prob = model.predict_proba(x_test) + metrics(y_test, yhat, y_pred_prob=yhat_prob) + else: + metrics(y_test, yhat) +