Iris Flower Classification is one of the most popular case studies among the data science community. Almost every data science newbie has solved this case study once in their life. Here, you are given the measurements associated with each species of iris flower and based on this data, you have to train a machine learning model for the task of classifying iris flowers. So if you are new to machine learning and have never tried to solve this case study, this article is for you. In this article, I will walk you through Iris Flower Classification with machine learning using Python.
Iris Flower Classification
Iris flower has three species; setosa, versicolor, and virginica, which differs according to their measurements. Now assume that you have the measurements of the iris flowers according to their species, and here your task is to train a machine learning model that can learn from the measurements of the iris species and classify them.
I hope you now have understood the case study of iris flower classification. Although the Scikit-learn library provides a dataset for iris flower classification, you can also download the same dataset from here for the task of iris flower classification with Machine Learning. Now in the section below, I will take you through how we can classify the iris flower species with machine learning using the Python programming language.
Iris Flower Classification using Python
I will start the task of Iris flower classification by importing the necessary Python libraries and the dataset that we need for this task:
import pandas as pd import matplotlib.pyplot as plt import numpy as np iris = pd.read_csv("IRIS.csv")
Now let’s have a look at the first five rows of this dataset:
print(iris.head())
sepal_length sepal_width petal_length petal_width species 0 5.1 3.5 1.4 0.2 Iris-setosa 1 4.9 3.0 1.4 0.2 Iris-setosa 2 4.7 3.2 1.3 0.2 Iris-setosa 3 4.6 3.1 1.5 0.2 Iris-setosa 4 5.0 3.6 1.4 0.2 Iris-setosa
Now let’s have a look at the descriptive statistics of this dataset:
print(iris.describe())
sepal_length sepal_width petal_length petal_width count 150.000000 150.000000 150.000000 150.000000 mean 5.843333 3.054000 3.758667 1.198667 std 0.828066 0.433594 1.764420 0.763161 min 4.300000 2.000000 1.000000 0.100000 25% 5.100000 2.800000 1.600000 0.300000 50% 5.800000 3.000000 4.350000 1.300000 75% 6.400000 3.300000 5.100000 1.800000 max 7.900000 4.400000 6.900000 2.500000
The target labels of this dataset are present in the species column, let’s have a quick look at the target labels:
print("Target Labels", iris["species"].unique())
Target Labels ['Iris-setosa' 'Iris-versicolor' 'Iris-virginica']
Now let’s plot the data using a scatter plot which will plot the iris species according to the sepal length and sepal width:
import plotly.express as px fig = px.scatter(iris, x="sepal_width", y="sepal_length", color="species") fig.show()

Iris Classification Model
Now let’s train a machine learning model for the task of classifying iris species. Here, I will first split the data into training and test sets, and then I will use the KNN classification algorithm to train the iris classification model:
Now let’s input a set of measurements of the iris flower and use the model to predict the iris species:
x_new = np.array([[5, 2.9, 1, 0.2]]) prediction = knn.predict(x_new) print("Prediction: {}".format(prediction))
Prediction: ['Iris-setosa']
Summary
So this is how you can train a machine learning model for the task of Iris classification using Python. Iris Classification is one of the most popular case studies among the data science community. Almost every data science newbie has solved this case study once in their life. I hope you liked this article on the task of classifying Iris species with machine learning using Python. Feel free to ask your valuable questions in the comments section below.