夏天/isummer

Sun of my life !Talk is cheap, Show me the code! 追风赶月莫停留,平芜尽处是春山~

博客园 首页 新随笔 联系 管理

K最近邻(k-Nearest Neighbour,KNN)分类算法

1.K最近邻(k-Nearest Neighbour,KNN)

  K最近邻(k-Nearest Neighbour,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。用官方的话来说,所谓K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是上面所说的K个邻居), 这K个实例的多数属于某个类,就把该输入实例分类到这个类中。

 

 

 

 

2.算法原理

  

  如上图所示,有两类不同的样本数据,分别用蓝色的小正方形和红色的小三角形表示,而图正中间的那个绿色的圆所标示的数据则是待分类的数据。也就是说,现在, 我们不知道中间那个绿色的数据是从属于哪一类(蓝色小正方形or红色小三角形),下面,我们就要解决这个问题:给这个绿色的圆分类。
  我们常说,物以类聚,人以群分,判别一个人是一个什么样品质特征的人,常常可以从他/她身边的朋友入手,所谓观其友,而识其人。我们不是要判别上图中那个绿色的圆是属于哪一类数据么,好说,从它的邻居下手。但一次性看多少个邻居呢?从上图中,你还能看到:

  • 如果K=3,绿色圆点的最近的3个邻居是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类。
  • 如果K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类。
  KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN 算法本身简单有效,它是一种 lazy-learning 算法,分类器不需要使用训练集进行训练,训练时间复杂度为0。KNN 分类的计算复杂度和训练集中的文档数目成正比,也就是说,如果训练集中文档总数为 n,那么 KNN 的分类时间复杂度为O(n)。
  KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
  
  K 近邻算法使用的模型实际上对应于对特征空间的划分。K 值的选择,距离度量和分类决策规则是该算法的三个基本要素:
  (1)K 值的选择会对算法的结果产生重大影响。K值较小意味着只有与输入实例较近的训练实例才会对预测结果起作用,但容易发生过拟合;如果 K 值较大,优点是可以减少学习的估计误差,但缺点是学习的近似误差增大,这时与输入实例较远的训练实例也会对预测起作用,是预测发生错误。在实际应用中,K 值一般选择一个较小的数值,通常采用交叉验证的方法来选择最优的 K 值。随着训练实例数目趋向于无穷和 K=1 时,误差率不会超过贝叶斯误差率的2倍,如果K也趋向于无穷,则误差率趋向于贝叶斯误差率。
  (2)该算法中的分类决策规则往往是多数表决,即由输入实例的 K 个最临近的训练实例中的多数类决定输入实例的类别
  (3)距离度量一般采用 Lp 距离,当p=2时,即为欧氏距离,在度量之前,应该将每个属性的值规范化,这样有助于防止具有较大初始值域的属性比具有较小初始值域的属性的权重过大。

  KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight), 如权值与距离成反比。 该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本 时,该样本的K个邻居中大容量类的样本占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行 结果。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。

3.算法不足

  该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。
  实现 K 近邻算法时,主要考虑的问题是如何对训练数据进行快速 K 近邻搜索,这在特征空间维数大及训练数据容量大时非常必要。

4.程序代码

  以下两个片段一个是算法的主题,一个是主程序调用

  1 /************************************************************************/
  2 /*KNN分类算法C++实现版本    2015/11/5                                             */
  3 /************************************************************************/
  4 #include<iostream>
  5 #include<fstream>
  6 
  7 #include<map>
  8 #include<vector>
  9 #include <string>
 10 #include<cmath>
 11 #include<algorithm>
 12 
 13 
 14 using namespace std;
 15 
 16 typedef string tLabel;    //标签
 17 typedef double tData;    //数据
 18 typedef pair<int,double>  PAIR;    //位置和使用率(概率)
 19 const int MaxColLen = 4;
 20 const int MaxRowLen = 1000;
 21 const int COLNUM = 3;
 22 ifstream fin;
 23 ofstream fout;
 24 
 25 class KNN
 26 {
 27 private:
 28     tData dataSet[MaxRowLen][MaxColLen];    //数据集
 29     tLabel labels[MaxRowLen];    //分类结果的标签
 30     tData testData[MaxColLen];//保留测试数据归一化的结果,或者原来的测试数据
 31     int rowLen;
 32     int k;    //临近的K值
 33     int test_data_num;    //测试数据数量
 34     map<int,double> map_index_dis;    //训练数据集中各个数据到被测试数据的距离,并以训练数据的序号为索引
 35     map<tLabel,int> map_label_freq;    //KNN中的K个距离最近的特征分类结果与对应的频率
 36     double get_distance(tData *d1,tData *d2);
 37 public:
 38     KNN();
 39     //预处理,获得相应的距离
 40     void get_all_distance();
 41     //获取最大的标签使用频率作为最大的概率
 42     tLabel get_max_freq_label();
 43     //归一化训练数据
 44     void auto_norm_data();
 45     //测试误差率
 46     void get_error_rate();
 47     //构造比较器,内部类型
 48     struct CmpByValue
 49     {
 50         bool operator() (const PAIR& lhs,const PAIR& rhs)
 51         {
 52             return lhs.second < rhs.second;
 53         }
 54     };
 55 
 56     ~KNN();    
 57 };
 58 
 59 KNN::~KNN()
 60 {
 61     fin.close();
 62     fout.close();
 63     map_index_dis.clear();
 64     map_label_freq.clear();
 65 }
 66 /************************************************************************/
 67 /*  KNN算法构造函数                                                     */
 68 /************************************************************************/
 69 KNN::KNN()
 70 {
 71     this->rowLen = 1000;
 72     //this->colLen = col;
 73     this->k = 7;
 74     test_data_num = 50;
 75     //训练数据的读取
 76     fin.open("D:\\VC_WorkSpace\\KNNAlg\\datingTestTrainingData.txt");
 77     //输出的结果到一个文件中保存
 78     fout.open("D:\\VC_WorkSpace\\KNNAlg\\result.txt");
 79 
 80     if( !fin || !fout )
 81     {
 82         cout << "can not open the file"<<endl;
 83         exit(0);
 84     }
 85 
 86     for(int i = 0; i < rowLen; i++)
 87     {
 88         for(int j = 0; j < COLNUM; j++)
 89         {
 90             //cout << dataSet[i][j] << "_";
 91             fin >> dataSet[i][j];
 92             fout << dataSet[i][j] << " ";
 93         }
 94         //输入样本中的每个值向量对应的分类结果到“分类结果空间”
 95         fin >> labels[i];
 96         fout << labels[i]<< endl;
 97         //cout << endl;
 98     }
 99 
100 }
101 
102 void KNN:: get_error_rate()
103 {
104     int i,j,count = 0;
105     tLabel label;
106     cout << "please input the number of test data : "<<endl;
107     cin >> test_data_num;//测试数据的数量
108     //以训练数据的前test_data_num作为测试样本
109     for( i = 0; i < test_data_num; i++ )
110     {
111         for(j = 0; j < COLNUM; j++)
112         {
113             testData[j] = dataSet[i][j];        
114         }
115         //训练test_data_num之后的数据作为训练数据
116         get_all_distance();
117         label = get_max_freq_label();//返回分类结果
118         cout << "*******   the lable = " << label << endl;
119         if( label != labels[i] )
120             count++;//分类失败统计器
121         map_index_dis.clear();
122         map_label_freq.clear();
123     }
124     //计算误差
125     cout<<"the error rate is = "<<(double)count/(double)test_data_num<<endl;
126 }
127 /************************************************************************/
128 /* 以欧式距离来表示                                                      */
129 /************************************************************************/
130 double KNN:: get_distance(tData *d1,tData *d2)
131 {
132     double sum = 0;
133     for(int i=0;i<COLNUM;i++)
134     {
135         sum += pow((d1[i] - d2[i]) , 2);
136     }
137     //输出结果显示
138     //cout<<"the sum is = "<< sum << endl;
139     return sqrt(sum);
140 }
141 
142 /************************************************************************/
143 /* 以测试test_data_num序号之后的样本作为训练数据                         */
144 /************************************************************************/
145 void KNN:: get_all_distance()
146 {
147     double distance;
148     int i;
149     for(i = test_data_num; i < rowLen; i++)
150     {
151         distance = get_distance(dataSet[i],testData);
152         map_index_dis[i] = distance;
153     }
154 
155     //  打开注释,可以查看索引距离集合中的内容
156     //    map<int,double>::const_iterator it = map_index_dis.begin();
157     //    while(it!=map_index_dis.end())
158     //    {
159     //        cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;
160     //        it++;
161     //    }
162 
163 }
164 
165 tLabel KNN:: get_max_freq_label()
166 {
167     vector<PAIR> vec_index_dis(map_index_dis.begin(), map_index_dis.end());
168     //对结果进行排序操作,由小到大的顺序排列
169     sort(vec_index_dis.begin(), vec_index_dis.end(), CmpByValue());
170     //取前K个距离最近的特征标签作为分类的参考依据
171     for(int i = 0; i < k; i++)
172     {
173         cout << "Index = " << vec_index_dis[i].first << " Distance = " << vec_index_dis[i].second << " Label = " << labels[vec_index_dis[i].first] << " \nCoordinate(";
174         int j;
175         for(j=0; j < COLNUM - 1; j++)
176         {
177             cout << dataSet[vec_index_dis[i].first][j]<<",";
178         }
179         cout << dataSet[vec_index_dis[i].first][j] << " )" << endl;
180         //统计K邻域的使用频率
181         map_label_freq[ labels[ vec_index_dis[i].first ]  ]++;
182     }
183 
184     map<tLabel,int>::const_iterator map_it = map_label_freq.begin();
185     tLabel label;    //保留频率最大的分类结果信息
186     int max_freq = 0;    //最大的使用频率
187     while( map_it != map_label_freq.end())
188     {
189         if( map_it->second > max_freq)
190         {
191             max_freq = map_it->second;
192             label = map_it->first;
193         }
194         map_it++;
195     }
196     cout << "The test data belongs to the : " << label << " label" << endl;
197     return label;
198 }
199 /************************************************************************/
200 /* 归一化处理                                                           */
201 /************************************************************************/
202 void KNN::auto_norm_data()
203 {
204     tData maxa[COLNUM];    //
205     tData mina[COLNUM];
206     tData range[COLNUM];
207     int i,j;
208     //遍历训练数据,找出数据向量的各个极值,为后续归一化处理
209     for( i = 0; i < COLNUM; i++ )
210     {
211         maxa[i] = max(dataSet[0][i],dataSet[1][i]);
212         mina[i] = min(dataSet[0][i],dataSet[1][i]);
213     }
214 
215     for( i = 2; i < rowLen; i++ )
216     {
217         for(j = 0; j < COLNUM; j++)
218         {
219             if( dataSet[i][j]>maxa[j] )
220             {
221                 maxa[j] = dataSet[i][j];
222             }
223             else if( dataSet[i][j]<mina[j] )
224             {
225                 mina[j] = dataSet[i][j];
226             }
227         }
228     }
229 
230     for( i = 0; i < COLNUM; i++ )
231     {
232         range[i] = maxa[i] - mina[i] ; 
233         //归一化测试数据
234         testData[i] = ( testData[i] - mina[i] )/range[i] ;
235     }
236 
237     //归一化训练数据的各个分量
238     for(i=0;i<rowLen;i++)
239     {
240         for(j=0;j<COLNUM;j++)
241         {
242             dataSet[i][j] = ( dataSet[i][j] - mina[j] )/range[j];
243         }
244     }
245 }

  主程序调用

 1 // KNNAlg.cpp : 定义控制台应用程序的入口点。
 2 #include "stdafx.h"
 3 #include <iostream>
 4 #include "knnbody.h"
 5 using namespace std;
 6 
 7 int _tmain(int argc, _TCHAR* argv[])
 8 {
 9 
10     cout << "the KNN algothm is running ... " << endl;
11     //生成KNN算法对象
12     KNN knn = KNN();
13     //训练数据的预处理
14     knn.auto_norm_data();
15     //对测试样本进行分类操作以及进行错误率统计
16     knn.get_error_rate();
17 
18     system("pause");
19     return 0;
20 }

  备注:以上代码主要是呈现KNN算法的大致结构,程序中有些许细节不尽完善,后续需要更改。

  参考网址:

(1)http://baike.baidu.com/link?url=B_MWlciVjI4Oz2UJQaa09C3xkdkOHHH5OOg3uxE1UiMtG_P4Eq3dMlVQiRqqTqpASFZV8sk1jjiS2gmQDPvZ__
(2)http://blog.csdn.net/lavorange/article/details/16924705
posted on 2015-11-05 20:55  夏天/isummer  阅读(1244)  评论(0编辑  收藏  举报