In the process of data science, a dataset can tell many stories. A good way to start the data processing is by finding correlations between the features in the dataset. In this article, I will take you through how to find a correlation using Python.
What is Correlation?
Correlation is a way to determine if two features in a dataset are related in any way. Correlations have many applications in the real world. We can see if the usage of certain search terms correlates with youtube views. Or, we can see if the ads correlate with sales.
When building machine learning models, correlations are an important factor in determining functionality. Not only can this help us see which features are linearly related, but if the features are highly correlated, we can remove them to avoid duplication of information.
Finding Correlation using Python
In this section, I will take you through how to find the correlation between the features by using Python. Here, I will use three Python libraries; Pandas, matplotlib, and Seaborn. Now let’s start the task by importing the dataset and some data processing:
import pandas as pd movies = pd.read_csv("MoviesOnStreamingPlatforms_updated.csv") movies.head()
The dataset is based on the movies from streaming platforms. There is a “Type” column in the data which seems like it is not entered carefully, so I will drop this column:
movies['Rotten Tomatoes'] = movies["Rotten Tomatoes"].str.replace("%", "").astype(float) movies.drop("Type", inplace=True, axis=1)
Using the pandas.corr() method, we can examine the correlations for all the numeric columns in the DataFrame. Since this is a method, all we have to do is call it on the DataFrame. The return value will be a new DataFrame which will show the correlations between the features:
correlations = movies.corr() correlations
We can also look at the correlations of all the features concerning only one feature:
Unnamed: 0 -0.254391 ID -0.254391 Year 1.000000 IMDb -0.021181 Rotten Tomatoes -0.057137 Netflix 0.258533 Hulu 0.098009 Prime Video -0.253377 Disney+ -0.046819 Runtime 0.081984 Name: Year, dtype: float64
It’s a bit easier to read and sufficient if you only look at the correlations for 1 variable. But, if you want to look at the correlations between all the features in the dataset, visualizing it will be a better choice than looking at the numeric values. We can use the seaborn’s heatmap method to visualize correlations:
import seaborn as sns import matplotlib.pyplot as plt sns.heatmap(correlations) plt.show()
I hope you liked this article on how to visualize and find correlation using Python between the features in the datasets. Feel free to ask your valuable questions in the comments section below.