对于已经存在分类的数据,用KNN算法预测新数据的分类
#coding=utf-8
#导入相应的模块包
from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
#创建训练样本,n_samples样本数,数据类别centers
data = make_blobs(n_samples=500,centers=5,random_state=8)
X,y = data
#训练模型
clf = KNeighborsClassifier()
clf.fit(X,y)
#绘制网格图,将所有数据点包含在内(-1和+1是为了给画布中的四周留部分空白)
x_min,x_max = X[:,0].min()-1,X[:,0].max()+1
y_min,y_max = X[:,1].min()-1,X[:,1].max()+1
#生成网格点坐标矩阵
xx,yy = np.meshgrid(np.arange(x_min,x_max,.02),np.arange(y_min,y_max,.02))
#预测网格点的分类值
Z = clf.predict(np.c_[xx.ravel(),yy.ravel()]) #np.c_按行连接两个矩阵,np.r按列连接两个矩阵,ravel()展平数组
Z = Z.reshape(xx.shape)
#画出网格点的颜色
plt.pcolormesh(xx,yy,Z,cmap=plt.cm.Pastel1)#plt.colormesh的作用在于能够直观表现出分类边界
#画出原数据点并标记颜色
plt.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.spring,edgecolors='k')
#画出边界线(x轴与y轴的的上下边界)
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
###处理新数据点
#把新的数据点用五角星表示出来
plt.scatter(6.75,5.82,marker='*',c='red',s=200)
#对数据点分类进行判断
print('新数据点的分类是:',clf.predict([[6.75,5.82]]))
#模型正确率
print('模型正确率:{:.2f}'.format(clf.score(X,y)))
#将画过的内容展示出来
plt.show()