K-Means算法实现[原创]
之前在网上找到一个算法的实现,不过总是没有得到正确的结果,于是静下心来自己写一个!
程序分成三个类:
package com.zhoujianxiang;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
/**
* 数据文件:data2.txt
* 3
* 7 8
* 12 1
* 13 6
* 13 13
* 13 19
* 14 5
* 17 16
* 19 20
* 20 7
* 8 13
* @author Administrator
*
*/
public class Kmeans {
int K; //聚类参数,聚成几个类。
int Generation=100;//聚类次数,聚类结束条件之一。
double E=7.1;//误差参数,聚类结束的条件之一。
static ArrayList<Point> allPoints = new ArrayList<Point>();//参与聚类的所有点
int totalNumber = 0;//聚类的点的个数
Map<Integer, Point> centerMap = new HashMap<Integer, Point>();//动态重心的记录
static Point[] start_points = null;
/**
* 把所有的点添加到一个list中
*/
public void prepare() throws IOException {
File file = new File("C:\\data2.txt");
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String data = "0";
Double x = 0.0;
Double y = 0.0;
String[] temp = null;
Point p = null;
K = Integer.parseInt(br.readLine());
int _k_ = K;
start_points = new Point[K];
while ((data = br.readLine()) != null) {
temp = data.split(" ");
x = Double.valueOf(temp[0]);
y = Double.valueOf(temp[1]);
p = new Point(totalNumber++, x, y);
allPoints.add(p);
if (_k_>0) {
start_points[--_k_] = new Point(_k_, x, y);
}
}
}
//根据聚类中心初始化聚类信息
public ArrayList<Cluster> beforCP(Set<Point> center) {
ArrayList<Cluster> cluster = new ArrayList<Cluster>();//存放几个类的信息
Iterator<Point> it = center.iterator();
while (it.hasNext()) {
Point i = it.next();
Cluster c = new Cluster();//代表一个聚类
c.setCenter(i);
cluster.add(c);
}
return cluster;
}
//聚类操作
public ArrayList<Cluster> clusterProcess(Point[] center) {
//重心移动变化过程
for (int i = 0; i < center.length; i++) {
System.out.println("points"+i+":"+center[i].getX()+","+center[i].getY());
}
System.out.println("---------------");
ArrayList<Cluster> cluster = new ArrayList<Cluster>();
Point[] points = new Point[center.length];
for (int i = 0 ; i< center.length ; i++) {
Cluster c = new Cluster();//代表一个聚类
c.setCenter(center[i]);
cluster.add(c);
}
Point source = null;
Point dest = null;
for (int i = 0; i < allPoints.size(); i++) {
Double mindistence = Double.MAX_VALUE;
int mincenter = 0;
for (int j = 0; j < center.length; j++) {
source = allPoints.get(i);
dest = center[j];
Double tempdist = (Double) Math.sqrt(StrictMath.pow(source.getX() - dest.getX(), 2) + StrictMath.pow(source.getY() - dest.getY(), 2));
if(tempdist < mindistence){
mindistence = tempdist;
mincenter = j;
}
}
for (int j = 0; j < cluster.size(); j++) {
if (cluster.get(j).getCenter().equals(center[mincenter])) {
cluster.get(j).addPoints(allPoints.get(i));
}
}
}
for (int i = 0; i < cluster.size(); i++) {
Double munx = cluster.get(i).getCenter().getX();;
Double muny = cluster.get(i).getCenter().getY();;
for (int j = 0; j < cluster.get(i).ofCluster.size(); j++) {
munx += cluster.get(i).ofCluster.get(j).getX();
muny += cluster.get(i).ofCluster.get(j).getY();
}
points[i] = new Point(i, munx/(cluster.get(i).ofCluster.size()+1), muny/(cluster.get(i).ofCluster.size()+1));
cluster.get(i).setCenter(points[i]);
}
//误差范围(由于double类型的数据不能相比较,所以只能认为在一定误差范围内就算是相等了)
Double error = 0.0;
for (int i = 0; i < points.length; i++) {
error += Math.abs((points[i].getX()-center[i].getX()))+Math.abs((points[i].getY()-center[i].getY()));
}
if (error<0.0001) {
return cluster;
}
return clusterProcess(points);
}
public boolean Nocontain(Point p, Point[] ps){
for (int i = 0; i < ps.length; i++) {
if (ps[i].equals(p)) {
return false;
}
}
return true;
}
//输出聚类信息
public static void print(ArrayList<Cluster> cs) {
for (int i = 0; i < cs.size(); i++) {
Cluster c = cs.get(i);
System.out.println("重心: " +c.getCenter());
ArrayList<Point> p = c.getOfCluster();
for (int j = 0; j < p.size(); j++) {
System.out.print(p.get(j));
}
System.out.println();
}
}
public static void main(String[] args) throws IOException {
Kmeans kmeans = new Kmeans();
kmeans.prepare();
ArrayList<Cluster> cs = kmeans.clusterProcess(start_points);
kmeans.print(cs);
}
}
package com.zhoujianxiang;
import java.util.ArrayList;
public class Cluster {
Point center;// 聚类中心point的id
ArrayList<Point> ofCluster = new ArrayList<Point>();// 属于这个聚类的点的集合
public Point getCenter() {
return center;
}
public void setCenter(Point center) {
this.center = center;
}
public ArrayList<Point> getOfCluster() {
return ofCluster;
}
public void setOfCluster(ArrayList<Point> ofCluster) {
this.ofCluster = ofCluster;
}
public void addPoints(Point point) {
if (!(this.ofCluster.contains(point)))
this.ofCluster.add(point);
}
}
package com.zhoujianxiang;
public class Point {
Double x;//坐标
Double y;
int id; //名称
public Point(int id, Double x, Double y) {
this.x = x;
this.y = y;
this.id = id;
}
public Point(){
}
public Double getX() {
return x;
}
public void setX(Double x) {
this.x = x;
}
public Double getY() {
return y;
}
public void setY(Double y) {
this.y = y;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
@Override
public String toString() {
return "P"+id+"("+x+","+y+")";
}
@Override
public boolean equals(Object obj) {
if(obj!=null){
Point p = (Point)obj;
if(this.x.equals(p.x)&&this.y.equals(p.y)){
return true;
}
else{
return false;
}
}
return false;
}
}

浙公网安备 33010602011771号