diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index 4e618041f2d0079597a321fb9cc4a1a6bfbecd9b..93564c92223ae5ec4db32b88f497c8172759a9d8 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -10,6 +10,7 @@ from sklearn.linear_model import LogisticRegression
 from sklearn.neighbors import KNeighborsClassifier
 from sklearn.tree import DecisionTreeClassifier
 import sklearn.tree as tree
+from sklearn.tree import export_graphviz
 
 
 def get_csv_as_dataframe(csv_file, reduce_frac=None):
@@ -117,3 +118,7 @@ def decision_tree(x, y, max_depth=4):
     print('AUC:         ', "{:.4f}".format(roc_auc_score(y_test, yhat_prob[:, 1])))
 
     return DT
+
+# export_graphviz(DT, out_file='tree.dot', filled=True, feature_names=['cld_geo_thick', 'cld_temp_acha', 'conv_cloud_fraction', 'supercooled_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp'])
+# !dot -Tpng tree.dot -o tree.png
+