MNIST Digits classification is one of the popular case studies in the data science community. It is based on the problem of Classification in Machine Learning. If you are a beginner in Machine Learning, you must try to solve this machine learning problem. If you want to learn how to solve the problem of MNIST digits classification, this article is for you. In this article, I will take you through the task of MNIST digits classification with Machine Learning using Python.
MNIST Digits Classification with Machine Learning
MNIST digits dataset is a vast collection of handwritten digits. This dataset is used for training image processing systems. It is also a very popular dataset used in universities to explain the concepts of classification in machine learning. It contains 60,000 training images and 10,000 testing images. You can download this dataset from various data sources on the internet. But as this dataset is already available in this Scikit-learn library, we will be using this dataset by importing it from the Scikit-learn library.
In the section below, I will take you through the task of MNIST digits classification with machine learning using the Python programming language.
MNIST Digits Classification using Python
Let’s start this task by importing the necessary Python libraries and the dataset:
import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.datasets import fetch_openml data = fetch_openml("mnist_784", version=1) print(data)
Let’s have a look at the shape of the data before moving forward:
x, y = data["data"], data["target"] print(x.shape)
(70000, 784)
The dataset contains 70,000 rows and 784 columns. Now before moving forward, let’s split the data into a training set and a test set:
xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.2, random_state=42)
Let’s look at the sample of the kind of images we have in the dataset:
image = np.array(xtrain.iloc[0]).reshape(28, 28) plt.imshow(image)

So in the training set, the first handwritten image represents the number 5. Now, let’s train a classification model. Here, I will be using the stochastic gradient descent classification algorithm:
from sklearn.linear_model import SGDClassifier model = SGDClassifier() model.fit(xtrain, ytrain)
Now let’s test the trained model by making predictions on the test set:
predictions = model.predict(xtest) print(predictions)
['8' '4' '8' ... '2' '7' '1']
So the first predicted image is 8, and the second predicted image is 4. Let’s look at the handwritten digits images to evaluate our predictions:
image = np.array(xtest.iloc[0]).reshape(28, 28) plt.imshow(image)

image = np.array(xtest.iloc[1]).reshape(28, 28) plt.imshow(image)

So this is how you can solve the problem of MNIST digits classification with Machine Learning.
Summary
This is how you can solve the problem of handwritten digits classification with Machine Learning using Python. It is one of the popular case studies in the data science community. It is based on the problem of Classification in Machine Learning. I hope you liked this article on handwritten digits Classification with Machine Learning. Feel free to ask valuable questions in the comments section below.