diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index d5fedef2135b08358b4a8343496c08d8624399e5..b0d64155d65fb4a125fcdfe0173ca5d2d90183b4 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -190,7 +190,7 @@ def decision_tree(x_train, y_train, x_test, y_test, criterion='entropy', max_dep return DT # Use this to plot the tree ----------------------------------------------------------- -# export_graphviz(DT, out_file='tree.dot', filled=True, feature_names=params) +# export_graphviz(DT, out_file='tree.dot', filled=True, class_names=['no icing', 'icing'], feature_names=params) # !dot -Tpng tree.dot -o tree.png