G-Means for non-spherical clusters from scratch

Abhishek Biswas
7 min readMar 12, 2023

--

K-means, one of the popular data mining techniques, has the benefit of quick clustering speed but the drawback of being less successful at clustering non-spherical data. For non-spherical data, a better K-means method (G-means) is suggested to increase clustering effectiveness.

Table contents:

  • Spherical vs Non-spherical data:
  • Ways to handle Non-spherical data during clustering
  • G-means algorithm from scratch
  • Algorithm Explanation in Python
  • Plotting and interpretation
  • Future Scope of Improvements
  • Conclusion

Spherical vs Non-spherical data:

Spherical data refers to data that is distributed uniformly in all directions, such as points on the surface of a sphere. In contrast, non-spherical data refers to data that is not uniformly distributed in all directions.

For example, consider a dataset that consists of 2D points that are uniformly distributed on the surface of a circle. This dataset would be considered spherical, since the points are uniformly distributed in all directions around the center of the circle. In contrast, if the points were concentrated in a specific region of the circle, such as on one side, then the dataset would be considered non-spherical.

In general, non-spherical data can have a variety of shapes, such as elongated, stretched, or irregular. This can make clustering more challenging, as traditional clustering algorithms like K-means may not be able to capture the underlying structure of the data. In these cases, more specialized clustering algorithms like spectral clustering or DBSCAN may be more appropriate.

Let’s generate these two types of synthetic data and visualize them using python.

import numpy as np
import matplotlib.pyplot as plt

# Generate spherical data
num_points = 500
center = [0, 0]
radius = 1
spherical_data = np.random.normal(size=(num_points, 2))
spherical_data = radius * (spherical_data / np.linalg.norm(spherical_data, axis=1)[:, None])
spherical_data += center

# Generate non-spherical data
num_points = 500
cov = [[1, 0], [0, 10]]
non_spherical_data = np.random.multivariate_normal(mean=[0, 0], cov=cov, size=num_points)

# Plot the data
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].scatter(spherical_data[:, 0], spherical_data[:, 1], c='b')
axs[0].set_title('Spherical data')
axs[1].scatter(non_spherical_data[:, 0], non_spherical_data[:, 1], c='r')
axs[1].set_title('Non-spherical data')
plt.show()
Output of Synthetic data

This code generates two sets of 2D data points. The first set, spherical_data, is generated to be uniformly distributed on a sphere centered at the origin. The second set, non_spherical_data, is generated to have a covariance matrix that is not diagonal, which gives it a non-spherical shape.

The code then plots both sets of data points side-by-side in a 2D scatterplot. The spherical data is plotted in blue, while the non-spherical data is plotted in red.

Ways to handle Non-spherical data during clustering:

  1. G-means is a more advanced version of X-means that assumes the data is generated from a mixture of Gaussian distributions and uses statistical hypothesis testing to determine the optimal number of clusters. Specifically, it performs a likelihood ratio test to decide whether a cluster should be split into two or merged with another cluster. This approach can handle variable-sized and non-spherical clusters and is robust to noise and outliers.
  2. Hierarchical X-means: This approach combines X-means with hierarchical clustering to obtain a tree-like structure of clusters. It starts with a single cluster and recursively splits it into subclusters until a stopping criterion is met. The advantage of this approach is that it can handle non-spherical clusters and can discover the underlying hierarchical structure of the data.

We will use First method G-means.

Here, we will write G-means algorithm from scratch on top of K-Means in Python:

import numpy as np
from sklearn.datasets import make_blobs
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#GMeans Algorithm

class GMeans:
def __init__(self, k=None, max_iters=100, min_points=1, tol=1e-4):
self.k = k
self.max_iters = max_iters
self.min_points = min_points
self.tol = tol

def fit(self, X):
if self.k is None:
self.k = 1
centroids = [np.mean(X, axis=0)]
labels = np.zeros(len(X))
clusters = [X]
k = 1
iters = 0
while k < self.k and iters < self.max_iters:
new_clusters = []
new_centroids = []
for i, cluster in enumerate(clusters):
if len(cluster) > self.min_points:
kmeans = KMeans(n_clusters=2)
kmeans.fit(cluster)
new_labels = kmeans.labels_
new_centers = kmeans.cluster_centers_
new_sse = np.sum((cluster[new_labels == 0] - new_centers[0])**2) + \
np.sum((cluster[new_labels == 1] - new_centers[1])**2)
old_sse = np.sum((cluster - centroids[i])**2)
if new_sse < old_sse:
new_clusters.append(cluster[new_labels == 0])
new_clusters.append(cluster[new_labels == 1])
new_centroids.append(new_centers[0])
new_centroids.append(new_centers[1])
else:
new_clusters.append(cluster)
new_centroids.append(centroids[i])
else:
new_clusters.append(cluster)
new_centroids.append(centroids[i])
if len(new_clusters) > len(clusters):
k += 1
centroids = new_centroids
clusters = new_clusters
iters += 1
print("Iteration:", iters)
print("Centroids:", centroids)
print("Convergence:", np.max(np.abs(np.array(centroids) - np.array(new_centroids))))
self.labels_ = labels
self.cluster_centers_ = centroids




# Generate non-spherical data with 3 clusters
X, y = make_blobs(n_samples=500, centers=3, n_features=3, cluster_std=[1.0, 2.5, 0.5])

# Cluster the data using GMeans
gmeans = GMeans(k=3, max_iters=100, min_points=1, tol=1e-4)
gmeans.fit(X)

# Get the predicted labels and cluster centers
predicted_labels = gmeans.labels_
predicted_centers = gmeans.cluster_centers_

# Debugging output
print("Predicted labels:", predicted_labels)
print("Predicted centers:", predicted_centers)

Algorithm Explanation:

The code implements the GMeans clustering algorithm, which is a variant of KMeans that automatically determines the number of clusters.

The GMeans class takes the following parameters:

  • k: the maximum number of clusters to search for. If not provided, defaults to 1.
  • max_iters: the maximum number of iterations to run the algorithm for. If not provided, defaults to 100.
  • min_points: the minimum number of points required to split a cluster. If not provided, defaults to 1.
  • tol: the convergence tolerance. If the maximum difference between old and new centroids is less than this value, the algorithm will terminate. If not provided, defaults to 1e-4.

The fit method of the GMeans class takes the data X as input and performs the GMeans clustering algorithm. It starts with one cluster and computes its centroid. If the cluster contains more than min_points points, the algorithm attempts to split it into two clusters using KMeans clustering. If the sum of squared error (SSE) of the resulting clusters is less than the SSE of the original cluster, the split is accepted and the new clusters and centroids are added to the list of clusters and centroids. Otherwise, the original cluster and centroid are kept. The algorithm continues this process until either the maximum number of iterations is reached or the maximum difference between old and new centroids is less than the convergence tolerance. The labels and cluster centers of the resulting clusters are stored in the labels_ and cluster_centers_ attributes of the GMeans object.

Plotting in 3D:

#3D plot

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=predicted_labels, cmap='viridis')
#ax.scatter(predicted_centers[:, 0], predicted_centers[:, 1], predicted_centers[:, 2], marker='*', s=200, c='r')
plt.scatter([x[0] for x in predicted_centers], [x[1] for x in predicted_centers], marker='*', s=200, c='r')

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.title("GMeans Clustering")
plt.show()

Plotting in 2D:

# plot the data points and the predicted centers
for label in np.unique(predicted_labels):
plt.scatter(X[predicted_labels == label, 0], X[predicted_labels == label, 1],
c=np.random.rand(3,), marker='o', label=f'Cluster {label}')

plt.scatter(predicted_centers[:, 0], predicted_centers[:, 1], marker='*', s=200, c='r', label='Cluster Centers')
plt.legend()
plt.title("GMeans Clustering")
plt.show()

Interpretation:

the code plots the data points colored by their predicted labels using plt.scatter(). It also plots the predicted centers as stars in red color. This allows you to see how well the algorithm clustered the data. Our, Synthetic data was having only 3 clusters. But, here we can see the centroids are captured in a little bit noisy way and it’s detecting 4 clusters.

To, improve these we need to tweak with our model by hyperparameter tuning and data preprocessing ways.

Future Scope of Improvements:

To improve the performance of G-Means on this dataset, you could try adjusting the max_depth parameter or the initial number of clusters (k) passed to the GMeans constructor. Alternatively, you could try using a different clustering algorithm that is better suited for non-spherical clusters, such as DBSCAN or Mean Shift.

Note:

possible hyperparameters that can be tuned in the GMeans algorithm are:

k: The number of clusters to be formed. If k is not specified, then the algorithm will start with a single cluster and continue splitting until it reaches the minimum points per cluster or maximum iterations.

max_iters: The maximum number of iterations to perform.

min_points: The minimum number of points that each resulting cluster should have. This is a stopping criterion for the algorithm to stop splitting clusters when the number of points in a cluster is below this threshold.

tol: The tolerance level for convergence. The algorithm will stop if the change in centroids is less than this value.

Conclusion:

We crafted the algorithm from scratch to understand the functionality and behavior on random synthetic dataset. But, Python has some external library as well. This will make the implementation easier.

--

--

Abhishek Biswas
Abhishek Biswas

Written by Abhishek Biswas

Technologist | Writer | Mentor | Industrial Ambassador | Mighty Polymath

No responses yet