scikit_learn--kNN算法
一.kNN算法简介
在模式识别领域中,最近邻居法(KNN算法,又译K-近邻算法)是一种用于分类和回归的非参数统计方法[1]。在这两种情况下,输入包含特征空间中的k个最接近的训练样本。
-
- 在k-NN分类中,输出是一个分类族群。一个对象的分类是由其邻居的“多数表决”确定的,k个最近邻居(k为正整数,通常较小)中最常见的分类决定了赋予该对象的类别。若k = 1,则该对象的类别直接由最近的一个节点赋予。
-
- 在k-NN回归中,输出是该对象的属性值。该值是其k个最近邻居的值的平均值。
最近邻居法采用向量空间模型来分类,概念为相同类别的案例,彼此的相似度高,而可以借由计算与已知类别案例之相似度,来评估未知类别案例可能的分类。
K-NN是一种基于实例的学习,或者是局部近似和将所有计算推迟到分类之后的惰性学习。k-近邻算法是所有的机器学习算法中最简单的之一。(维基百科)
二.使用sklearn来实现kNN算法
1.示例一:采用约会网站的数据来测试
import numpy as np import matplotlib.pyplot as plt #约会网站测试数据的分别表示1.每年飞行的里程数,2.玩游戏和看视频占的时间的比,3.每周消费冰淇淋的公升数 raw_data_X=[[54483,6.317292,0.018209], [18475,12.664194,0.595653], [33926,2.906644,0.581657], [43865,2.388241,0.913938], [26547,6.024471,0.486215], [44404,7.226764,1.255329], [16674,4.183997,1.275290], [8123,11.850211,1.096981], [42747,11.661797,1.167935]] #样本标签中1表示不喜欢,2表示魅力一般,3表示极具魅力 raw_data_y=[1,3,1,1,3,3,2,3,3] X_train=np.array(raw_data_X) y_train=np.array(raw_data_y) x=np.array([56054,3.574967,0.494666]) from sklearn.neighbors import KNeighborsClassifier #导入sklearn的相应的模块 kNN_classifier=KNeighborsClassifier(n_neighbors=3) #括号中的3表示k的值 kNN_classifier.fit(X_train,y_train) #传入训练数据进行拟合 x_predict=x.reshape(1,-1)#对测试的数据先进行修改格式 kNN_classifier.predict(x_predict) #传入测试,对结果进行进行预测 #array([1]) 查看数据中的结果,发现测试数据测试正确
三.改善并测试算法的精确度
1.示例一:采用鸢尾花的数据集
import numpy as np from sklearn import datasets #导入sklearn中的数据集 iris = datasets.load_iris() #获取鸢尾花数据集 X=iris.data #获得鸢尾花的特征参数数据 y=iris.target #获取鸢尾花的标签(分类) #为了改善算法的精确度,我们采用把数据集进行分割的方法,一部分用来当做训练数据集,一部分用来当做测试集(train_test_split的方法) from sklearn.model_selection import train_test_split #导入相应的模块 X_train,x_test,y_train,y_test=train_test_split(X,y,test_size=0.2) #test_size=0.2表示的是测试数据占20%,训练数据占80% from sklearn.neighbors import KNeighborsClassifier kNN_classifier=KNeighborsClassifier(n_neighbors=3) kNN_classifier.fit(X_train,y_train) y_predict=kNN_classifier.predict(x_test) #y_predict=array([0, 0, 1, 1, 1, 1, 2, 1, 0, 1, 2, 1, 1, 0, 1, 2, 0, 1, 1, 2, 2, 0,2, 1, 1, 0, 2, 0, 0, 1]) #测试算法的精确度 sum(y_predict==y_test)/len(y_test) #用相等的个数除以整个的的个数 #0.9333333333333333 #由上面可以知道,该算法的精准度维0.93

浙公网安备 33010602011771号