Matplotlib Tutorial for Data Science

This article is all about Matplotlib, the basic data visualization tool of Python programming language for Data Science.

Here I will discuss various plot types with Matplotlib and customization techniques associated with Data Science.

Introduction to Matplotlib

Matplotlib is the basic plotting library of Python programming language. It is the most prominent tool among Python visualization packages. Matplotlib is highly efficient in performing wide range of tasks.

It can produce publication quality figures in a variety of formats. It can export visualizations to all of the common formats like PDF, SVG, JPG, PNG, BMP and GIF.

It can create popular visualization types – line plot, scatter plot, histogram, bar chart, error charts, pie chart, box plot, and many more types of plot.

Matplotlib also supports 3D plotting. Many Python libraries are built on top of Matplotlib. For example, pandas and Seaborn are built on Matplotlib. They allow to access Matplotlib’s methods with less code.

Let’s Start with Matplotlib by importing the dependencies and matplotlib

# Import dependencies

import numpy as np
import pandas as pd
# Import Matplotlib

import matplotlib.pyplot as plt 

Displaying Plots in Matplotlib

x1 = np.linspace(0, 10, 100)
# create a plot figure
fig = plt.figure()

plt.plot(x1, np.sin(x1), '-')
plt.plot(x1, np.cos(x1), '--')
plt.show() # to show the plot

Matplotlib Object Hierarchy

There is an Object Hierarchy within Matplotlib. In Matplotlib, a plot is a hierarchy of nested Python objects. A hierarchy means that there is a tree-like structure of Matplotlib objects underlying each plot.

Figure object is the outermost container for a Matplotlib plot. The Figure object contain multiple Axes objects. So, the Figure is the final graphic that may contain one or more Axes. The Axes represent an individual plot.

So, we can think of the Figure object as a box-like container containing one or more Axes. The Axes object contain smaller objects such as tick marks, lines, legends, title and text-boxes.

Matplotlib API Overview

Matplotlib has two APIs to work with. A MATLAB-style state-based interface and a more powerful object-oriented (OO) interface. The former MATLAB-style state-based interface is called pyplot interface and the latter is called Object-Oriented interface.

There is a third interface also called pylab interface. It merges pyplot (for plotting) and NumPy (for mathematical functions) together in an environment closer to MATLAB.

This is considered bad practice nowadays. So, the use of pylab is strongly discouraged and hence, I will not discuss it any further.

# create a plot figure
plt.figure()

# create the first of two panels and set current axis
plt.subplot(2, 1, 1)   # (rows, columns, panel number)
plt.plot(x1, np.sin(x1))


# create the second of two panels and set current axis
plt.subplot(2, 1, 2)   # (rows, columns, panel number)
plt.plot(x1, np.cos(x1))
plt.show()

Visualization with Pyplot

Generating visualization with Pyplot is very easy. The x-axis values ranges from 0-3 and the y-axis from 1-4. If we provide a single list or array to the plot() command, matplotlib assumes it is a sequence of y values, and automatically generates the x values.

Since python ranges start with 0, the default x vector has the same length as y but starts with 0. Hence the x data are [0,1,2,3] and y data are [1,2,3,4].

plt.plot([1, 2, 3, 4])
plt.ylabel('Numbers')
plt.show()

State-machine interface

Pyplot provides the state-machine interface to the underlying object-oriented plotting library. The state-machine implicitly and automatically creates figures and axes to achieve the desired plot. For example:

x = np.linspace(0, 2, 100)

plt.plot(x, x, label='linear')
plt.plot(x, x**2, label='quadratic')
plt.plot(x, x**3, label='cubic')

plt.xlabel('x label')
plt.ylabel('y label')

plt.title("Simple Plot")

plt.legend()

plt.show()

Formatting the style of plot

For every x, y pair of arguments, there is an optional third argument which is the format string that indicates the color and line type of the plot. The letters and symbols of the format string are from MATLAB.

We can concatenate a color string with a line style string. The default format string is ‘b-‘, which is a solid blue line. For example, to plot the above line with red circles, we would issue the following command:-

plt.plot([1, 2, 3, 4], [1, 4, 9, 16], 'ro')
plt.axis([0, 6, 0, 20])
plt.show()

The axis() command in the example above takes a list of [xmin, xmax, ymin, ymax] and specifies the viewport of the axes.

Working with NumPy arrays

Generally, we have to work with NumPy arrays. All sequences are converted to numpy arrays internally.

The below example illustrates plotting several lines with different format styles in one command using arrays.

# evenly sampled time at 200ms intervals
t = np.arange(0., 5., 0.2)

# red dashes, blue squares and green triangles
plt.plot(t, t, 'r--', t, t**2, 'bs', t, t**3, 'g^')
plt.show()

Object-Oriented API

The Object-Oriented API is available for more complex plotting situations. It allows us to exercise more control over the figure.

In Pyplot API, we depend on some notion of an “active” figure or axes. But, in the Object-Oriented API the plotting functions are methods of explicit Figure and Axes objects.

Figure is the top level container for all the plot elements. We can think of the Figure object as a box-like container containing one or more Axes.

The Axes represent an individual plot. The Axes object contain smaller objects such as axis, tick marks, lines, legends, title and text-boxes.

The following code produces sine and cosine curves using Object-Oriented API:

# First create a grid of plots
# ax will be an array of two Axes objects
fig, ax = plt.subplots(2)


# Call plot() method on the appropriate object
ax[0].plot(x1, np.sin(x1), 'b-')
ax[1].plot(x1, np.cos(x1), 'b-')

Objects and Reference

The main idea with the Object Oriented API is to have objects that one can apply functions and actions on. The real advantage of this approach becomes apparent when more than one figure is created or when a figure contains more than one subplot.

We create a reference to the figure instance in the fig variable. Then, we ceate a new axis instance axes using the add_axes method in the Figure class instance fig as follows:

fig = plt.figure()

x2 = np.linspace(0, 5, 10)
y2 = x2 ** 2

axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])

axes.plot(x2, y2, 'r')

axes.set_xlabel('x2')
axes.set_ylabel('y2')
axes.set_title('title')

Parts of a Plot

There are different parts of a plot. These are title, legend, grid, axis and labels etc. These are denoted in the following figure:-

Parts%20of%20a%20plot.png

First plot with Matplotlib

Now, I will start producing plots. Here is the first example:-

plt.plot([1, 3, 2, 4], 'b-')
plt.show()

plt.plot([1, 3, 2, 4], 'b-')

This code line is the actual plotting command. Only a list of values has been plotted that represent the vertical coordinates of the points to be plotted.

Matplotlib will use an implicit horizontal values list, from 0 (the first value) to N-1 (where N is the number of items in the list).

Specify both Lists

Also, we can explicitly specify both the lists as follows:

x3 = np.arange(0.0, 6.0, 0.01) 
plt.plot(x3, [xi**2 for xi in x3], 'b-') 
plt.show()

Multiline Plots

Multiline Plots mean plotting more than one plot on the same figure. We can plot more than one plot on the same figure.
It can be achieved by plotting all the lines before calling show(). It can be done as follows:

x4 = range(1, 5)
plt.plot(x4, [xi*1.5 for xi in x4])
plt.plot(x4, [xi*3 for xi in x4])
plt.plot(x4, [xi/3.0 for xi in x4])
plt.show()

 Saving the Plot

We can save the figures in a wide variety of formats. We can save them using the savefig() command as follows:

# Saving the figure
fig.savefig('plot1.png')

Scatter Plot

Another commonly used plot type is the scatter plot. Here the points are represented individually with a dot or a circle.

Scatter Plot with plt.plot()

We have used plt.plot/ax.plot to produce line plots. We can use the same functions to produce the scatter plots as follows:

x7 = np.linspace(0, 10, 30)
y7 = np.sin(x7)
plt.plot(x7, y7, 'o', color = 'black')

Histogram

Histogram charts are a graphical display of frequencies. They are represented as bars. They show what portion of the dataset falls into each category, usually specified as non-overlapping intervals. These categories are called bins.

The plt.hist() function can be used to plot a simple histogram as follows:

data1 = np.random.randn(1000)
plt.hist(data1)

Bar Chart

Bar charts display rectangular bars either in vertical or horizontal form. Their length is proportional to the values they represent. They are used to compare two or more values.

We can plot a bar chart using plt.bar() function. We can plot a bar chart as follows:-

data2 = [5. , 25. , 50. , 20.]
plt.bar(range(len(data2)), data2)
plt.show() 

Horizontal Bar Chart

We can produce Horizontal Bar Chart using the plt.barh() function. It is the strict equivalent of plt.bar() function.

data2 = [5. , 25. , 50. , 20.]
plt.barh(range(len(data2)), data2)
plt.show() 

Error Bar Chart

In experimental design, the measurements lack perfect precision. So, we have to repeat the measurements. It results in obtaining a set of values.

The representation of the distribution of data values is done by plotting a single data point (known as mean value of dataset) and an error bar to represent the overall distribution of data.

We can use Matplotlib’s errorbar() function to represent the distribution of data values. It can be done as follows:

x9 = np.arange(0, 4, 0.2)
y9 = np.exp(-x9)
e1 = 0.1 * np.abs(np.random.randn(len(y9)))
plt.errorbar(x9, y9, yerr = e1, fmt = '.-')
plt.show()

Stacked Bar Chart

We can draw stacked bar chart by using a special parameter called bottom from the plt.bar() function. It can be done as follows:

A = [15., 30., 45., 22.]
B = [15., 25., 50., 20.]
z2 = range(4)
plt.bar(z2, A, color = 'b')
plt.bar(z2, B, color = 'r', bottom = A)
plt.show()

Pie Chart

Pie charts are circular representations, divided into sectors. The sectors are also called wedges. The arc length of each sector is proportional to the quantity we are describing.

It is an effective way to represent information when we are interested mainly in comparing the wedge against the whole pie, instead of wedges against each other.

Matplotlib provides the pie() function to plot pie charts from an array X. Wedges are created proportionally, so that each value x of array X generates a wedge proportional to x/sum(X).

plt.figure(figsize=(7,7))
x10 = [35, 25, 20, 20]
labels = ['Computer', 'Electronics', 'Mechanical', 'Chemical']
plt.pie(x10, labels=labels)
plt.show()

Boxplot

Boxplot allows us to compare distributions of values by showing the median, quartiles, maximum and minimum of a set of values.

We can plot a boxplot with the boxplot() function as follows:

data3 = np.random.randn(100)
plt.boxplot(data3)
plt.show()

The boxplot() function takes a set of values and computes the mean, median and other statistical quantities. The following points describe the preceeding boxplot:

• The red bar is the median of the distribution.

• The blue box includes 50 percent of the data from the lower quartile to the upper quartile. Thus, the box is centered on the median of the data.

• The lower whisker extends to the lowest value within 1.5 IQR from the lower quartile.

• The upper whisker extends to the highest value within 1.5 IQR from the upper quartile.

• Values further from the whiskers are shown with a cross marker.

Aman Kharwal
Aman Kharwal

Data Strategist at Statso. My aim is to decode data science for the real world in the most simple words.

Articles: 1614

Leave a Reply

Discover more from thecleverprogrammer

Subscribe now to keep reading and get access to the full archive.

Continue reading