k-均值聚类算法

最近因为工作得原因,接触了一点机器学习中得算法,在此记录下来,以供学习。

import numpy as np
import copy
import matplotlib.pyplot as plt

pic = plt.imread('apple.png')
plt.imshow(pic)
pic.shape
data = pic.reshape(-1,4)

def kmeans_wave(n, k, data): #n为迭代次数, k为聚类中心, data为输入数据
    data_new = copy.deepcopy(data)    
    data_new = np.column_stack((data_new, np.ones(631*982)))
    center_point = np.random.choice(631*982, k, replace = False)
    center = data_new[center_point, :]
    
    distance = [[] for i in range(k)]
    
    for i in range(n):
        for j in range(k):
            distance[j] = np.sqrt(np.sum(np.square(data_new-np.array(center[j])), axis=1)) # 更新距离
    
        data_new[:,4] = np.argmin(np.array(distance), axis = 0)  # 将最小距离的类别标签作为当前数据的类别
        for l in range(k):
            center[l] = np.mean(data_new[data_new[:,4]==1], axis=0)# 更新聚类中心
return data_new

if __name__ == '__main__':
    data_new = kmeans_wave(100,6,data)
    print(data_new.shape)
    pic_new = data_new[:, 4].reshape(631,982)
    plt.imshow(pic_new)
    plt.show()

下面是运行结果:

 

posted @ 2018-08-10 08:46  y-xs  阅读(176)  评论(0编辑  收藏  举报