Select Page

# Chapter 2: Mean Shift

Let’s implement Mean Shift from scratch. First, we’ll have to define a Mean Shift object.

```import numpy as np

from utils.kernels import RBF
from utils.distances import euclidean

class MeanShift:

def __init__(self, bandwidth=1, tol=1E-7):

self.bandwidth = bandwidth
self.tol = 1 - tol
self.kernel = RBF(gamma=self.bandwidth)
```

The bandwidth parameter is there to parameterize the Radial Basis Function kernel. Now, let’s assumed that we have a trained model. This means that we have centers representing our clusters. Assigning a new point to a cluster comes down to assigning the point to its closest cluster. In this case we will use the Euclidean distance.

```    def _compute_labels(self, X, centers):

labels = []

for x in X:

distances = np.array([euclidean(x, center) for center in centers])
label = np.argmin(distances)
labels.append(label)

_, labels = np.unique(labels, return_inverse=True)
return np.array(labels, dtype=np.int)

def predict(self, X):

labels = self._compute_labels(X, self.cluster_centers_)

return labels
```

Now, let’s look at how we can train our model. Given some data, we first start by creating a center for each point in the data. Then, until convergence, we shift and merge centers.

```    def fit(self, X):

for labels, centers in self._fit(X):

self.labels_ = labels
self.cluster_centers_ = centers

return self

def _fit(self, X):

old_centers = np.array([])
new_centers = X
labels = -np.ones(len(X))  # -1 represents an "orphan"

while not self._has_converged(old_centers, new_centers):

yield labels, new_centers

old_centers = new_centers
new_centers = []

for center in old_centers:

shifted_center = self._shift(center, X)
new_centers.append(shifted_center)

new_centers = self._merge_centers(new_centers)
labels = self._compute_labels(X, new_centers)
```

An important function is the _shift function. To shift a center, we calculate the density values between the center and all points. Then, the new center is created by taking a weighted average of the data points. The difference in position between the old and new center is what is referred to as the shift.

```    def _shift(self, x, X):

densities = [self.kernel(x, x_) for x_ in X]

shifted_center = np.average(X, weights=densities, axis=0)

return shifted_center
```

Since all centers will eventually converge, some centers might need to be merged to speed up computation. Also, because of computer arithmetic, centers will rarely be exactly at the same position. Therefore, we redefine each center as the average of all centers that are within a certain high-density region around it. This way, we end up with identical centers, which we merge.

```    def _merge_centers(self, centers):

centers = np.unique(centers, axis=0)
new_centers = []

for c in centers:
distances = np.array([self.kernel(c, c_) for c_ in centers])
new_centers.append(np.mean(centers[distances > self.tol], axis=0))

centers = np.unique(new_centers, axis=0)

return centers
```

In our case, we define convergence as the moment where the shifted centers are “close enough” to the old centers.

```    def _has_converged(self, old, new):

if len(old) == len(new):

for i in range(len(new)):
if self.kernel(old[i], new[i]) < 1.0:
return False

return True
else:
return False
```