K-Means聚类算法

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

算法介绍

K-Means又名为K均值算法,他是一个聚类算法,这里的K就是聚簇中心的个数,代表数据中存在多少数据簇。K-Means在聚类算法中算是非常简单的一个算法了。有点类似于KNN算法,都用到了距离矢量度量,用欧式距离作为小分类的标准。

算法步骤

(1)、设定数字k,从n个初始数据中随机的设置k个点为聚类中心点。

(2)、针对n个点的每个数据点,遍历计算到k个聚类中心点的距离,最后按照离哪个中心点最近,就划分到那个类别中。

(3)、对每个已经划分好类别的n个点,对同个类别的点求均值,作为此类别新的中心点。

(4)、循环(2),(3)直到最终中心点收敛。

以上的计算过程将会在下面我的程序实现中有所体现。

算法的代码实现

输入数据:

3 3
4 10
9 6
14 8
18 11
21 7
主实现类:

package DataMining_KMeans;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;

/**
 * k均值算法工具类
 * 
 * @author lyq
 * 
 */
public class KMeansTool {
	// 输入数据文件地址
	private String filePath;
	// 分类类别个数
	private int classNum;
	// 类名称
	private ArrayList<String> classNames;
	// 聚类坐标点
	private ArrayList<Point> classPoints;
	// 所有的数据左边点
	private ArrayList<Point> totalPoints;

	public KMeansTool(String filePath, int classNum) {
		this.filePath = filePath;
		this.classNum = classNum;
		readDataFile();
	}

	/**
	 * 从文件中读取数据
	 */
	private void readDataFile() {
		File file = new File(filePath);
		ArrayList<String[]> dataArray = new ArrayList<String[]>();

		try {
			BufferedReader in = new BufferedReader(new FileReader(file));
			String str;
			String[] tempArray;
			while ((str = in.readLine()) != null) {
				tempArray = str.split(" ");
				dataArray.add(tempArray);
			}
			in.close();
		} catch (IOException e) {
			e.getStackTrace();
		}

		classPoints = new ArrayList<>();
		totalPoints = new ArrayList<>();
		classNames = new ArrayList<>();
		for (int i = 0, j = 1; i < dataArray.size(); i++) {
			if (j <= classNum) {
				classPoints.add(new Point(dataArray.get(i)[0],
						dataArray.get(i)[1], j + ""));
				classNames.add(i + "");
				j++;
			}
			totalPoints
					.add(new Point(dataArray.get(i)[0], dataArray.get(i)[1]));
		}
	}

	/**
	 * K均值聚类算法实现
	 */
	public void kMeansClustering() {
		double tempX = 0;
		double tempY = 0;
		int count = 0;
		double error = Integer.MAX_VALUE;
		Point temp;

		while (error > 0.01 * classNum) {
			for (Point p1 : totalPoints) {
				// 将所有的测试坐标点就近分类
				for (Point p2 : classPoints) {
					p2.computerDistance(p1);
				}
				Collections.sort(classPoints);

				// 取出p1离类坐标点最近的那个点
				p1.setClassName(classPoints.get(0).getClassName());
			}

			error = 0;
			// 按照均值重新划分聚类中心点
			for (Point p1 : classPoints) {
				count = 0;
				tempX = 0;
				tempY = 0;
				for (Point p : totalPoints) {
					if (p.getClassName().equals(p1.getClassName())) {
						count++;
						tempX += p.getX();
						tempY += p.getY();
					}
				}
				tempX /= count;
				tempY /= count;

				error += Math.abs((tempX - p1.getX()));
				error += Math.abs((tempY - p1.getY()));
				// 计算均值
				p1.setX(tempX);
				p1.setY(tempY);

			}
			
			for (int i = 0; i < classPoints.size(); i++) {
				temp = classPoints.get(i);
				System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}",
						(i + 1), temp.getX(), temp.getY()));
			}
			System.out.println("----------");
		}

		System.out.println("结果值收敛");
		for (int i = 0; i < classPoints.size(); i++) {
			temp = classPoints.get(i);
			System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}",
					(i + 1), temp.getX(), temp.getY()));
		}

	}

}
坐标点类:

package DataMining_KMeans;

/**
 * 坐标点类
 * 
 * @author lyq
 * 
 */
public class Point implements Comparable<Point>{
	// 坐标点横坐标
	private double x;
	// 坐标点纵坐标
	private double y;
	//以此点作为聚类中心的类的类名称
	private String className;
	// 坐标点之间的欧式距离
	private Double distance;

	public Point(double x, double y) {
		this.x = x;
		this.y = y;
	}
	
	public Point(String x, String y) {
		this.x = Double.parseDouble(x);
		this.y = Double.parseDouble(y);
	}
	
	public Point(String x, String y, String className) {
		this.x = Double.parseDouble(x);
		this.y = Double.parseDouble(y);
		this.className = className;
	}

	/**
	 * 距离目标点p的欧几里得距离
	 * 
	 * @param p
	 */
	public void computerDistance(Point p) {
		if (p == null) {
			return;
		}

		this.distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
				* (this.y - p.y);
	}

	public double getX() {
		return x;
	}

	public void setX(double x) {
		this.x = x;
	}

	public double getY() {
		return y;
	}

	public void setY(double y) {
		this.y = y;
	}
	
	public String getClassName() {
		return className;
	}

	public void setClassName(String className) {
		this.className = className;
	}

	public double getDistance() {
		return distance;
	}

	public void setDistance(double distance) {
		this.distance = distance;
	}

	@Override
	public int compareTo(Point o) {
		// TODO Auto-generated method stub
		return this.distance.compareTo(o.distance);
	}
	
}
调用类:

/**
 * K-means(K均值)算法调用类
 * @author lyq
 *
 */
public class Client {
	public static void main(String[] args){
		String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
		//聚类中心数量设定
		int classNum = 3;
		
		KMeansTool tool = new KMeansTool(filePath, classNum);
		tool.kMeansClustering();
	}
}

测试输出结果:

聚类中心点1,x=15.5,y=8
聚类中心点2,x=4,y=10
聚类中心点3,x=3,y=3
----------
聚类中心点1,x=17.667,y=8.667
聚类中心点2,x=6.5,y=8
聚类中心点3,x=3,y=3
----------
聚类中心点1,x=17.667,y=8.667
聚类中心点2,x=6.5,y=8
聚类中心点3,x=3,y=3
----------
结果值收敛
聚类中心点1,x=17.667,y=8.667
聚类中心点2,x=6.5,y=8
聚类中心点3,x=3,y=3

K-Means算法的优缺点

1、首先优点当然是算法简单,快速,易懂,没有涉及到特别复杂的数据结构。

2、缺点1是最开始K的数量值以及K个聚类中心点的设置不好定,往往开始时不同的k个中心点的设置对后面迭代计算的走势会有比较大的影响,这时候可以考虑根据类的自动合并和分裂来确定这个k。

3、缺点2由于计算是迭代式的,而且计算距离的时候需要完全遍历一遍中心点,当数据规模比较大的时候,开销就显得比较大了。

posted @ 2020-01-12 19:09  回眸,境界  阅读(283)  评论(0编辑  收藏  举报