K-means

Requirements

python3.6

matplotlib

numpy

sklearn

scipy

Introduction

Cluster analysis is a basic analysing method in data mining. K-means clustering is the most classic clustering, which is based on the distances between dots in data.

Steps:

  1. Set initial centroids
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
data = np.random.rand(100, 2)
estimator = KMeans(n_clusters=3)
estimator.fit(data)
labels = estimator.labels_ 
centroids = estimator.cluster_centers_ 
inertia = estimator.inertia_ 


fig = plt.figure()  
ax1 = fig.add_subplot(121)  
ax1.set_title('Scatter Plot (Before)') 
# plt.grid(1)
plt.xlabel('X')  
plt.ylabel('Y')  
for each in data:
    ax1.scatter(each[0],each[1],c = '#b22c46',marker = 'o',alpha=0.9)

plt.legend('x',loc='lower right')
ax2 = fig.add_subplot(122)
ax2.set_title('Scatter Plot (After)')

# plt.grid(1)
plt.xlabel('X')  
plt.ylabel('Y')
type1_x,type1_y,type2_x,type2_y,type3_x,type3_y=[],[],[],[],[],[]

for i in range(len(labels)):
    if labels[i] == 0:  
        type1_x.append(data[i][0])
        type1_y.append(data[i][1])

    if labels[i] == 1: 
        type2_x.append(data[i][0])
        type2_y.append(data[i][1])


    if labels[i] == 2:  
        type3_x.append(data[i][0])
        type3_y.append(data[i][1])


type1 = ax2.scatter(type1_x, type1_y,  c='#ffd400')
type2 = ax2.scatter(type2_x, type2_y,  c='#94d6da')
type3 = ax2.scatter(type3_x, type3_y,  c='#f3704b')


ax2.legend((type1, type2, type3), ('class1', 'class2', 'class3'), loc='lower right')
fig.savefig('./km.svg')
plt.show()

生信媛K-means Wikipedia 从零开始实现Kmeans聚类算法 深入理解K-Means聚类算法

On the edge

Table of Contents