diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 4e618041f2d0079597a321fb9cc4a1a6bfbecd9b..93564c92223ae5ec4db32b88f497c8172759a9d8 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -10,6 +10,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.neighbors import KNeighborsClassifier from sklearn.tree import DecisionTreeClassifier import sklearn.tree as tree +from sklearn.tree import export_graphviz def get_csv_as_dataframe(csv_file, reduce_frac=None): @@ -117,3 +118,7 @@ def decision_tree(x, y, max_depth=4): print('AUC: ', "{:.4f}".format(roc_auc_score(y_test, yhat_prob[:, 1]))) return DT + +# export_graphviz(DT, out_file='tree.dot', filled=True, feature_names=['cld_geo_thick', 'cld_temp_acha', 'conv_cloud_fraction', 'supercooled_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp']) +# !dot -Tpng tree.dot -o tree.png +