K-Means算法实现[原创]

之前在网上找到一个算法的实现,不过总是没有得到正确的结果,于是静下心来自己写一个!

程序分成三个类:

package com.zhoujianxiang;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/**
 * 数据文件:data2.txt
 * 	3
 * 	7 8
 * 	12 1
 * 	13 6
 * 	13 13
 * 	13 19
 * 	14 5
 * 	17 16
 * 	19 20
 * 	20 7
 * 	8 13
 * @author Administrator
 *
 */
public class Kmeans {

    int K; //聚类参数,聚成几个类。
    int Generation=100;//聚类次数,聚类结束条件之一。
    double E=7.1;//误差参数,聚类结束的条件之一。
    static ArrayList<Point> allPoints = new ArrayList<Point>();//参与聚类的所有点
    int totalNumber = 0;//聚类的点的个数
    Map<Integer, Point> centerMap = new HashMap<Integer, Point>();//动态重心的记录
    static Point[] start_points = null;

    /**
     * 把所有的点添加到一个list中
     */
    public void prepare() throws IOException {
        File file = new File("C:\\data2.txt");
        FileReader fr = new FileReader(file);
        BufferedReader br = new BufferedReader(fr);
        String data = "0";
        Double x = 0.0;
        Double y = 0.0;
        String[] temp = null;
        Point p = null;
        K = Integer.parseInt(br.readLine());
        int _k_ = K;
        start_points = new Point[K];
        while ((data = br.readLine()) != null) {
            temp = data.split(" ");
            x = Double.valueOf(temp[0]);
            y = Double.valueOf(temp[1]);
            p = new Point(totalNumber++, x, y);
            allPoints.add(p);
            if (_k_>0) {
            	start_points[--_k_] = new Point(_k_, x, y);
			}
        }
    }

    //根据聚类中心初始化聚类信息
    public ArrayList<Cluster> beforCP(Set<Point> center) {
        ArrayList<Cluster> cluster = new ArrayList<Cluster>();//存放几个类的信息
        Iterator<Point> it = center.iterator();
        while (it.hasNext()) {
        	Point i = it.next();
            Cluster c = new Cluster();//代表一个聚类
            c.setCenter(i);
            cluster.add(c);
        }
        return cluster;
    }

    //聚类操作
    public ArrayList<Cluster> clusterProcess(Point[] center) {
    	//重心移动变化过程
    	for (int i = 0; i < center.length; i++) {
			System.out.println("points"+i+":"+center[i].getX()+","+center[i].getY());
		}
    	System.out.println("---------------");
    	
        ArrayList<Cluster> cluster = new ArrayList<Cluster>();
        Point[] points = new Point[center.length];
        for (int i = 0 ; i< center.length ; i++) {
        	Cluster c = new Cluster();//代表一个聚类
            c.setCenter(center[i]);
			cluster.add(c);
		}
        Point source = null;
        Point dest = null;
        for (int i = 0; i < allPoints.size(); i++) {
            Double mindistence = Double.MAX_VALUE;
            int mincenter = 0;
            for (int j = 0; j < center.length; j++) {
            	source = allPoints.get(i);
            	dest = center[j];
                	Double tempdist = (Double) Math.sqrt(StrictMath.pow(source.getX() - dest.getX(), 2) + StrictMath.pow(source.getY() - dest.getY(), 2));
                	if(tempdist < mindistence){
                		mindistence = tempdist;
                		mincenter = j;
                	}
            }
            for (int j = 0; j < cluster.size(); j++) {
				if (cluster.get(j).getCenter().equals(center[mincenter])) {
					cluster.get(j).addPoints(allPoints.get(i));
				}
			}
        }
        
        for (int i = 0; i < cluster.size(); i++) {
        	Double munx = cluster.get(i).getCenter().getX();;
        	Double muny = cluster.get(i).getCenter().getY();;
			for (int j = 0; j < cluster.get(i).ofCluster.size(); j++) {
				munx += cluster.get(i).ofCluster.get(j).getX();
				muny += cluster.get(i).ofCluster.get(j).getY();
			}
			points[i] = new Point(i, munx/(cluster.get(i).ofCluster.size()+1), muny/(cluster.get(i).ofCluster.size()+1));
			cluster.get(i).setCenter(points[i]);
		}
        //误差范围(由于double类型的数据不能相比较,所以只能认为在一定误差范围内就算是相等了)
        Double error = 0.0;
        for (int i = 0; i < points.length; i++) {
        	error += Math.abs((points[i].getX()-center[i].getX()))+Math.abs((points[i].getY()-center[i].getY()));
		}
        if (error<0.0001) {
        	return cluster;
        }
        return clusterProcess(points);
    }
    public boolean Nocontain(Point p, Point[] ps){
    	for (int i = 0; i < ps.length; i++) {
			if (ps[i].equals(p)) {
				return false;
			}
		}
    	return true;
    }

    //输出聚类信息
    public static void print(ArrayList<Cluster> cs) {
    	for (int i = 0; i < cs.size(); i++) {
    		Cluster c = cs.get(i);
    		System.out.println("重心: " +c.getCenter());
    		ArrayList<Point> p = c.getOfCluster();
    		for (int j = 0; j < p.size(); j++) {
    			System.out.print(p.get(j));
    		}
    		System.out.println();
    	}
    }
    
	public static void main(String[] args) throws IOException {
        Kmeans kmeans = new Kmeans();
        kmeans.prepare();
        ArrayList<Cluster> cs = kmeans.clusterProcess(start_points);
        kmeans.print(cs);
    }
}

 

package com.zhoujianxiang;

import java.util.ArrayList;

public class Cluster {
	Point center;// 聚类中心point的id
	ArrayList<Point> ofCluster = new ArrayList<Point>();// 属于这个聚类的点的集合

	public Point getCenter() {
		return center;
	}

	public void setCenter(Point center) {
		this.center = center;
	}

	public ArrayList<Point> getOfCluster() {
		return ofCluster;
	}

	public void setOfCluster(ArrayList<Point> ofCluster) {
		this.ofCluster = ofCluster;
	}

	public void addPoints(Point point) {
		if (!(this.ofCluster.contains(point)))
			this.ofCluster.add(point);
	}
}

 

package com.zhoujianxiang;

public class Point {
    Double x;//坐标
    Double y;
    int id; //名称
    public Point(int id, Double x, Double y) {
        this.x = x;
        this.y = y;
        this.id = id;
    }
    public Point(){
    }
    public Double getX() {
        return x;
    }
    public void setX(Double x) {
        this.x = x;
    }
    public Double getY() {
        return y;
    }
    public void setY(Double y) {
        this.y = y;
    }
    public int getId() {
        return id;
    }
    public void setId(int id) {
        this.id = id;
    }
    @Override
    public String toString() {
        return "P"+id+"("+x+","+y+")";
    }    
    @Override
    public boolean equals(Object obj) {
        if(obj!=null){
        	Point p = (Point)obj;
        if(this.x.equals(p.x)&&this.y.equals(p.y)){
            return true;
        }
        else{
            return false;
        }
        }
        return false;
    }  
}

 

posted @ 2011-11-15 19:13  biscuitlife  阅读(6298)  评论(0)    收藏  举报