Visualize a Decision Tree in Machine Learning

In Machine Learning, a decision tree is a decision support tool that uses a graphical or tree model of decisions and their possible consequences, including the results of random events, resource costs, and utility. This is a way of displaying an algorithm that contains only conditional control statements. In this article, I will take you through how we can visualize a decision tree using Python.

Visualizing a Decision tree is very much different from the visualization of data where we have used a decision tree algorithm. So, If you are not very much familiar with the decision tree algorithm then I will recommend you to first go through the decision tree algorithm from here.

Also, Read – Visualize Real-Time Stock Prices with Python.

How to Visualize a Decision Tree?

If you are a practitioner in machine learning or you have applied the decision tree algorithm before in a lot of classification tasks then you must be confused about why I am stressing to visualize a decision tree. Just look at the picture down below.

image for post

In the right side, we have a visualization of the output we get when we use a decision tree algorithm on data to predict the possibilities. In the left side, we have the structure that a decision tree algorithm follows to make predictions by making trees.

So, I hope now you know what’s the difference between visualizing the decision tree algorithm on the data, and to visualize the structure of a decision tree algorithm. Now let’s see how we can visualize a decision tree.

Visualize a Decision Tree

To explain you the process of how we can visualize a decision tree, I will use the iris dataset which is a set of 3 different types of iris species (Setosa, Versicolour, and Virginica) petal and sepal length, which is stored in a NumPy array dimension of 150×4. Now, let’s import the necessary libraries to get started with the task of visualizing a decision tree:

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris, load_boston
from sklearn import treeCode language: JavaScript (javascript)

Now, let’s load the iris dataset and have a quick look at the first 5 rows of the data by using the pandas.head() method:

iris = load_iris()
df_iris = pd.DataFrame(iris['data'], 
                       columns=iris['feature_names'])
df_iris['target'] = iris['target']
df_iris.head()Code language: JavaScript (javascript)
iris dataset

Train a Decision Tree

For visualizing a decision tree, the first step is to train it on the data, because the visualization of a decision tree is nothing but the structure that it will use to make predictions. So, to visualize the structure of the predictions made by a decision tree, we first need to train it on the data:

clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

Now, we can visualize the structure of the decision tree. For this, we need to use a package known as graphviz, which can be easily installed by using the pip command – pip install graphviz. Now, if you have installed this package successfully, let’s move forward for the task of visualizing the decision tree:

!pip install graphviz
import graphviz 
dot_data = tree.export_graphviz(clf, out_file=None, 
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  
                     special_characters=True)  

graph = graphviz.Source(dot_data)  
graph Code language: PHP (php)
decision tree

In the output, we can see the structure of the decision tree that is used in making predictions on the data. But these are numerical values which means a lot in machine learning, but to make this task interesting let’s visualize the graphical representation of each step involved in the structure of the decision tree. 

Graphical Visualization of Each Step

For this task, we need to install another package known as dtreeviz, which can be easily installed by using the pip command – pip install dtreeviz. Now, if you have installed this package successfully let’s see how we can visualize the graphical representation of each step involved in making predictions:

!pip install dtreeviz
from dtreeviz.trees import dtreeviz
viz = dtreeviz(clf,
               iris['data'],
               iris['target'],
               target_name='',
               feature_names=np.array(iris['feature_names']),
               class_names={0:'setosa',1:'versicolor',2:'virginica'})
              
vizCode language: JavaScript (javascript)
visualize a decision tree

Also, Read – Build and Deploy a Chatbot with HTML, CSS and Python.

In the output above, we can see the distribution for each class at each node, you can also see where is the decision boundary for each split, and can see the sample size at each leaf as the size of the circle.

I hope you liked this article on how we can visualize the structure of a decision tree. Feel free to ask your valuable questions in the comments section below. You can also follow me on Medium to learn every topic of Machine Learning.

Follow Us:

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1537

Leave a Reply