Chapter 2: Mean Shift

Let’s implement Mean Shift from scratch. First, we’ll have to define a Mean Shift object.

import numpy as np

from utils.kernels import RBF
from utils.distances import euclidean

class MeanShift:

    def __init__(self, bandwidth=1, tol=1E-7):

        self.bandwidth = bandwidth
        self.tol = 1 - tol
        self.kernel = RBF(gamma=self.bandwidth)

The bandwidth parameter is there to parameterize the Radial Basis Function kernel. Now, let’s assumed that we have a trained model. This means that we have centers representing our clusters. Assigning a new point to a cluster comes down to assigning the point to its closest cluster. In this case we will use the Euclidean distance.

    def _compute_labels(self, X, centers):

        labels = []

        for x in X:

            distances = np.array([euclidean(x, center) for center in centers])
            label = np.argmin(distances)

        _, labels = np.unique(labels, return_inverse=True)
        return np.array(labels,

    def predict(self, X):

        labels = self._compute_labels(X, self.cluster_centers_)

        return labels

Now, let’s look at how we can train our model. Given some data, we first start by creating a center for each point in the data. Then, until convergence, we shift and merge centers.

    def fit(self, X):

        for labels, centers in self._fit(X):

            self.labels_ = labels
            self.cluster_centers_ = centers

        return self

    def _fit(self, X):

        old_centers = np.array([])
        new_centers = X
        labels = -np.ones(len(X))  # -1 represents an "orphan"

        while not self._has_converged(old_centers, new_centers):

            yield labels, new_centers

            old_centers = new_centers
            new_centers = []

            for center in old_centers:

                shifted_center = self._shift(center, X)

            new_centers = self._merge_centers(new_centers)
            labels = self._compute_labels(X, new_centers)

An important function is the _shift function. To shift a center, we calculate the density values between the center and all points. Then, the new center is created by taking a weighted average of the data points. The difference in position between the old and new center is what is referred to as the shift.

    def _shift(self, x, X):

        densities = [self.kernel(x, x_) for x_ in X]

        shifted_center = np.average(X, weights=densities, axis=0)

        return shifted_center

Since all centers will eventually converge, some centers might need to be merged to speed up computation. Also, because of computer arithmetic, centers will rarely be exactly at the same position. Therefore, we redefine each center as the average of all centers that are within a certain high-density region around it. This way, we end up with identical centers, which we merge.

    def _merge_centers(self, centers):

        centers = np.unique(centers, axis=0)
        new_centers = []

        for c in centers:
            distances = np.array([self.kernel(c, c_) for c_ in centers])
            new_centers.append(np.mean(centers[distances > self.tol], axis=0))

        centers = np.unique(new_centers, axis=0)

        return centers

In our case, we define convergence as the moment where the shifted centers are “close enough” to the old centers.

    def _has_converged(self, old, new):

        if len(old) == len(new):

            for i in range(len(new)):
                if self.kernel(old[i], new[i]) < 1.0:
                    return False

            return True
            return False