scikit_learn--kNN算法

一.kNN算法简介

模式识别领域中,最近邻居法KNN算法,又译K-近邻算法)是一种用于分类回归非参数统计方法[1]。在这两种情况下,输入包含特征空间中的k个最接近的训练样本。

  • k-NN分类中,输出是一个分类族群。一个对象的分类是由其邻居的“多数表决”确定的,k个最近邻居(k为正整数,通常较小)中最常见的分类决定了赋予该对象的类别。若k = 1,则该对象的类别直接由最近的一个节点赋予。
  • k-NN回归中,输出是该对象的属性值。该值是其k个最近邻居的值的平均值。

最近邻居法采用向量空间模型来分类,概念为相同类别的案例,彼此的相似度高,而可以借由计算与已知类别案例之相似度,来评估未知类别案例可能的分类。

K-NN是一种基于实例的学习,或者是局部近似和将所有计算推迟到分类之后的惰性学习。k-近邻算法是所有的机器学习算法中最简单的之一。(维基百科)

二.使用sklearn来实现kNN算法

1.示例一:采用约会网站的数据来测试

import numpy as np
import matplotlib.pyplot as plt


#约会网站测试数据的分别表示1.每年飞行的里程数,2.玩游戏和看视频占的时间的比,3.每周消费冰淇淋的公升数
raw_data_X=[[54483,6.317292,0.018209],
            [18475,12.664194,0.595653],
            [33926,2.906644,0.581657],
            [43865,2.388241,0.913938],
            [26547,6.024471,0.486215],
            [44404,7.226764,1.255329],
            [16674,4.183997,1.275290],
            [8123,11.850211,1.096981],
            [42747,11.661797,1.167935]]

#样本标签中1表示不喜欢,2表示魅力一般,3表示极具魅力
raw_data_y=[1,3,1,1,3,3,2,3,3]

X_train=np.array(raw_data_X)
y_train=np.array(raw_data_y)

x=np.array([56054,3.574967,0.494666])


from sklearn.neighbors import KNeighborsClassifier #导入sklearn的相应的模块

kNN_classifier=KNeighborsClassifier(n_neighbors=3) #括号中的3表示k的值

kNN_classifier.fit(X_train,y_train)  #传入训练数据进行拟合

x_predict=x.reshape(1,-1)#对测试的数据先进行修改格式

kNN_classifier.predict(x_predict) #传入测试,对结果进行进行预测

#array([1])  查看数据中的结果,发现测试数据测试正确

三.改善并测试算法的精确度

1.示例一:采用鸢尾花的数据集

import numpy as np 
from sklearn import datasets #导入sklearn中的数据集
iris = datasets.load_iris()  #获取鸢尾花数据集

X=iris.data  #获得鸢尾花的特征参数数据
y=iris.target  #获取鸢尾花的标签(分类)


#为了改善算法的精确度,我们采用把数据集进行分割的方法,一部分用来当做训练数据集,一部分用来当做测试集(train_test_split的方法)
from sklearn.model_selection import train_test_split  #导入相应的模块

X_train,x_test,y_train,y_test=train_test_split(X,y,test_size=0.2)

#test_size=0.2表示的是测试数据占20%,训练数据占80%

from sklearn.neighbors import KNeighborsClassifier
kNN_classifier=KNeighborsClassifier(n_neighbors=3)
kNN_classifier.fit(X_train,y_train)

y_predict=kNN_classifier.predict(x_test)

#y_predict=array([0, 0, 1, 1, 1, 1, 2, 1, 0, 1, 2, 1, 1, 0, 1, 2, 0, 1, 1, 2, 2, 0,2, 1, 1, 0, 2, 0, 0, 1])

#测试算法的精确度

sum(y_predict==y_test)/len(y_test) #用相等的个数除以整个的的个数

#0.9333333333333333

#由上面可以知道,该算法的精准度维0.93

              

posted @ 2018-05-15 16:10  明-少  阅读(280)  评论(0)    收藏  举报