Python Knn算法原理及实现
一、Knn为伪代码及优缺点
(1)
- (1)计算已知类别数据集中的点与当前点之间的距离
- (2)按照距离递增次序排序
- (3)选取与与当前点距离最小的k个点
- (4)确定前k个点所在类别的出现频率
- (5)返回前k个点出现频率最高的类别作为当前点的预测分类
(2)优点:精度高、对异常数据不敏感、无数据输入假定
缺点:计算复杂度高、空间复杂度高
二、代码实现
1 def Knn_class(inx,dataset,labels,k): 2 """ 3 inx用于分类的输入向量;dataset已知类别的特征数据训练集; 4 label已知类别的标签;k确定前k个出现的类别频率 5 """ 6 import numpy as np 7 ##计算距离 8 datasetsize = dataset.shape[0] 9 diffmat = np.tile(inx,(datasetsize,1))-dataset #拓展函数,将arr整体拓展为m行n列才能与dataset同维度相减 10 sqdiffmat = diffmat**2 11 distance = sqdiffmat.sum(axis=1) #axis=1行相加 12 13 14 sortdistance = distance.argsort() ##对距离进行升序排序并返回索引值 15 16 ##对排序完的距离前k个进行投票筛选(即频率高的标签) 17 dic = {} 18 try: 19 for i in range(k): 20 label = labels[sortdistance[i]] #匹配升序后的距离的标签 21 dic[label] = dic.get(label,0)+1 #统计每类标签出现的频率 22 lt = list(dic.items()) 23 print(lt) 24 lt.sort(key = lambda x:x[1],reverse=True) #对字典内的标签进行降序 25 return lt[0][0] #返回频率最高的标签,即预测类别 26 except: 27 print('"k"值最大为:',len(labels))
三、简单测试代码
1 if __name__=='__main__': 2 import numpy as np 3 np.random.seed(123) 4 inx = np.array([1,23,45,6,7]) 5 dataset = np.random.random(40).reshape(8,5) 6 labels = np.array([0,1,1,3,2,0,3,0]) 7 k=3 8 result = Knn_class(inx,dataset,labels,k) 9 print(result)
四、手写数字识别
1 dir = r'C:\Users\Administrator.PC-20160806EWJL\Desktop\大杂烩\Untitled Folder\testandtraindata\traindata' 2 #加载数据 3 def datatoarray(fname): 4 arr = [] 5 f = open(fname) 6 for i in range(32): 7 thisline = f.readline() 8 for j in range(len(thisline)-1): 9 data = int(thisline[j]) 10 arr.append(data) 11 return arr 12 13 #建立一个函数取文件名前缀 14 def seplabel(fname): 15 label=int(fname.split("_")[0]) 16 return label 17 18 #建立训练数据集 19 def traindata(): 20 import os 21 label =[] 22 trainfile = os.listdir(dir) 23 num = len(trainfile) 24 trainarr = np.zeros((num,1024)) 25 for i in range(num): 26 trainfname = trainfile[i] 27 label.append(seplabel(trainfname)) 28 trainarr[i,:] = datatoarray("testandtraindata\\traindata\\%s"%trainfname) 29 30 return trainarr,label 31 32 #预测数据并返回准确率 33 def testdata(): 34 import os 35 testfile = os.listdir('testandtraindata\\testdata') 36 num = len(testfile) 37 errcount = 0 38 for i in range(num): 39 testfname = testfile[i] 40 testlabel=(seplabel(testfname)) 41 testarr=datatoarray("testandtraindata\\testdata\\%s"%testfname) 42 inx = testarr 43 dataset,labels = traindata() 44 k=5 45 accure = Knn_class(inx,dataset,labels,k) 46 if accure!=testlabel: 47 errcount +=1 48 print("comeback label:",accure) 49 print("the real answer is:",testlabel) 50 print("the total number or errors is %d "%errcount) 51 print("the total errors trate is %d "%(errcount/float(num))) 52 53 testdata()
参考图书:机器学习实战【美】Peter Harrington 著 李锐 李鹏 曲亚东 王斌 译
浙公网安备 33010602011771号