k近邻算法的Java实现

k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系。输入没有标签的新数据之后,将新数据的每个特征和样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据的分类标签作为新数据的标签。一般来说,我们只选取样本数据中前k个最相似的数据。

Java实现:

KNNData.java

package KNN;

public class KNNData implements Comparable<KNNData>{
    double c1;
    double c2;
    double c3;
    double distance;
    String type;
    
    public KNNData(double c1, double c2, double c3, String type) {
        this.c1 = c1;
        this.c2 = c2;
        this.c3 = c3;
        this.type = type;
    }
    
    @Override
    public int compareTo(KNNData arg0) {
        return Double.valueOf(this.distance).compareTo(Double.valueOf(arg0.distance));
    }    
}

KNN.java

package KNN;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class KNN {
    
    //训练集
    private List<KNNData> KNNDS = null;
    
    public KNN(List<KNNData> KNNDS) {
        this.KNNDS = KNNDS;
    }
    
    //欧式距离
    private static double disCal(KNNData i, KNNData td) {
        return Math.sqrt((i.c1 - td.c1)*(i.c1 - td.c1)+(i.c2 - td.c2)*(i.c2 - td.c2)+
                (i.c3 - td.c3)*(i.c3 - td.c3));
    }
    
    private static String getMaxValueKey(int k, List<KNNData> ts){
        //只保留前k个元素
        
        while(ts.size() != k) {
            ts.remove(k);
        }
                
        String sKey;
        //保存key以及出现次数
        HashMap<String,Integer> keySet = new HashMap<String,Integer>();
        keySet.put(ts.get(0).type,1);
        for (int x = 1; x < ts.size(); x++) {
            sKey = ts.get(x).type;
            if (keySet.containsKey(sKey)) {
                keySet.put(sKey, keySet.get(sKey)+1);
            } else {
                keySet.put(sKey, 1);
            }
        }
        Set<Map.Entry<String,Integer>> set = keySet.entrySet();
        Iterator<Map.Entry<String,Integer>> iter = set.iterator(); 
        
        int mValue = 0;
        String mType = "";
        while (iter.hasNext()){
            Map.Entry<String,Integer> map = iter.next();
            if (mValue < map.getValue()) {
                mType = map.getKey();
                mValue = map.getValue();
            }
        }
        
        return mType;
    }
    
    public static String knnCal(int k, KNNData i, List<KNNData> ts) {
        //保存距离
        for (KNNData td : ts) {
            td.distance = disCal(i, td);
        }
        Collections.sort(ts);    
        return getMaxValueKey(k, ts);
    }
}

KNNTest.java

package KNN;

import java.util.ArrayList;
import java.util.List;

public class KNNTest {

    public static void main(String[] args) {
        
        List<KNNData> kd = new ArrayList<KNNData>();
        //训练集
        kd.add(new KNNData(1.2,1.1,0.1,"A"));
        kd.add(new KNNData(1.2,1.1,0.1,"A"));
        kd.add(new KNNData(7,1.5,0.1,"B"));
        kd.add(new KNNData(6,1.2,0.1,"B"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(100,1.1,0.1,"D"));

        System.out.println(KNN.knnCal(3, new KNNData(1.1,1.1,0.1,"N/A"), kd));
    }
}

 

posted @ 2016-03-03 10:00  finalboss1987  阅读(2329)  评论(0编辑  收藏  举报