Chapter 2: Mean Shift
Let’s implement Mean Shift from scratch. First, we’ll have to define a Mean Shift object.
import numpy as np from utils.kernels import RBF from utils.distances import euclidean class MeanShift: def __init__(self, bandwidth=1, tol=1E-7): self.bandwidth = bandwidth self.tol = 1 - tol self.kernel = RBF(gamma=self.bandwidth)
The bandwidth parameter is there to parameterize the Radial Basis Function kernel. Now, let’s assumed that we have a trained model. This means that we have centers representing our clusters. Assigning a new point to a cluster comes down to assigning the point to its closest cluster. In this case we will use the Euclidean distance.
def _compute_labels(self, X, centers): labels = [] for x in X: distances = np.array([euclidean(x, center) for center in centers]) label = np.argmin(distances) labels.append(label) _, labels = np.unique(labels, return_inverse=True) return np.array(labels, dtype=np.int) def predict(self, X): labels = self._compute_labels(X, self.cluster_centers_) return labels
Now, let’s look at how we can train our model. Given some data, we first start by creating a center for each point in the data. Then, until convergence, we shift and merge centers.
def fit(self, X): for labels, centers in self._fit(X): self.labels_ = labels self.cluster_centers_ = centers return self def _fit(self, X): old_centers = np.array([]) new_centers = X labels = -np.ones(len(X)) # -1 represents an "orphan" while not self._has_converged(old_centers, new_centers): yield labels, new_centers old_centers = new_centers new_centers = [] for center in old_centers: shifted_center = self._shift(center, X) new_centers.append(shifted_center) new_centers = self._merge_centers(new_centers) labels = self._compute_labels(X, new_centers)
An important function is the _shift function. To shift a center, we calculate the density values between the center and all points. Then, the new center is created by taking a weighted average of the data points. The difference in position between the old and new center is what is referred to as the shift.
def _shift(self, x, X): densities = [self.kernel(x, x_) for x_ in X] shifted_center = np.average(X, weights=densities, axis=0) return shifted_center
Since all centers will eventually converge, some centers might need to be merged to speed up computation. Also, because of computer arithmetic, centers will rarely be exactly at the same position. Therefore, we redefine each center as the average of all centers that are within a certain high-density region around it. This way, we end up with identical centers, which we merge.
def _merge_centers(self, centers): centers = np.unique(centers, axis=0) new_centers = [] for c in centers: distances = np.array([self.kernel(c, c_) for c_ in centers]) new_centers.append(np.mean(centers[distances > self.tol], axis=0)) centers = np.unique(new_centers, axis=0) return centers
In our case, we define convergence as the moment where the shifted centers are “close enough” to the old centers.
def _has_converged(self, old, new): if len(old) == len(new): for i in range(len(new)): if self.kernel(old[i], new[i]) < 1.0: return False return True else: return False