Chapter 2: K-Nearest Neighbours
Let’s build a K-Nearest Neighbours model from scratch.
First, we will define some generic KNN
object. In the constructor, we pass three parameters:
- The number of neighbours being used to make predictions
- The distance measure we want to use
- Whether or not we want to use weighted distances
from collections import Counter import numpy as np from utils.distances import euclidean class KNN: def __init__(self, k, distance=euclidean, weighted=False): self.k = k self.weighted = weighted # Whether or not to use weighted distances self.distance = distance
Now we will define the fit
function, which is the function which describes how to train a model. For a K-Nearest Neighbours model, the training is rather simplistic. Indeed, all there needs to be done is to store the training instances as the model’s parameters.
def fit(self, X, y): self.X_ = X self.y_ = y return self
Similarly, we can build an update
function which will update the state of the model as more data points are provided for training. Training a model by feeding it data in a stream-like fashion is often referred to as online learning. Not all models allow for computationally efficient online learning, but K-Nearest Neighbours does.
def update(self, X, y): self.X_ = np.concatenate((self.X_, X)) self.y_ = np.concatenate((self.y_, y)) return self
In order to make predictions, we also need to create a predict
function. For a K-Nearest Neighbours model, a prediction is made in two steps:
- Find the K-nearest neighbours by computing their distances to the data point we want to predict
- Given these neighbours and their distances, compute the predicted output
def predict(self, X): predictions = [] for x in X: neighbours, distances = self._get_neighbours(x) prediction = self._vote(neighbours, distances) predictions.append(prediction) return np.array(predictions)
Retrieving the neighbours can be done by calculating all pairwise distances between the data point and the data stored inside the state of the model. Once these distances are known, the K instances that have the shortest distance to the example are returned.
def _get_neighbours(self, x): distances = np.array([self._distance(x, x_) for x_ in self.X_]) indices = np.argsort(distances)[:self.k] return self.y_[indices], distances[indices]
In case we would like to use weighted distances, we need to compute the weights. By default, these weights are all set to 1 to make all instances equal. To weigh the instances, neighbours that are closer are typically favoured by given them a weight equal to 1 divided by their distance.
If neighbours have distance 0, since we can’t divide by zero, their weight is set to 1, and all other weights are set to 0. This is also how scikit-learn deals with this problem according to their source code.
def _get_weights(self, distances): weights = np.ones_like(distances, dtype=float) if self.weighted: if any(distances == 0): weights[distances != 0] = 0 else: weights /= distances return weights
The only function that we have yet to define is the vote
function that is called in the predict
function. Depending on the implementation of that function, K-Nearest Neighbours can be used for regression, classification, or even as a meta-learner.
KNN for Regression
In order to use K-Nearest Neighbour for regression, the vote
function is defined as the average of the neighbours. In case weighting is used, the vote
function returns the weighted average, favouring closer instances.
class KNN_Regressor(KNN): def _vote(self, targets, distances): weights = self._get_weights(distances) return np.sum(weights * targets) / np.sum(weights)
KNN for Classification
In the classification case, the vote function uses a majority voting scheme. If weighting is used, each neighbour has a different impact on the prediction.
class KNN_Classifier(KNN): def _vote(self, classes, distances): weights = self._get_weights(distances) prediction = None max_weighted_frequency = 0 for c in classes: weighted_frequency = np.sum(weights[classes == c]) if weighted_frequency > max_weighted_frequency: prediction = c max_weighted_frequency = weighted_frequency return prediction