import numpy as np # 矩阵计算函数库
import matplotlib.pyplot as plt # 可视化图像
from mpl_toolkits.mplot3d import Axes3D # 3维图
from sklearn.cluster import KMeans # KMeans聚类算法
from sklearn import datasets # 鸢尾花数据集
np.random.seed(5) # 设置随机种子,5个数,用于K-means聚类算法的初始化
centers = [[1, 1], [-1, -1], [1, -1]] # 聚类中心
iris = datasets.load_iris() # 获取数据集
X = iris.data # 训练所需的数据集
y = iris.target # 数据集对应的分类标签,属于监督学习
estimators = {'k_means_iris_3': KMeans(n_clusters=3),
'k_means_iris_8': KMeans(n_clusters=8),
'k_means_iris_bad_init': KMeans(n_clusters=3, n_init=1,
init='random')}
# 设置K-means的参数,n_clusters是需要计算出的集群数,n_init使用不同centroid seeds运行K-means的时间,init是初始化方法
fignum = 1
for name, est in estimators.items():
fig = plt.figure(fignum, figsize=(4, 3)) # figsize指定图像的纵向高度和横向宽度
plt.clf() # 清空当前图像操作,此处可以不加
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134) # 返回3D图形对象
plt.cla() # 清空当前坐标操作,此处可以不加
est.fit(X) # 用数据对算法进行拟合操作
labels = est.labels_ # 得到每一数据点的分类结果
# 绘制散点图
ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=labels.astype(np.float))
# scatter是绘制散点图的函数,前面3个参数对应数据在x,y,z轴的坐标,c代表色彩颜色序列
# 设置x,y,z轴的刻度标签,[]代表不描绘刻度
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
# 设置x,y,z轴的标签
ax.set_xlabel('Petal width')
ax.set_ylabel('Sepal length')
ax.set_zlabel('Petal length')
fignum = fignum + 1
# Plot the ground truth
fig = plt.figure(fignum, figsize=(4, 3))
plt.clf()
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
plt.cla()
for name, label in [('Setosa', 0),
('Versicolour', 1),
('Virginica', 2)]: # 在数据集中心绘制 分类标签的名字
ax.text3D(X[y == label, 3].mean(),
X[y == label, 0].mean() + 1.5,
X[y == label, 2].mean(), name,
horizontalalignment='center', # center代表text向中间水平对齐
bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))
# bbox用于设置ext背景框 alpha为透明度,edgecolor为边框颜色(w为white之意),facecolor为背景框内部颜色
# Reorder the labels to have colors matching the cluster results
y = np.choose(y, [1, 2, 0]).astype(np.float)
ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y) # 绘制散点图
# 设置x,y,z轴的刻度标签,[]代表不描绘刻度
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
# 设置x,y,z轴的标签
ax.set_xlabel('Petal width')
ax.set_ylabel('Sepal length')
ax.set_zlabel('Petal length')
plt.show() # 显示图像