Correlation using Python

In the process of data science, a dataset can tell many stories. A good way to start the data processing is by finding correlations between the features in the dataset. In this article, I will take you through how to find a correlation using Python.

What is Correlation?

Correlation is a way to determine if two features in a dataset are related in any way. Correlations have many applications in the real world. We can see if the usage of certain search terms correlates with youtube views. Or, we can see if the ads correlate with sales. 

Also, Read – 100+ Machine Learning Projects Solved and Explained.

When building machine learning models, correlations are an important factor in determining functionality. Not only can this help us see which features are linearly related, but if the features are highly correlated, we can remove them to avoid duplication of information.

Finding Correlation using Python

In this section, I will take you through how to find the correlation between the features by using Python. Here, I will use three Python libraries; Pandas, matplotlib, and Seaborn. Now let’s start the task by importing the dataset and some data processing:

import pandas as pd
movies = pd.read_csv("MoviesOnStreamingPlatforms_updated.csv")
movies.head()
data for correlation

The dataset is based on the movies from streaming platforms. There is a “Type” column in the data which seems like it is not entered carefully, so I will drop this column:

movies['Rotten Tomatoes'] = movies["Rotten Tomatoes"].str.replace("%", "").astype(float)
movies.drop("Type", inplace=True, axis=1)

Using the pandas.corr() method, we can examine the correlations for all the numeric columns in the DataFrame. Since this is a method, all we have to do is call it on the DataFrame. The return value will be a new DataFrame which will show the correlations between the features:

correlations = movies.corr()
correlations
correlations

We can also look at the correlations of all the features concerning only one feature:

print(correlations["Year"])
Unnamed: 0        -0.254391
ID                -0.254391
Year               1.000000
IMDb              -0.021181
Rotten Tomatoes   -0.057137
Netflix            0.258533
Hulu               0.098009
Prime Video       -0.253377
Disney+           -0.046819
Runtime            0.081984
Name: Year, dtype: float64

It’s a bit easier to read and sufficient if you only look at the correlations for 1 variable. But, if you want to look at the correlations between all the features in the dataset, visualizing it will be a better choice than looking at the numeric values. We can use the seaborn’s heatmap method to visualize correlations:

import seaborn as sns
import matplotlib.pyplot as plt
sns.heatmap(correlations)
plt.show()
correlation with Python

I hope you liked this article on how to visualize and find correlation using Python between the features in the datasets. Feel free to ask your valuable questions in the comments section below.

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1535

Leave a Reply