Decision Trees in Machine Learning

Decision Trees are versatile Machine Learning algorithms that can perform both classification and regression tasks, and even multi-output tasks. They are powerful algorithms, capable of fitting complex datasets.

Decision trees are also the fundamental components of Random Forests, which are among the most powerful Machine Learning algorithms available today.

In this article, I will start by discussing how to train, visualize, and make predictions with Decision Trees. Then I will go through the CART training algorithm used by Scikit-Learn, and I will discuss how to regularize trees and use them for regression tasks.

Also, read – 10 Machine Learning Projects to Boost your Portfolio

Training and Visualizing Decision Trees

To understand Decision Trees, let’s build one and take a look at how it makes predictions. The following code trains a DecisionTreeClassifier on the iris dataset:

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X =[:, 2:] # petal length and width
y =

tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42), y)Code language: Python (python)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=2,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=42,

You can visualize the trained Decision Tree by first using the export_graphviz() method to output a graph definition file called

from graphviz import Source
from sklearn.tree import export_graphviz

        out_file=os.path.join(IMAGES_PATH, ""),

Source.from_file(os.path.join(IMAGES_PATH, ""))Code language: Python (python)
decision trees
Iris Decision Tree

Making Predictions

Let’s see how the tree represented in the above figure make predictions. Suppose you find an iris flower and you want to classify it. You start at the root node, this node asks whether the flower’s petal length is smaller than 2.45 cm. If it is, then you move to the root’s left child node. In this case, it is a leaf node, so it does not ask any questions, simply at the predicted class for that node, and the Decision Tree predicts that your flower is an Iris setosa.

Now suppose you find another flower, and this time the petal length is greater than 2.45 cm. You must move down to the root’s right child node, which is not a leaf node, so the node asks another question, is the petal width smaller than 1.75 cm? If it is, then your flower is mostly an Iris versicolor. If not, it is likely an Iris virginica. It’s really that simple.

One of the many qualities of Decision Trees is that they require very little data preparation. In fact, they don’t require feature scaling or centering at all.

from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):
    x1s = np.linspace(axes[0], axes[1], 100)
    x2s = np.linspace(axes[2], axes[3], 100)
    x1, x2 = np.meshgrid(x1s, x2s)
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)
    if not iris:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
    if plot_training:
        plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", label="Iris setosa")
        plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", label="Iris versicolor")
        plt.plot(X[:, 0][y==2], X[:, 1][y==2], "g^", label="Iris virginica")
    if iris:
        plt.xlabel("Petal length", fontsize=14)
        plt.ylabel("Petal width", fontsize=14)
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.ylabel(r"$x_2$", fontsize=18, rotation=0)
    if legend:
        plt.legend(loc="lower right", fontsize=14)

plt.figure(figsize=(8, 4))
plot_decision_boundary(tree_clf, X, y)
plt.plot([2.45, 2.45], [0, 3], "k-", linewidth=2)
plt.plot([2.45, 7.5], [1.75, 1.75], "k--", linewidth=2)
plt.plot([4.95, 4.95], [0, 1.75], "k:", linewidth=2)
plt.plot([4.85, 4.85], [1.75, 3], "k:", linewidth=2)
plt.text(1.40, 1.0, "Depth=0", fontsize=15)
plt.text(3.2, 1.80, "Depth=1", fontsize=13)
plt.text(4.05, 0.5, "(Depth=2)", fontsize=11)

save_fig("decision_tree_decision_boundaries_plot") language: Python (python)
Decision Tree decision boundaries

The above figure shows this Decision Tree’s decision boundaries. The thick vertical line represents the decision boundary of the root node: petal length = 2.45 cm. Since the lefthand area is pure, it cannot be split any further.

However, the righthand area is impure, so the depth-1 right node splits it at petal width = 1.75 cm. Since max_depth was set to 2, the Decision tree stops right there. If you set max_depth to 3, then the two depth-2 nodes would each add another decision boundary.

Estimating Class Probabilities

Decision Trees can also estimate the probability that an instance belongs to a particular class k. First it traverses the tree to find the leaf node for this instance, and then it returns the ratio of training instances of class k in this node.

For example, suppose you have found a flower whose petals are 5 cm long and 1.5 cm wide. The corresponding leaf node is the depth-2 left node, so the Decision Trees should output the following probabilities: 0% for Iris setosa, and 90.7% for the Iris versicolor, and 9.3% for Iris virginica.

And if you ask it to predict the class, it should output Iris versicolor (class 1) because it has the highest probability. Let’s check this:

tree_clf.predict_proba([[5, 1.5]])Code language: Python (python)
array([[0. , 0.90740741, 0.09259259]])
tree_clf.predict([[5, 1.5]])Code language: Python (python)

Perfect! Notice that the estimated probabilities would be identical anywhere else in the bottom right rectangle of the figure below, for example – if the petals were 6 cm long and 1.5 cm wide.

Decision Tree decision boundaries


Hopefully by now you are convinced that Decision Trees have a lot going for them they are simple to understand and interpret, easy to use, versatile and powerful. However they do have a few limitations. First, as you may have noticed, Decision Trees love orthogonal decision boundaries, which make them sensitive to training set rotation. Let’s have a look at it:

X[(X[:, 1]==X[:, 1][y==1].max()) & (y==1)] # widest Iris versicolor flowerCode language: Python (python)
array([[4.8, 1.8]])
not_widest_versicolor = (X[:, 1]!=1.8) | (y==2)
X_tweaked = X[not_widest_versicolor]
y_tweaked = y[not_widest_versicolor]

tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40), y_tweaked)Code language: Python (python)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini', max_depth=2, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort='deprecated', random_state=40, splitter='best')
plt.figure(figsize=(8, 4))
plot_decision_boundary(tree_clf_tweaked, X_tweaked, y_tweaked, legend=False)
plt.plot([0, 7.5], [0.8, 0.8], "k-", linewidth=2)
plt.plot([0, 7.5], [1.75, 1.75], "k--", linewidth=2)
plt.text(1.0, 0.9, "Depth=0", fontsize=15)
plt.text(1.0, 1.80, "Depth=1", fontsize=13)

save_fig("decision_tree_instability_plot") language: Python (python)
Sensitivity to training set details
from sklearn.datasets import make_moons
Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)

deep_tree_clf1 = DecisionTreeClassifier(random_state=42)
deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42), ym), ym)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)[0])
plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)
plt.title("No restrictions", fontsize=16)[1])
plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)
plt.title("min_samples_leaf = {}".format(deep_tree_clf2.min_samples_leaf), fontsize=14)

save_fig("min_samples_leaf_plot") language: Python (python)

The above figure shows two Decision Trees trained on the moons dataset. On the left the Decision Tree is trained with the default hyperparameters, and on the right it’s trained with min_samples_leaf = 4. It is quite obvious that the model on the left is overfitting, and the model on the right will probably generalize better.

Xs = np.random.rand(100, 2) - 0.5
ys = (Xs[:, 0] > 0).astype(np.float32) * 2

angle = np.pi / 4
rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
Xsr =

tree_clf_s = DecisionTreeClassifier(random_state=42), ys)
tree_clf_sr = DecisionTreeClassifier(random_state=42), ys)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)[0])
plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)[1])
plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)

save_fig("sensitivity_to_rotation_plot") language: Python (python)

The above figure shows a simple linear separable dataset: on the left, a Decision Tree an split it easily, while op the right, after the dataset is rotated to 45 degree, the decision boundary looks unnecessarily convoluted. Although both Decision Trees fit the training set perfectly. It is very likely that the model on the right ride will not generalize well.

Regression: Decision Trees

Decision trees are also capable of performing regression tasks. Let’s build a regression tree using Scikit-Learn’s DecisionTreeRegressor class, training it on a noisy quadritic dataset with max_depth = 2:

# Quadratic training set + noise
m = 200
X = np.random.rand(m, 1)
y = 4 * (X - 0.5) ** 2
y = y + np.random.randn(m, 1) / 10
from sklearn.tree import DecisionTreeRegressor

tree_reg = DecisionTreeRegressor(max_depth=2, random_state=42), y)Code language: Python (python)
DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=2, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort='deprecated', random_state=42, splitter='best')
from sklearn.tree import DecisionTreeRegressor

tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3), y), y)

def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel="$y$"):
    x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)
    y_pred = tree_reg.predict(x1)
    plt.xlabel("$x_1$", fontsize=18)
    if ylabel:
        plt.ylabel(ylabel, fontsize=18, rotation=0)
    plt.plot(X, y, "b.")
    plt.plot(x1, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)[0])
plot_regression_predictions(tree_reg1, X, y)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
plt.text(0.21, 0.65, "Depth=0", fontsize=15)
plt.text(0.01, 0.2, "Depth=1", fontsize=13)
plt.text(0.65, 0.8, "Depth=1", fontsize=13)
plt.legend(loc="upper center", fontsize=18)
plt.title("max_depth=2", fontsize=14)[1])
plot_regression_predictions(tree_reg2, X, y, ylabel=None)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
for split in (0.0458, 0.1298, 0.2873, 0.9040):
    plt.plot([split, split], [-0.2, 1], "k:", linewidth=1)
plt.text(0.3, 0.5, "Depth=2", fontsize=13)
plt.title("max_depth=3", fontsize=14)

save_fig("tree_regression_plot") language: Python (python)
decision tree regressor
Prediction of two Decision Trees regression model

Now let’s have a look at the resulting tree using the Decision tree for regression:

        out_file=os.path.join(IMAGES_PATH, ""),
Source.from_file(os.path.join(IMAGES_PATH, ""))Code language: Python (python)
Decision trees
A Decision Tree for Regression

This tree looks very similar to the classification model we built earlier. The main difference is that instead of predicting a class in each node, it predicts a value.

Regularization of Decision Tree Regressor

Just like classification tasks, Decision trees are prone to overfitting when dealing with regression tasks. Let’s have a look at it:

tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10), y), y)

x1 = np.linspace(0, 1, 500).reshape(-1, 1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)[0])
plt.plot(X, y, "b.")
plt.plot(x1, y_pred1, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([0, 1, -0.2, 1.1])
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", fontsize=18, rotation=0)
plt.legend(loc="upper center", fontsize=18)
plt.title("No restrictions", fontsize=14)[1])
plt.plot(X, y, "b.")
plt.plot(x1, y_pred2, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([0, 1, -0.2, 1.1])
plt.xlabel("$x_1$", fontsize=18)
plt.title("min_samples_leaf={}".format(tree_reg2.min_samples_leaf), fontsize=14)

save_fig("tree_regression_regularization_plot") language: Python (python)
Regularization of Decision Tree Regressor
Regularization of Decision Tree Regressor

Also, read – Support Vector Machines in Machine Learning

These predictions are obviously overfitting the training set badly. So I hope you liked this article on Decision Trees in Machine Learning, I covered both the classification and regression tasks in decision trees with their disadvantages too. Feel free to ask question in the comments section below on any topic you want to.

Follow Us:

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1534

Leave a Reply