KNN classifier with scikit-learn

Train a KNN classifier with scikit-learn: Your first AI project

Machine learning
KNN classifier with scikit-learn

Artificial intelligence has been around for decades, but it has only recently become mainstream. From 2015 onwards, the development of convolutional networks, coupled with the increasing availability of data, and the GPGPU movement brought a renaissance to the field of AI research. However, it has been the popularization of end-user applications like ChatGPT or Midjourney, which has truly made AI mainstream.

Jumping right into transformer networks or the diffusion process can be a little bit daunting, but luckily for newcomers, there are simpler yet powerful techniques to start with.

One of the most basic AI algorithms is k-nearest neighbors (KNN). KNN is a simple but effective algorithm that can be used for a variety of tasks, including classification, regression, and anomaly detection. Furthermore, one doesn’t have to implement KNN from scratch, we have Python libraries like scikit-learn that have nice implementations.

In this post, you will learn how to build a handwritten digit classifier in Python. You will use the scikit-learn library to train and run inference using the KNN algorithm, and learn some of its inner workings. Additionally, you will also learn about some standard practices such as cross-validation and hyperparameter search. By the end of this tutorial, you should have a solid understanding of the KNN algorithm and be able to use it for your own machine learning projects.

How does the KNN algorithm work?

K Nearest Neighbors (KNN) is probably one of the dumbest machine learning algorithms. The machine learning models that fill the headlines these days take weeks to train, and usually require datacenters with lots of compute cores and memory, instead, the most basic form of KNN does nothing at train time, it just saves the training data. Then, when we need to run the algorithm on a new sample, KNN identifies the closest training samples to the new data point and examines them. Depending on the information we extract from those closest samples, we can run classification, regression, or outlier detection, but classification is by far the most popular application. In this sense, the KNN algorithm just extracts the labels of the closest neighbors and returns the most common one.

Instead, if we want to use KNN for regression or outlier detection we must be a little more clever. For regression, we can get the distance, or similarity, between the new sample and each of the neighbors, and average their features using the similarity as the averaging weight. For outlier detection, we can set a distance threshold, and if no training samples are close enough to the new sample, we can assume it’s an outlier.

Things can also get trickier if the training dataset is large. In this case, the process of finding the closest samples to a given new one can be pretty slow, because KNN would have to be comparing the new sample to every sample in the dataset. This approach is often referred to as brute force search. Instead, a search tree can be built at train time, so that the nearest neighbors search runs faster. This search tree is usually a KD-tree or a ball tree and, although these names may look intimidating, all they do is looking at a single feature at a time, and dividing the training dataset in two splits, in one of those splits the samples with a value lower than a threshold for the selected feature are stored, and in the other split, the samples with a higher value are stored; this process can be repeated in an hierarchical approach, subdividing each of the splits using another feature. Then, at inference time, we just have to follow the tree splits looking at the corresponding features of the new sample and, when we arrive at a leaf node, we perform brute search on a smaller set of the training data.

Finally, another tricky part of the KNN algorithm is the way we compare a pair of samples to see how close they are. The most basic distance is the Euclidean distance, or L2 norm, which measures the straight-line distance between two points in the feature space. Another common option is the Manhattan distance, also called L1 norm or taxicab distance, and it measures the distance by adding up the absolute differences between the features of the two points. Finally, one last common distance metric is the Minkowski distance, which has a “p” parameter that controls the order of the distance, and can imitate the Manhattan distance for p = 1, and the Euclidean distance for p = 2.

UCI ML hand-written digits dataset

A machine learning algorithm would be of little use if we did not have data to train it. For this example we will use the UCI ML hand-written digits datasets, a collection of images of handwritten digits, from 0 to 9, mostly black and white, and where the image just shows the target digit.

We will be specifically using the version included with scikit-learn which is a copy of its test set. This version contains 1797 images from a total of 13 people, where each image is 8x8 pixels. During acquisition, images of 32 by 32 pixels were scaled down by dividing them into non-overlapping blocks of 4 by 4 pixels, and counting the number of white pixels in each block. The original size of 32 pixels, when divided by 4, results in a downscaled size of 8 and, since each block in the original image is 16 pixels, the downscaled image should have pixels with values ranging from 0 to 16.

Instead of manually downloading the dataset, we can load it using the following code:

from sklearn import datasets
digits_datset = datasets.load_digits()

The variable contains the pixel values, and the variable contains the labels for the images. The labels are numbers from 0 to 9, corresponding to the digit that is shown in each image.

Once we have loaded the data, we can display it using matplotlib like this:

import matplotlib.pyplot as plt
for i in range(10):
plt.subplot(1, 10, i+1)
plt.imshow(digits_datset.images[i], cmap='gray')

Which will generate the following plot:

Digits dataset
Figure 1. UCI ML Digits dataset. This dataset contains images of handwritten digits. Each image is 8x8 pixels.

Additionally, since the dataset included in scikit-learn is just the test set, we should subdivide it into our own training and testing sets. For this, we will be using the train_test_split function of the model_selection module, like this:

from sklearn.model_selection import train_test_split
(train_data, test_data, train_labels, test_labels) = train_test_split(,, test_size=0.2, random_state=1337)

For the sake of clarity, the random_state parameter is a seed value controlling how the data will be shuffled. Normally, the data will be randomly shuffled, and each call to the function will shuffle the data differently, however, if we pass a random_state value, we will have a reproducible behavior.

Training and evaluating the model

First, we import the model and create an object. We will be using 20 neighbors, the KD-Tree search, and p = 2 for the Minkowski distance:

from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier(n_neighbors=20, algorithm='kd_tree', p=2)

Now, we can fit the model to the training data using the .fit method:, train_labels)

At this point we have a trained model, and we can get the test set performance using the following code:

from sklearn.metrics import accuracy_score
test_predictions = model.predict(test_data)
accuracy = accuracy_score(test_labels, test_predictions)

Later, we can also save the model for future use. For this, we can use the pickle module:

import pickle
with open('model.pkl','wb') as f:
pickle.dump(model, f)
with open('model.pkl','rb') as f:
loaded_model = pickle.load(f)

How can we select the best parameters?

In the previous section we used 20 neighbors and p = 2 for the Minkowski distance. However, we didn’t check if we could get better performance using other values.

In this kind of situation, we can use grid search. This technique involves exhaustively trying out every combination of different parameters in a discrete range, a grid, and selecting the combination that gets the best performance.

To improve the search process, a common strategy is to do grid search with cross validation. Cross validation is a technique that splits the data into several subsets, k folds. Once divided, a training and evaluation loop is executed k times, each time using k-1 folds for training, and the remaining fold for evaluation. Finally, instead of calculating the accuracy in just one test set, the accuracy is calculated for each validation fold and then averaged.

To do this in Python, we first have to import the GridSearchCV class, Numpy, and create the model, parameter grid, and grid search object:

from sklearn.model_selection import GridSearchCV
import numpy as np
parameters = {'n_neighbors': np.arange(2, 20, 2), 'p': np.arange(1, 4, 1)}
model = KNeighborsClassifier(algorithm='kd_tree')
grid_search_model = GridSearchCV(model, parameters, scoring='accuracy', n_jobs=-1)

The n_jobs = -1 parameter will tell scikit-learn to use all our CPU cores, so that it finishes faster.

Then we can call .fit to do the grid search with cross validation, it will take some time to finish:,

The grid_search_model object will automatically split the data into training and validation sets, and use the cross-validation technique to measure the performance of each parameter combination. Once finished, we can check the best parameters and the best accuracy:


Finally, we can evaluate the best performing model on the original test set to see if we got any improvement:

test_predictions = grid_search_model.predict(test_data)
accuracy = accuracy_score(test_labels, test_predictions)

Although we may be getting better results in the original test set, the cross validation accuracy is the one that we should be trusting the most. This is because that accuracy was evaluated on k folds.


In this post we learned about the KNN algorithm, a good starting point for those looking to enter the field of machine learning. This algorithm stores the training samples and analyzes new ones by comparing them to the training data. To classify new data, it assigns the most common label among its close neighbors.

Despite the satisfactory results we saw, the KNN algorithm will fall short on more challenging benchmarks like the Imagenet dataset. The images in the dataset we used are really small, 8 x 8 pixels, are black and white, and do not have any confounding elements in them, whereas real world images are way more complex.

The next steps could be reading about feature extraction algorithms for images like HOG or SIFT, and more complex classifiers like the SVM or Random Forest. By learning about filter based feature extractors, you will begin familiarizing yourself with the convolution operation, which is one of the cornerstones of convolutional networks. On the other hand, by looking at optimization based models like SVM you will get a better understanding of how machine learning algorithms fit their parameters to the training data.

In summary, while we achieved nice results, there is still a long way to go. This is just a first step into the field of AI, but a good starting point, nonetheless.

We are your AI company

Contact us and we will develop a customized solution with AI, web applications, or IoT.

Book a free consultation
Smart technology illustration