KNN算法
KNN算法的英文名称是K-Nearest Neighbour,即k最近邻算法。它的基本思想是先把所有的训练样例存储起来,当有需要分类的实例时,用该实例和所有的训练样例进行相似度比较,然后找出最相似的k个训练样例,然后查看这个k个训练样例哪个分类占比最高,就把该分类赋给需要分类的实例。比较实例的相似度一般使用距离来度量。
优点:算法简单易懂、准确性高、对数据没有特殊要求
缺点:对离群值敏感、计算量大、内存需求大
适用于数值型定量数据和名词型定性数据
代码如下:
public class KNN
{
/**classify an instance
*
* @param instance need to classify
* @param trainingSet the training instances
* @param labels the labels of the training instances
* @param k the number of need to calculate k nearest neighbors
* @return the class of instance
*/
public int classify( double[] instance, double[][] trainingSet, int[] labels, int k )
{
int trainingInstancesSize = trainingSet.length;
double[] distances = new double[trainingInstancesSize];
RealVector instanceVector = MatrixUtils.createRealVector( instance );
for( int i = 0; i < trainingInstancesSize; i++ )
{
RealVector trainingVetor = MatrixUtils.createRealVector( trainingSet[i] );
distances[i] = instanceVector.getDistance( trainingVetor );
}
//sorted all distances from small to large
TreeMap<Double, Integer> sortedDistances = new TreeMap<Double, Integer>();
for( int i = 0; i < trainingInstancesSize; i++ )
{
sortedDistances.put( distances[i], i );
}
Map<Integer, Integer> resultMap = new HashMap<Integer, Integer>();
//get the k smallest distances
for( int i = 0; i < k; i++ )
{
Entry<Double, Integer> entry = sortedDistances.pollFirstEntry();
int index = entry.getValue();
int label = labels[index];
if( null == resultMap.get( label ) )
{
resultMap.put( label, 0 );
}
resultMap.put( label, resultMap.get( label ) + 1 );
}
// vote to get the class of instance
int classfier = findMaxValueKey( resultMap );
return classfier;
}
/**
* find the max value's key
* @param map
* @return
*/
private <T> T findMaxValueKey( Map<T, Integer> map )
{
int max = Integer.MIN_VALUE;
T maxValueKey = null;
for( Entry<T, Integer> entry : map.entrySet() )
{
int value = entry.getValue();
if( max < value )
{
max = value;
maxValueKey = entry.getKey();
}
}
return maxValueKey;
}
public static void main( String[] args )
{
double[][] examples = { {1, 1.1}, {1,1}, {0,0}, {0,0.1} };
int[] labels = { 0, 0, 1, 1 };
double[] instance = { 0.3, 0.4 };
System.err.println( new KNN().classify( instance, examples, labels, 3 ) );
}
浙公网安备 33010602011771号