diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index 9b5320380fd1ed70ae0658bf1d86cddf5a3922b3..9ab5ed3171333eecbfbbc891699c2a153899209a 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -88,8 +88,12 @@ def get_train_test_data(data_frame, standardize=True):
     return x, y
 
 
-def logistic_regression(x, y):
-    x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4)
+def logistic_regression(x, y, x_test=None, y_test=None):
+    if x_test is None:
+        x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4)
+    else:
+        x_train = x
+        y_train = y
     print('Train set:', x_train.shape,  y_train.shape)
     print('Test set:', x_test.shape,  y_test.shape)
 
@@ -220,8 +224,12 @@ def SVM(x, y, kernel='rbf'):
     print('F1:          ', "{:.4f}".format(f1_score(y_test, yhat)))
 
 
-def random_forest(x, y, criterion='entropy', max_depth=4):
-    x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4)
+def random_forest(x, y, x_test=None, y_test=None, criterion='entropy', max_depth=4):
+    if x_test is None:
+        x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4)
+    else:
+        x_train = x
+        y_train = y
     print('Train set:', x_train.shape,  y_train.shape)
     print('Test set:', x_test.shape,  y_test.shape)