1 # -*- coding: utf-8 -*-
2 """
3 Created on Wed Jan 10 19:18:56 2018
4
5 @author: markli
6 """
7 import numpy as np;
8 '''
9 kmeans 算法实现
10 算法原理
11 1、随机选择k个点作为聚类中心点,进行聚类
12 2、求出聚类后的各类的 中心点
13 3、由中心点作为新的聚类中心点,再次进行聚类
14 4、比较前后两次的聚类中心点是否发生变化,若没有变化则停止,否则重复2,3,4
15 '''
16
17 def Kmeans(X,k,maxiter):
18 '''
19 使用Kmeans均值聚类对数据集Data进行聚类
20 X 数据集
21 k 聚类中心个数
22 maxiter 最大迭代次数
23 '''
24 m,n = X.shape;
25 #向数据集中添加一列,用来存放类别号
26 Dataset = np.zeros((m,n+1));
27 Dataset[:,:-1] = X;
28
29 #随机选取k 个聚类中心
30 randomCenterIndex = np.random.randint(m,size=k);
31 center = Dataset[randomCenterIndex];
32 center[:,-1] = range(1,k+1);
33
34 #初始聚类
35 oldCenter = np.copy(center);
36 DataClass(Dataset,center);
37 center = getCenter(Dataset,k);
38
39 itertor = 1;
40 while not isStop(oldCenter,center,itertor,maxiter):
41 oldCenter = np.copy(center);
42 DataClass(Dataset,center);
43 center = getCenter(Dataset,k);
44 itertor = itertor + 1;
45 print("数据集聚类结果",Dataset);
46 print("聚类中心点",center);
47
48
49 def DataClass(Dataset,center):
50 '''
51 对数据集进行聚类或者类标签更新
52 Dataset 数据集
53 center 聚类中心点 最后一列为聚类中心点的分类号
54 '''
55 n = Dataset.shape[0];
56 k = center.shape[0];
57 for i in range(n):
58 lable = center[0,-1];
59 mindistance = np.linalg.norm(Dataset[i,:-1]-center[0,:-1],ord=2);
60 for j in range(1,k):
61 distance = np.linalg.norm(Dataset[i,:-1]-center[j,:-1],ord=2);
62 if(distance < mindistance):
63 mindistance = distance;
64 lable = center[j,-1];
65 Dataset[i,-1] = lable;
66
67 def getCenter(Dataset,k):
68 '''
69 获得数据集的k个聚类中心,数据集的最后一列是当前的分类号
70 Dataset 数据集
71 k 聚类中心点个数
72 '''
73 center = np.ones((k,Dataset.shape[1]));
74 for i in range(1,k+1):
75 DataSubset = Dataset[Dataset[:,-1] == i,:];
76 center[i-1] = np.mean(DataSubset,axis=0);
77 return center;
78
79 def isStop(oldCenter,newCenter,itertor,maxiter):
80 '''
81 判断是否停止
82 oldCenter 前一次聚类的聚类中心
83 newCenter 新产生的聚类中心
84 itertor 当前迭代次数
85 maxitor 最大迭代次数
86 '''
87
88 if(itertor >= maxiter):
89 return True;
90
91 return np.array_equal(oldCenter,newCenter);
92
93
94 X = np.array([[1,1],[2,1],[4,3],[5,4]]);
95 print(X.shape);
96 Kmeans(X,2,10);