In Machine Learning, a decision tree is a decision support tool that uses a graphical or tree model of decisions and their possible consequences, including the results of random events, resource costs, and utility. This is a way of displaying an algorithm that contains only conditional control statements. In this article, I will take you through how we can visualize a decision tree using Python.
Visualizing a Decision tree is very much different from the visualization of data where we have used a decision tree algorithm. So, If you are not very much familiar with the decision tree algorithm then I will recommend you to first go through the decision tree algorithm from here.
How to Visualize a Decision Tree?
If you are a practitioner in machine learning or you have applied the decision tree algorithm before in a lot of classification tasks then you must be confused about why I am stressing to visualize a decision tree. Just look at the picture down below.
In the right side, we have a visualization of the output we get when we use a decision tree algorithm on data to predict the possibilities. In the left side, we have the structure that a decision tree algorithm follows to make predictions by making trees.
So, I hope now you know what’s the difference between visualizing the decision tree algorithm on the data, and to visualize the structure of a decision tree algorithm. Now let’s see how we can visualize a decision tree.
Visualize a Decision Tree
To explain you the process of how we can visualize a decision tree, I will use the iris dataset which is a set of 3 different types of iris species (Setosa, Versicolour, and Virginica) petal and sepal length, which is stored in a NumPy array dimension of 150×4. Now, let’s import the necessary libraries to get started with the task of visualizing a decision tree:
Now, let’s load the iris dataset and have a quick look at the first 5 rows of the data by using the pandas.head() method:
Train a Decision Tree
For visualizing a decision tree, the first step is to train it on the data, because the visualization of a decision tree is nothing but the structure that it will use to make predictions. So, to visualize the structure of the predictions made by a decision tree, we first need to train it on the data:
clf = tree.DecisionTreeClassifier() clf = clf.fit(iris.data, iris.target)
Now, we can visualize the structure of the decision tree. For this, we need to use a package known as graphviz, which can be easily installed by using the pip command – pip install graphviz. Now, if you have installed this package successfully, let’s move forward for the task of visualizing the decision tree:
!pip install graphviz import graphviz dot_data = tree.export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True) graph = graphviz.Source(dot_data) graphCode language: PHP (php)
In the output, we can see the structure of the decision tree that is used in making predictions on the data. But these are numerical values which means a lot in machine learning, but to make this task interesting let’s visualize the graphical representation of each step involved in the structure of the decision tree.
Graphical Visualization of Each Step
For this task, we need to install another package known as dtreeviz, which can be easily installed by using the pip command – pip install dtreeviz. Now, if you have installed this package successfully let’s see how we can visualize the graphical representation of each step involved in making predictions:
In the output above, we can see the distribution for each class at each node, you can also see where is the decision boundary for each split, and can see the sample size at each leaf as the size of the circle.
I hope you liked this article on how we can visualize the structure of a decision tree. Feel free to ask your valuable questions in the comments section below. You can also follow me on Medium to learn every topic of Machine Learning.