Cross-Validation in Machine Learning

In cross-validation, we run the process of our machine learning model on different subsets of data to get several measures of model quality. For example, we could start by dividing the data into 5 parts, each 20% of the full data set.

Cross-validation gives a more accurate measure of model quality, which is especially important if you make a lot of decisions based on your machine learning model. In this article, I will introduce you to the concept of cross-validation in machine learning.

Also, Read – Data Leakage in Machine Learning.

When To use Cross-Validation?

So now the main question arises: when to use the cross-validation algorithm in your machine learning task?

So, you should use cross-validation:

  • For small datasets: when the additional computational load isn’t a big deal, you should run cross-validation.
  • For large datasets: when a single validation set is sufficient. Your code will generally run faster, and you may have enough data that you hardly need to reuse some of it for blocking.

There are no simple criteria for what constitutes a large or small data set. But if your model takes a few minutes or less to run, it’s probably worth switching to cross-validation.

You can also run cross-validation and see if the scores for each experiment look close. If each experiment gives the same results, a single validation set is probably sufficient.

Cross-Validation in Action

Let’s see how we can use the cross validation algorithm in machine learning by using the cross_val_score method provided by scikit-learn. I will start by loading the dataset. The dataset I am using in this task can be downloaded from here:

import pandas as pd

# Read the data
data = pd.read_csv('melb_data.csv')

# Select subset of predictors
cols_to_use = ['Rooms', 'Distance', 'Landsize', 'BuildingArea', 'YearBuilt']
X = data[cols_to_use]

# Select target
y = data.PriceCode language: PHP (php)

Now we need to define a pipeline which uses impute to fill in missing values ​​and random forest model to make predictions. While it is possible to do cross-validation without pipelines, it is quite difficult! Using a pipeline will make the code remarkably simple:

from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer

my_pipeline = Pipeline(steps=[('preprocessor', SimpleImputer()),
                              ('model', RandomForestRegressor(n_estimators=50,
                                                              random_state=0))
                             ])Code language: JavaScript (javascript)

We get the cross-validation scores with the cross_val_score () function of scikit-learn. We define the number of folds with the cv parameter:

from sklearn.model_selection import cross_val_score

# Multiply by -1 since sklearn calculates *negative* MAE
scores = -1 * cross_val_score(my_pipeline, X, y,
                              cv=5,
                              scoring='neg_mean_absolute_error')

print("MAE scores:\n", scores)Code language: PHP (php)
MAE scores:
 [301628.7893587  303164.4782723  287298.331666   236061.84754543
 260383.45111427]

Now we generally want a single measure of model quality to compare alternative models. We, therefore, take the average of the experiments:

print("Average MAE score (across experiments):")
print(scores.mean())Code language: PHP (php)
Average MAE score (across experiments):
277707.3795913405

Cross-Validation shows a much better performance of our model quality, with the added benefit of cleaning up our code: note that we no longer need to keep track of separate training and validation sets. So especially for small datasets, this is a good improvement.

Also, Read – What are Neural Networks and How they work?

I hope you liked this article on the concept of Cross Validation in Machine Learning. Feel free to ask your valuable questions in the comments section below. You can also follow me on Medium to learn every topic of Machine Learning.

Follow Us:

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1433

2 Comments

Leave a Reply