diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index 914771c7c4f5e88ce6ecac16237be056de9b98a4..78256ac3bab979139c348b3f5cc16167ff78d607 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -252,3 +252,13 @@ def gradient_boosting(x_train, y_train, x_test, y_test, n_estimators=100, max_de
 
     if saveModel:
         joblib.dump(gbm, '/Users/tomrink/icing_gbm.pkl')
+
+
+def infer(model, x_test, y_test, does_prob=True):
+    yhat = model.predict(x_test)
+    if does_prob:
+        yhat_prob = model.predict_proba(x_test)
+        metrics(y_test, yhat, y_pred_prob=yhat_prob)
+    else:
+        metrics(y_test, yhat)
+