💡K-Means Clustering

K-Means is an unsupervised clustering algorithms. The algorithm idea is as follow:

  1. Randomly assign K points to be the center of clusters - centroids.

  2. Assign a cluster label to each point based on the closest centroid

  3. Take the average among points in each cluster, then assign a new centroid based on the average data.

  4. Repeat step 2 and step 3 until converge or a given number of times to find the final centroids' locations.

class KMeans:
    def _init_(self, n_clusters=5, epoch=10):
        self.n_clusters=n_clusters
        self.epoch=epoch
        self.centroids=None
        
    def fit(self, X_train):
        self.X_train=X_train
        min_val=np.min(self.X_train, axis=0)
        max_val=np.max(self.X_train, axis=0)
        
        # randomly initialize centroids 
        centroids=[np.random.uniform(min_val, max_val) for _ in range(self.n_clusters)]
        
        for i in range(self.epoch):
            centroids=self.update_centroids(centroids)
        self.centroids=centroids
        return self
    
    def update_centroids(self, centroids):
        clusters=np.zeros(self.n_train)
        # assign each training data to closest centroid
        for i in range(self.n_train):
            p=self.X_train[i]
            dis=[self.euclidean(p, centroid) for centroid in centroids]
            clusters[i]=np.argmin(dis)
        # update centroids by averaging
        for j in range(self.n_clusters):
            data=self.X_train[np.array(clusters)==i]
            centroids[i]=data.mean(axis=0)
        return centroids
    
    def euclidean(self, p, p_train):
        return np.sqrt(np.sum(np.square(p-p_train)))
    
    def predict(self, X_test):
        self.X_test = X_test
        # number of test observations
        self.n_test = X_test.shape[0]
        clusters=np.zeros(self.n_test)
        for i in range(self.n_test):
            p=self.X_test[i]
            dis=[self.euclidean(p,centroid) for centroid in self.centroids]
            clusters[i]=np.argmin(dis)
        return clusters

Code snippets are adapted from GeeksforGeeks.

Last updated