跟着Leo机器学习实战——K-近邻算法
一个很有趣的个人博客,不信你来撩 fangzengye.com
跟着Leo机器学习实战——K-近邻算法
github代码获取
https://github.com/LeoLeos/MachineLearningLeo
概述(核心思想)
选出以待分类的一个样本与他最近的K个已知标签的样本中,这K个样本中的属于相同label的个数最多,则将这个待分类样本分给这一类.
举例
如果K=3,绿色圆点的最近的3个邻居是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类。
如果K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类。
优点
精确度高,对异常值不敏感,无数据输入假定
缺点
计算复杂度和空间复杂度高
实战–使用KNN改进约会网站的配对效果
准备数据
def file2matrix(filename): fr = open(filename) numberOfLines = len(fr.readlines()) #获取文件行数 returnMat = zeros((numberOfLines,3)) #构建零矩阵 classLabelVector = [] #构建准备存储label向量 fr = open(filename) index = 0 #将文件转化为矩阵 for line in fr.readlines(): #去掉每行首尾空格回车符 line = line.strip() listFromLine = line.split('\t') returnMat[index,:] = listFromLine[0:3]classLabelVector<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span>listFromLine<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> index <span class="token operator">+=</span> <span class="token number">1</span> <span class="token keyword">return</span> returnMat<span class="token punctuation">,</span>classLabelVector datingDataMat<span class="token punctuation">,</span>datingLabels<span class="token operator">=</span>kNN<span class="token punctuation">.</span>file2matrix<span class="token punctuation">(</span><span class="token string">'datingTestSet2.txt'</span><span class="token punctuation">)</span> <span class="token comment">#print(returnMat,classLabelVector) </span> <span class="token comment">#实例化画图视图 </span> fig <span class="token operator">=</span> plt<span class="token punctuation">.</span>figure<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">#选择子图 </span> ax <span class="token operator">=</span> fig<span class="token punctuation">.</span>add_subplot<span class="token punctuation">(</span><span class="token number">111</span><span class="token punctuation">)</span> ax<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>datingDataMat<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span>datingDataMat<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token number">15.0</span><span class="token operator">*</span>array<span class="token punctuation">(</span>datingLabels<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token number">15.0</span><span class="token operator">*</span>array<span class="token punctuation">(</span>datingLabels<span class="token punctuation">)</span><span class="token punctuation">)</span> ax<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>datingDataMat<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>datingDataMat<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span><span class="token number">15.0</span><span class="token operator">*</span>array<span class="token punctuation">(</span>datingLabels<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token number">15.0</span><span class="token operator">*</span>array<span class="token punctuation">(</span>datingLabels<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">#显示画图结果 </span> plt<span class="token punctuation">.</span>show<span class="token punctuation">(</span><span class="token punctuation">)</span>

算法核心(分类器)
def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet #tile产生以inX为元素包,形成(dataSetSize,1)矩阵,返回inX与dataSet各点坐标差值。 sqDiffMat = diffMat**2 #计算点inX与各点欧式距离平方。 sqDistances = sqDiffMat.sum(axis=1) #以行为点,对列求和 distances = sqDistances**0.5 #开方 sortedDistIndicies = distances.argsort() #从小到大排序的下标 classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] #获取按sortedDistIndicies排序的label classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #获取与距离最小的前k个分给voteIlabel类有几次 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) #按分给某一类的次数排序 return sortedClassCount[0][0] #返回分给次数最多的label
归一化公式
new = (oldValue-minValue)/(maxValue-minValue)
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
#产生与输入矩阵一样大小的零矩阵
normDataSet = zeros(shape(dataSet))
#获取行数
m = dataSet.shape[0]
#对输入矩阵减去最小值
normDataSet = dataSet - tile(minVals, (m,1))
#除以maxVals 与 minVals的差值
normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide
#返回归一化之后的矩阵,最大值与最小值的差以及最小值
return normDataSet, ranges, minVals
测试算法
def datingClassTest():
#设置测试样本与样本数比例
hoRatio = 0.50 #hold out 10%
#加载数据
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
#归一化
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
#设置k值
numTestVecs = int(m*hoRatio)
#初始化错误率
errorCount = 0.0
for i in range(numTestVecs):
#判别normMat[i,:]测试样本的类别
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
print(errorCount)
预测
def classifyPerson():
resultList = ['not at all', 'in small does', 'in large does']
percentTats = float(input('percentTage of time spent playing video game?'))
ffMiles = float(input('frequent filer miles earned per years?'))
iceCream = float(input('liters of ice cream consumed per year?'))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat,ranges, minVals = autoNorm(datingDataMat)
#将三者转化成数组,之前打错成大括号,变成列表,报错!
inArr = array([ffMiles, percentTats, iceCream])
#调用kNN函数进行预测
classifierResult = classify0((inArr - minVals)/ranges,normMat, datingLabels, 3)
print('You will probably like this person:',resultList[classifierResult - 1])
手写数字系统识别
流程
准备数据:将图像转换为测试向量
def img2vector(filename):
#初始化零向量矩阵
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
#按行,列写入像素数值
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect #返回图片向量
测试算法:使用k-近邻算法识别手写数字
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') #load the training set加载训练集
m = len(trainingFileList) #训练集样本数
trainingMat = zeros((m,1024)) #为美国样本初始化一个零向量,以存储每个样本的向量
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits') #iterate through the test set加载测试集
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) #将测试集的每个样本逐步向量化
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #为其分类
print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0 #计算分错类的个数
print ("\nthe total number of errors is: %d" % errorCount)
print ("\nthe total error rate is: %f" % (errorCount/float(mTest))) #错误率
如果您喜欢我的文章,感觉对您有帮助,不如请问喝杯奶茶,我不怕胖的哦嘻嘻嘻,比心!

我的个人博客fangzengye.com, 欢迎来撩哦!
原文博主: 热衷开源的宝藏Boy
版权声明: 自由转载-非商用-禁止演绎-保持署名| CC BY-NC-ND 3.0
浙公网安备 33010602011771号