lichao_normal

导航

KNN学习

1 KNN估计

  机器学习中参数估计学习的好处在于,它把估计概率密度、判别式或回归问题归结为少量的参数估计问题。通常需要假设样本服从某种分布,如高斯分布、Beta分布等。但是,考虑到我们通常面对的样本数据都是未知的,这些假设并不是总是成立的或者是不够准确。因此,假设错误或是不够准确,都难以实现准确的估计,直接导致计算结果出现较大的误差。

  在非参数估计中,不考虑数据的分布信息,依赖于输入的数据是否存在相似性进行判定。这种依赖是假设相似的事物,其数据特征也具有相似性,通过判定数据特征的相似性,进行分类识别。由于非参数估计方法并不对先验概率进行假设,其计算的复杂性仅仅与样本数据集的大小有关。

  KNN作为一种常见的费参数估计方法,与直方图相类似,通过估计局部的数据概率密度实现。主要区别在于,直方图是通过移动固定的窗,统计窗内的概率密度,而KNN估计则是不固定窗的大小,而固定窗内的数据点的个数k,并以此计算窗的宽度。密度高对应窗的宽度要小,密度低则对应窗的宽度要大。

  KNN中近邻描述的是两个数据点之间的相似性,可以通过各种相似性度量值计算,如:余弦相似度、马氏距离、欧式距离、曼哈顿距离等。其中,k表示近邻数目,其应当小鱼样本大小N。KNN的密度估计可以表示为:

$\widehat{p}(x)=\frac{k}{2Nd_{k}(x)}$

   其中,$d_{k}(x)$表示x的第k个最近距离的数据点与其之间的距离.

 

 

 

  

 

    可以看出KNN概率估计并不光滑,主要原因是其使用的是硬截断,可以使用核函数进行平滑,如高斯核函数。


 

2 KNN分类

  KNN用作分类时,使用类条件概率估计$p(\textbf{x}|C_{i})$去估计条件概率$p(C_{i}|\textbf{x})$,类条件密度估计为:

$\widehat{p}(\textbf{x}|C_{i})=\frac{1}{N_{i}h}\sum_{t=1}^{N}K(\frac{\textbf{x}-\textbf{x}^{t}}{h})r_{i}^{t}$

  其中,如果$\textbf{x}^{t}$属于$C_{i}$,则$r_{i}^{t}=1$。$N_{i}$是属于$C_{i}$的所有样本的数量。并且先验概率密度$\widehat{p}(C_{i}=\frac{N_{i}}{N})$。因此,根据贝叶斯公式就可以得出后验概率:

$\widehat{p}(C_{i}|\textbf{x})=\widehat{p}(\textbf{x}|C_{i})\widehat{p}(C_{i})=\frac{1}{Nh}\sum_{t=1}^{N}K(\frac{\textbf{x}-\textbf{x}^{t}}{h})r_{i}^{t}$

  公式含义表明,$\textbf{x}$属于某个类别的概率最大时,则$\textbf{x}$属于$C_{i}$。且,每个样本都会参与到$\textbf{x}$是否属于$C_{i}$的判定过程中,当然只有最小的k个样本才会对$\textbf{x}$起到决定性的作用($k$是核函数的截断阈值,第$K$个以后的样本贡献全都被判定为0,其中$K(\cdot)$其中加权的作用,对于不同样本的贡献,给于不同的权值响应。

  公式的另一种表现形式:

$\widehat{p}(\textbf{x}|C_{i})=\frac{k_{i}}{N_{i}V^{k}(\textbf{x})}$

  其中,$k_{i}$表示前$k$个样本中属于$C_{i}$的个数,而$V^{k}(x)$则是$x$的邻近超球的体积,相当于一维情况下窗的面积,特别值得注意的是超球的半径由距离$\textbf{x}$的第$k$个最短距离半径决定。所以,

$\widehat{p}(C_{i}|\textbf{x})=\frac{k_{i}}{k}$

  KNN分类判定方式:将输入划分到某个类别,其包含$k$个距离最近距离样本最多。所有的k个近邻都有相同的贡献度,贡献最多的类别,就是$\textbf{x}$的所属类别。

  算法过程:

  1.   计算已知类别数据集中的点与当前点之间的距离;
  2.   按照距离递增次序排序;选取与当前点距离最小的k个点;
  3.   确定前k个点所在类别的出现频率;
  4.   返回前k个点出现频率最高的类别作为当前点的预测分类。

 

 

3 代码学习

python

import operator
import matplotlib.pyplot as plt
import numpy as np
def CreateDataSet():
    group = np.array([[1.0,1.1], [1.0,1.0], [0,0], [0,0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels
def disp_view(group, labels):
    pos = [i for i in range(len(labels)) if labels[i]=='A']
    neg = [i for i in range(len(labels)) if labels[i]=='B']
    s = group[pos]
    s1 = group[neg]
    x1 = s[:, 0]
    y1 = s[:, 1]
    x2 = s1[:, 0]
    y2 = s1[:, 1]
    plt.scatter(x1, y1, c='r')
    plt.scatter(x2, y2)
    plt.show()
    return
def classfiery(inx, dataset, labels, k):
    datasetSize = dataset.shape[0]
    diffMat = np.tile(inx, (datasetSize, 1)) - dataset
    sqDiffMat = diffMat ** 2
    sqDistance = sqDiffMat.sum(axis = 1)
    distance = sqDistance ** 0.5
    sortedDistIndicies = distance.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]
def file2mat(filename):
    fp = open(filename)
    arrayLines = fp.readlines()
    numberOfLines = len(arrayLines)
    returnMat = np.zeros((numberOfLines, 3))
    classLabelVector = []
    index = 0
    for line in arrayLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index, :] = listFromLine[0 : 3]
        classLabelVector.append(listFromLine[-1])
        index += 1
    return returnMat, classLabelVector
def autoNorm(dataset):
    minVal = dataset.min(0)
    maxVal = dataset.min(0)
    diff = maxVal - minVal
    normDataSet = np.zeros(np.shape(dataset))
    m = np.shape(dataset)[0]
    normDataSet = dataset - np.tile(minVal, (m, 1))
    normDataSet = normDataSet / np.tile(diff, (m, 1))
    return normDataSet, diff, minVal
def datingClassTest():
    hoRatio = 0.1
    datingDataMat, datingLabel = file2mat('F:\workshop\KNN\datingtestset.txt')
    normMat, diff, minVal = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVec = int(m * hoRatio)
    errorCount = 0
    for i in range(numTestVec):
        res = classfiery(normMat[i, :], normMat[numTestVec:m,],\
                         datingLabel[numTestVec : m],3)
        if(res != datingLabel[i]):
            errorCount += 1
    print("错误识别概率为:%f"%(errorCount/m))
group,labels = CreateDataSet()
disp_view(group, labels)
print(classfiery([1,1], group, labels, 3))
datingClassTest()
View Code

 

tile(A, reps)
返回一个重复A多次的多为数组

参数
----------
A : 数组
reps : 重复形式,如(2,3)[[A, A,A],[A,A,A]] 

 

例子

----------
a = np.array([0, 1, 2])
np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],
[0, 1, 2, 0, 1, 2]])
np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],
[[0, 1, 2, 0, 1, 2]]])

b = np.array([[1, 2], [3, 4]])
np.tile(b, 2)
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
np.tile(b, (2, 1))
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])

c = np.array([1,2,3,4])
np.tile(c,(4,1))
array([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])

 

argsort(a, axis=-1, kind='quicksort', order=None)

返回排序后的对应索引的序列

a:输入

kind:排序方法

order:升序或者降序

例子

---------------

x = np.array([3, 1, 2])
np.argsort(x)
array([1, 2, 0])

Two-dimensional array:

x = np.array([[0, 3], [2, 2]])
x
array([[0, 3],
[2, 2]])

np.argsort(x, axis=0)
array([[0, 1],
[1, 0]])

np.argsort(x, axis=1)
array([[0, 1],
[0, 1]])

Sorting with keys:

x = np.array([(1, 0), (0, 1)], dtype=[('x', '<i4'), ('y', '<i4')])
x
array([(1, 0), (0, 1)],
dtype=[('x', '<i4'), ('y', '<i4')])

np.argsort(x, order=('x','y'))
array([1, 0])

np.argsort(x, order=('y','x'))
array([0, 1])

Matlab

测试

 [group, labels] = createDataSet();
 DrawScatter(group, labels);
 disp(classify0([0,0], group, labels, 3));
%% 导入文本数据
[fileMat,labels] = importfile('datingtestset.txt');
DrawScatter(fileMat(:,1:2), labels);
%% 数据归一化处理
max_Val = max(fileMat);
min_Val = min(fileMat);
new_Mat = (fileMat - repmat(min_Val, length(labels), 1)) ./ (repmat(max_Val - min_Val, length(labels), 1));

%% 测试分类器性能
testRatio = 0.1;
error = 0;
testNum = testRatio * length(labels);
for i = 1 : testNum
   res = classify0(new_Mat(i,:), new_Mat(testNum + 1 : length(labels),:), labels(testNum + 1 :length(labels), :), 3, 1);
   if(res ~= labels(i))
     error = error + 1;
   end
end
sprintf('错误匹配率为:%f',error/testNum) 

导入数据

function [fileMat,label_num] = importfile(filename, startRow, endRow)
%IMPORTFILE 将文本文件中的数值数据作为列矢量导入。
%   [VARNAME1,VARNAME2,VARNAME3,LARGEDOSES] = IMPORTFILE(FILENAME) 读取文本文件
%   FILENAME 中默认选定范围的数据。
%
%   [VARNAME1,VARNAME2,VARNAME3,LARGEDOSES] = IMPORTFILE(FILENAME,
%   STARTROW, ENDROW) 读取文本文件 FILENAME 的 STARTROW 行到 ENDROW 行中的数据。
%
% Example:
%   [VarName1,VarName2,VarName3,largeDoses] = importfile('datingtestset.txt',1, 1000);
%
%    另请参阅 TEXTSCAN。

% 由 MATLAB 自动生成于 2017/12/09 12:07:08

%% 初始化变量。
delimiter = '\t';
if nargin<=2
    startRow = 1;
    endRow = inf;
end

%% 将数据列作为字符串读取:
% 有关详细信息,请参阅 TEXTSCAN 文档。
formatSpec = '%s%s%s%s%[^\n\r]';

%% 打开文本文件。
fileID = fopen(filename,'r');

%% 根据格式字符串读取数据列。
% 该调用基于生成此代码所用的文件的结构。如果其他文件出现错误,请尝试通过导入工具重新生成代码。
dataArray = textscan(fileID, formatSpec, endRow(1)-startRow(1)+1, 'Delimiter', delimiter, 'HeaderLines', startRow(1)-1, 'ReturnOnError', false);
for block=2:length(startRow)
    frewind(fileID);
    dataArrayBlock = textscan(fileID, formatSpec, endRow(block)-startRow(block)+1, 'Delimiter', delimiter, 'HeaderLines', startRow(block)-1, 'ReturnOnError', false);
    for col=1:length(dataArray)
        dataArray{col} = [dataArray{col};dataArrayBlock{col}];
    end
end

%% 关闭文本文件。
fclose(fileID);

%% 将包含数值字符串的列内容转换为数值。
% 将非数值字符串替换为 NaN。
raw = repmat({''},length(dataArray{1}),length(dataArray)-1);
for col=1:length(dataArray)-1
    raw(1:length(dataArray{col}),col) = dataArray{col};
end
numericData = NaN(size(dataArray{1},1),size(dataArray,2));

for col=[1,2,3]
    % 将输入元胞数组中的字符串转换为数值。已将非数值字符串替换为 NaN。
    rawData = dataArray{col};
    for row=1:size(rawData, 1);
        % 创建正则表达式以检测并删除非数值前缀和后缀。
        regexstr = '(?<prefix>.*?)(?<numbers>([-]*(\d+[\,]*)+[\.]{0,1}\d*[eEdD]{0,1}[-+]*\d*[i]{0,1})|([-]*(\d+[\,]*)*[\.]{1,1}\d+[eEdD]{0,1}[-+]*\d*[i]{0,1}))(?<suffix>.*)';
        try
            result = regexp(rawData{row}, regexstr, 'names');
            numbers = result.numbers;
            
            % 在非千位位置中检测到逗号。
            invalidThousandsSeparator = false;
            if any(numbers==',');
                thousandsRegExp = '^\d+?(\,\d{3})*\.{0,1}\d*$';
                if isempty(regexp(thousandsRegExp, ',', 'once'));
                    numbers = NaN;
                    invalidThousandsSeparator = true;
                end
            end
            % 将数值字符串转换为数值。
            if ~invalidThousandsSeparator;
                numbers = textscan(strrep(numbers, ',', ''), '%f');
                numericData(row, col) = numbers{1};
                raw{row, col} = numbers{1};
            end
        catch me
        end
    end
end


%% 将数据分割为数值列和元胞列。
rawNumericColumns = raw(:, [1,2,3]);
rawCellColumns = raw(:, 4);


%% 将导入的数组分配给列变量名称
VarName1 = cell2mat(rawNumericColumns(:, 1));
VarName2 = cell2mat(rawNumericColumns(:, 2));
VarName3 = cell2mat(rawNumericColumns(:, 3));
largeDoses = rawCellColumns(:, 1);


%% 将返回值调整为合适的形式
label = unique(largeDoses);
label_num = zeros(length(largeDoses), 1);
for i = 1 :  length(label)
    [r, c] = find(strcmp(largeDoses, largeDoses{i, 1}));
    label_num(r,:) = i;
end
fileMat = [VarName1 VarName2 VarName3];

 KNN分类

function flag = classify0(inx, dataSet, labels, k, flagType)
%% KNN
    if nargin <= 4
        flagType = 0;
    end
    [m, n] = size(dataSet);
    inx = repmat(inx, m, 1);
    disMat = inx - dataSet;
    sqDistance = sum(disMat .^ 2, 2);
    Distance = sqrt(sqDistance);
    [Dis, I] = sort(Distance);
    if flagType == 1
        classCount = containers.Map('KeyType', 'double', 'ValueType', 'double');
    else
        classCount = containers.Map();
    end
    for i = 1 : k
        voteIlabel = labels(I(i));
        if isKey(classCount, voteIlabel)
            classCount(voteIlabel) = classCount(voteIlabel) + 1;
        else
            classCount(voteIlabel) = 1;
        end
    end
    keySet = cell2mat(keys(classCount));
    valueSet = cell2mat(values(classCount));
    [valueSet, I] = sort(valueSet,'descend');
    keySet = keySet(I);
    flag = keySet(1);
end

 

4 小结

1、优点

简单,易于理解,易于实现,无需参数估计的假设约束;

适合对稀有事件进行分类(例如当流失率很低时,比如低于0.5%,构造流失预测模型);

对异常值不敏感(可用做异常值检测算法)

特别适合于多分类问题(multi-modal,对象具有多个类别标签)。


2、缺点

计算量大,内存开销大;

可解释性较差,无法给出决策树那样的规则。

posted on 2017-12-09 16:18  lichao_normal  阅读(381)  评论(0)    收藏  举报