Binary Classification Model

Binary Classification is a type of classification model that have two label of classes. For example an email spam detection model contains two label of classes as spam or not spam. Most of the times the tasks of binary classification includes one label in a normal state, and another label in an abnormal state. In this article I will take you through Binary Classification in Machine Learning using Python.

MNIST Dataset

I will be using the MNIST dataset, which is a set of 70,000 small images of digits handwritten by high school students and employees of the US Census Bureau. Each image is labeled with the digit it represents. This set has been studied so much that it is often called the “hello world” of Machine Learning.

Whenever people come up with new classification algorithm they are curious to see how it will perform on MNIST, and anyone who learns Machine Learning tackles this dataset sooner or later. So let’s import some libraries to start with our Binary Classification model:

# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)Code language: Python (python)

Scikit-Learn provides many helper functions to download popular datasets. MNIST is one of them. The following code fetches the MNIST dataset:

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()Code language: Python (python)
dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'DESCR', 'details', 'categories', 'url'])

Now Let’s look at the data:

X, y = mnist["data"], mnist["target"]
X.shapeCode language: Python (python)

(70000, 784)

y.shapeCode language: Python (python)


28 * 28Code language: Python (python)


There are 70,000 images, and each image has 784 features. This is because each image is 28×28 pixels, and each feature simply represents one pixel’s intensity, from 0 (white) to 255(black). Let’s take a peak at one digit from the dataset. All you need to do is grab an instance’s feature vector, reshape it to a 28×28 array, and display it using Matplotlib’s imshow() function:

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)

save_fig("some_digit_plot") language: Python (python)
Binary Classification

This looks like a 5, and indeed that’s what the label tells us:

y[0]Code language: Python (python)


Note that the label is a string. Most Machine Learning Algorithms expect numbers, so let’s cast y to integer:

y = y.astype(np.uint8)Code language: Python (python)

Now before training a Binary Classification model, let’ have a look at the digits:

def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap =,
def plot_digits(instances, images_per_row=10, **options):
    size = 28
    images_per_row = min(len(instances), images_per_row)
    images = [instance.reshape(size,size) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row + 1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((size, size * n_empty)))
    for row in range(n_rows):
        rimages = images[row * images_per_row : (row + 1) * images_per_row]
        row_images.append(np.concatenate(rimages, axis=1))
    image = np.concatenate(row_images, axis=0)
    plt.imshow(image, cmap =, **options)
example_images = X[:100]
plot_digits(example_images, images_per_row=10)
save_fig("more_digits_plot") language: Python (python)

You should always create a test set and set it aside before inspecting the data closely. The MNIST dataset is actually already split into a training set and a test set:

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]Code language: Python (python)

Training a Binary Classification Model

Let’s simply the problem for now and only try to identify one digit. For example, the number 5. This “5 detector” will be an example of a binary classification, capable of distinguishing between just two classes, 5 and not 5. Let’s create the target vectors for the classification task:

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)Code language: Python (python)

Now let’s pick a classification model and train it. A good place to start is with a Stochastic Gradient Descent (SGD) deals with training instances independently, one at a time, as we will se later in our binary classification model. Let’s build a binary classification using the SGDClassifier and train it on the whole training set:

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42), y_train_5)Code language: Python (python)
SGDClassifier(alpha=0.0001, average=False, class_weight=None, early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=1000, n_iter_no_change=5, n_jobs=None, penalty='l2', power_t=0.5, random_state=42, shuffle=True, tol=0.001, validation_fraction=0.1, verbose=0, warm_start=False)
sgd_clf.predict([some_digit])Code language: Python (python)

array([ True])

The classifier guesses that this image represents a 5 (True). Looks like it guessed right in this particular case. Now let’s evaluate the performance of our binary classification model.

Performance Measures

Evaluating a Classifier is often trickier than evaluating a regressor, so we will spend some more part of this article to evaluate our binary classification model.

Implementing Cross-Validation on Binary Classification Model

Occasionally you will need more control over the cross-validation process than what scikit-learn provides off the shelf. In these cases, you can implement cross-validation yourself. The following code does roughly the same thing as Scikit-learn’s cross_val_score() function does, and it prints the same result:

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits=3, random_state=42)

for train_index, test_index in skfolds.split(X_train, y_train_5):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = y_train_5[train_index]
    X_test_fold = X_train[test_index]
    y_test_fold = y_train_5[test_index], y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    print(n_correct / len(y_pred))Code language: Python (python)


The StratifiedKFold class performs stratified sampling to produce folds that contain a representative ratio of each class. At each iteration the code creates a clone of the classification model, trains that clone on the training folds, and make predictions on the test fold. Then it counts the number of correct predictions and outputs the ratio of correct predictions.

Let’s use the cross_val_score() function to evaluate our SGDClassifier model, using K-fold cross-validation with three folds. Remember that K-fold cross-validation means splitting the training set into K folds, then making predictions and evaluating them on each fold using a model trained on the remaining folds:

from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")Code language: Python (python)

array([0.95035, 0.96035, 0.9604 ])

Wow! Above 93% accuracy on all cross-validation folds. Well, before you get too exited, let’s look at a very dumb classifier that just classifies every single image in the “not 5” class:

from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
    def fit(self, X, y=None):
    def predict(self, X):
        return np.zeros((len(X), 1), dtype=bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")Code language: Python (python)

array([0.91125, 0.90855, 0.90915])

Also, Read: Generate WordClouds with Python.

That’s right it has over 90% accuracy. this is simply because only about 10% of the images are 5s, so if you always guess that an image is not a 5, you will be right about 90% of the time. So I hope you liked this article on Binary Classification Model in Machine Learning. Feel free to ask you valuable questions in the comments section below.

Follow Us:

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1498

Leave a Reply