SKL Decision Trees
The implementation of Decision Trees in Scikit Learn
For more details see the official documentation
Some points to note:
The module does not support missing values or categorical variables.
The cost of prediction is logarithmic to the number of data points used to train.
Can be unstable, with small variations in the data leading to a completely different tree.
Can create biased trees if some classes dominate. Recommended to balance the dataset prior to fitting.
Classification
Very simple to use in the base case:
from sklearn import tree
X = [[0,0], [1,1]]
y = [0,1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X,Y)
clf.predict([[2,2]])
clf.predict_proba([[2,2]]) # predict probability of each class
Capable of binary and multiclass classification.
Plotting the Tree
You can plot the tree and export to Graphviz or to text with export_text
.
Check out the graphviz export docs for colour options.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text, plot_tree, export_graphviz
import graphviz
iris = load_iris()
decision_tree = DecisionTreeClassifier(random_state = 0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
r = export_text(decision_tree, feature_names=iris['feature_names']
print(r) # get a nice string repr of the tree
plot_tree(decision_tree) # get a funky tree diagram
# or use graphviz
dot_data = export_graphviz(decision_tree, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("iris")