K-Nearest Neighbors (KNN) with Python

There are so many classification algorithms in scikit-learn that we can use to train a machine learning model. In this article, I will introduce you to a machine learning tutorial on K-Nearest Neighbors (KNN) with Python programming language.

Introduction to K-Nearest Neighbors (KNN)

Creating a machine learning model is all about storing the learning set. To predict a new data point, the K-Nearest Neighbors (KNN) algorithm finds the point in the training set that is closest to the new point. Then, it assigns the label of this learning point to the new data point.

Also, Read – 100+ Machine Learning Projects Solved and Explained.

The k in the k nearest neighbors means that instead of using only the neighbor closest to the new data point, we can consider any fixed number k of neighbors in the training (e.g. three or five nearest neighbors). Then we can make a prediction using the majority class among these neighbors.

All of scikit-learn’s machine learning models are implemented in their classes, called Estimator classes. The k-nearest neighbors (KNN) classification algorithm is implemented in the KNeighborsClassifier class in the neighbors module.

Machine Learning Tutorial on K-Nearest Neighbors (KNN) with Python

The data that I will be using for the implementation of the KNN algorithm is the Iris dataset, a classic dataset in machine learning and statistics. The Iris dataset is included in the datasets module of Scikit-learn. We can easily import it by calling the load_iris function:

from sklearn.datasets import load_iris
iris_dataset = load_iris()

Now the next step is to split the data into training and test sets. Scikit-learn contains a function that shuffles the dataset and splits it into two. The function is known as the train_test_split function:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    iris_dataset['data'], iris_dataset['target'], random_state=0)

This function will extract 75% of the rows of the data as a training set, along with the corresponding labels for that data. The remaining 25% of the data is known as a test set. How much data you want to put into the training and the test set, respectively, is not fixed, but using a test set that contains 25% of the data is meant to be a very good approach.

Data Visualization

Before building a machine learning model, it is often a good idea to inspect the data, see if the task is easily solvable without machine learning, or if the desired information may not be contained in the data. One of the best ways to inspect the data is to visualize it. 

One way to do this is by using a scatter plot. A scatter plot places one feature along the x-axis and another along the Y-axis, and draw a point for each data point: 

iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# create a scatter matrix from the dataframe, color by y_train
pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15),
                           marker='o', hist_kwds={'bins': 20}, s=60,
                           alpha=.8, cmap=mglearn.cm3)
scatter plot for iris dataset

From the scatter plot above, we can see that the three classes appear relatively well separated using sepal and petal measurements. A machine learning model will likely able to learn to separate them.

K-Nearest Neighbors with Python

Now we can start building the actual machine learning model, namely the K-Nearest Neighbors. This is when we will define the model parameters. The most important parameter of the KNeighbors classifier is the number of neighbors, which we will set to 1:

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)

To build the model on the training set, we need to call the fit method of the KNN object:

knn.fit(X_train, y_train)

The fit method returns the KNN object itself, so we get a string representation of our classifier. The representation shows us which parameters were used to create the model. Almost all of them are the defaults, but you can also find n_neighbors = 1, which is the parameter we passed.

Making Predictions

We can now make predictions using this model on new data for which we may not know the correct labels. To make predictions, we need to call the predict method of the KNN object:

prediction = knn.predict(X_new)
print("Prediction:", prediction)
print("Predicted target name:",
Prediction: [0]
Predicted target name: ['setosa']

Our K-Nearest Neighbors model predicts that this new iris belongs to class 0, meaning its species is setosa. I hope you liked this article on a machine learning tutorial on K-Nearest Neighbors (KNN) with Python programming language. Feel free to ask your valuable questions in the comments section below.

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1498

Leave a Reply