KNN算法

KNN算法的英文名称是K-Nearest Neighbour,即k最近邻算法。它的基本思想是先把所有的训练样例存储起来,当有需要分类的实例时,用该实例和所有的训练样例进行相似度比较,然后找出最相似的k个训练样例,然后查看这个k个训练样例哪个分类占比最高,就把该分类赋给需要分类的实例。比较实例的相似度一般使用距离来度量。

优点:算法简单易懂、准确性高、对数据没有特殊要求

缺点:对离群值敏感、计算量大、内存需求大

适用于数值型定量数据和名词型定性数据

代码如下:

public class KNN
{
	/**classify an instance
	 * 
	 * @param instance need to classify
	 * @param trainingSet the training instances
	 * @param labels the labels of the training instances
	 * @param k the number of need to calculate k nearest neighbors 
	 * @return the class of instance
	 */
	public int classify( double[] instance, double[][] trainingSet, int[] labels, int k )
	{
		int trainingInstancesSize = trainingSet.length;
		double[] distances = new double[trainingInstancesSize];
		
		RealVector instanceVector = MatrixUtils.createRealVector( instance );
		for( int i = 0; i < trainingInstancesSize; i++ )
		{
			RealVector trainingVetor = MatrixUtils.createRealVector( trainingSet[i] );
			distances[i] = instanceVector.getDistance( trainingVetor );
		}
		
		//sorted all distances from small to large
		TreeMap<Double, Integer> sortedDistances = new TreeMap<Double, Integer>();
		for( int i = 0; i < trainingInstancesSize; i++ )
		{
			sortedDistances.put( distances[i], i );
		}
		
		Map<Integer, Integer> resultMap = new HashMap<Integer, Integer>();
		//get the k smallest distances
		for( int i = 0; i < k; i++ )
		{
			Entry<Double, Integer> entry = sortedDistances.pollFirstEntry();
			int index = entry.getValue();
			
			int label = labels[index];
			
			if( null == resultMap.get( label ) )
			{
				resultMap.put( label, 0 );
			}
			
			resultMap.put( label, resultMap.get( label ) + 1 );
			
		}
		
//		vote to get the class of instance
		int classfier = findMaxValueKey( resultMap );
		
		return classfier;
	}
	
	/**
	 * find the max value's key
	 * @param map
	 * @return
	 */
	private <T> T findMaxValueKey( Map<T, Integer> map )
	{
		int max = Integer.MIN_VALUE;
		T maxValueKey = null;
		
		for( Entry<T, Integer> entry : map.entrySet() )
		{
			int value = entry.getValue();
			if( max < value )
			{
				max = value;
				maxValueKey = entry.getKey();
			}
		}
		
		return maxValueKey;
	}
        public static void main( String[] args )
	{
		double[][] examples = { {1, 1.1}, {1,1}, {0,0}, {0,0.1} };
		int[] labels = { 0, 0, 1, 1 };
		
		double[] instance = { 0.3, 0.4 };
		
		System.err.println( new KNN().classify( instance, examples, labels, 3 ) );
		
	}        

 

posted @ 2013-07-14 23:01  叶莞尔  阅读(642)  评论(0)    收藏  举报