Alex's Notes

SKL Decision Trees

The implementation of Decision Trees in Scikit Learn

For more details see the official documentation

Some points to note:

  • The module does not support missing values or categorical variables.

  • The cost of prediction is logarithmic to the number of data points used to train.

  • Can be unstable, with small variations in the data leading to a completely different tree.

  • Can create biased trees if some classes dominate. Recommended to balance the dataset prior to fitting.

Classification

Very simple to use in the base case:

from sklearn import tree

X = [[0,0], [1,1]]
y = [0,1]

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X,Y)
clf.predict([[2,2]])
clf.predict_proba([[2,2]]) # predict probability of each class

Capable of binary and multiclass classification.

Plotting the Tree

You can plot the tree and export to Graphviz or to text with export_text.

Check out the graphviz export docs for colour options.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text, plot_tree, export_graphviz
import graphviz

iris = load_iris()
decision_tree = DecisionTreeClassifier(random_state = 0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
r = export_text(decision_tree, feature_names=iris['feature_names']
print(r) # get a nice string repr of the tree

plot_tree(decision_tree) # get a funky tree diagram

# or use graphviz
dot_data = export_graphviz(decision_tree, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("iris")

Links to this note