diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 9b5320380fd1ed70ae0658bf1d86cddf5a3922b3..9ab5ed3171333eecbfbbc891699c2a153899209a 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -88,8 +88,12 @@ def get_train_test_data(data_frame, standardize=True): return x, y -def logistic_regression(x, y): - x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4) +def logistic_regression(x, y, x_test=None, y_test=None): + if x_test is None: + x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4) + else: + x_train = x + y_train = y print('Train set:', x_train.shape, y_train.shape) print('Test set:', x_test.shape, y_test.shape) @@ -220,8 +224,12 @@ def SVM(x, y, kernel='rbf'): print('F1: ', "{:.4f}".format(f1_score(y_test, yhat))) -def random_forest(x, y, criterion='entropy', max_depth=4): - x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4) +def random_forest(x, y, x_test=None, y_test=None, criterion='entropy', max_depth=4): + if x_test is None: + x_train, x_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=4) + else: + x_train = x + y_train = y print('Train set:', x_train.shape, y_train.shape) print('Test set:', x_test.shape, y_test.shape)