package com.pachira.d;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
public class Kmeans {
/**
* Kmeans聚类算法
* 基本思想:
* 以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
*
* 过程描述:
* 输入:k, data[n], eps;
* (1) 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1];
* (2) 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i;
* (3) 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数;
* (4) 重复(2)(3),直到所有c[i]值的变化小于给定阈值。
*
* 其他说明:
* 1、Kmeans的变种,其距离计算不是欧基米德距离,有可能会出现问题;
* 2、海量数据聚类,欧基米德距离要比余弦相似性好(Inderjit S.Dhillon James FAN 和 Yuqiang Guan论文)
*
* data[n]的每个元素往往是一个向量;
*
*/
/**
* 初始化聚类中心点
* @param k 中心点数
* @param data 带聚类的数据集
* @return 中心点集
*/
public static double[] getPoints(int k, int[] data){
double[] points = new double[k];
for (int i = 0; i < k; i++) {
points[i] = (double)data[i];
}
return points;
}
/**
* 计算元素和每个中心点的距离,将该元素归为最小距离的中心点中
* @param points 中心点集
* @param data 元素集
* @return 聚类结果
*/
public static LinkedHashMap<Double, List<Integer>> culcate(double[] points, int[] data){
LinkedHashMap<Double, List<Integer>> map = new LinkedHashMap<Double, List<Integer>>();
for (int i = 0; i < data.length; i++) {
//get one point to culcate the distance
int d = data[i];
double minDistance = Double.MAX_VALUE;
double key = -1;
for(int j = 0; j < points.length; j++){
//欧基米德距离
double tmp = Math.sqrt(Math.pow((d - points[j]), 2));
if(tmp < minDistance){
minDistance = tmp;
key = points[j];
}
}
// System.out.println(key);
if(map.containsKey(key)){
List<Integer> cus = map.get(key);
cus.add(d);
}else{
List<Integer> cus = new ArrayList<Integer>();
cus.add(d);
map.put(key, cus);
}
}
return map;
}
/**
* 重置中心点
* @param 聚类结果
* @return 重置后的中心点集
*/
public static double[] resetPoint(HashMap<Double, List<Integer>> map){
double[] tmp = new double[map.keySet().size()];
int index = 0;
for(double key: map.keySet()){
List<Integer> val = map.get(key);
double total = 0;
for (int i = 0; i < val.size(); i++) {
total += val.get(i);
}
if(val.size() == 0){
tmp[index++] = key;
}else{
key = total / val.size();
tmp[index++] = key;
}
}
return tmp;
}
/**
* Kmeans
* @param data 待聚类元素集合
* @param k 类别数目(中心点数)
* @param eps 收敛阈值
* @return 聚类结果
*/
public static LinkedHashMap<Double, List<Integer>> kmeans(int[] data, int k, double eps){
double[] points = getPoints(k, data);
LinkedHashMap<Double, List<Integer>> tmp = null;
while(true){
tmp = culcate(points, data);
show(tmp);
double[] tpoints = resetPoint(tmp);
boolean flag = true;
for (int i = 0; i < tpoints.length; i++) {
if(Math.abs(points[i] - tpoints[i]) > eps){
flag = false;
break;
}
}
if(flag)break;
points = tpoints;
}
return null;
}
/**
* 显示聚类结果
* @param map
*/
public static void show(LinkedHashMap<Double, List<Integer>> map){
for (double key: map.keySet()) {
System.out.println(String.format("%.2f", key) + "\t" + map.get(key));
}
System.out.println("=================================");
}
public static void main(String[] args) {
int k = 10;
double eps = 0.001;
int[] data = {45, 26, 45, 65, 49, 27, 44, 26, 40, 63, 35, 63, 47, 24, 65, 62, 38, 8, 43, 65, 34, 36, 80, 34, 62, 60, 54, 66, 86, 47, 73, 15, 40, 7, 12, 35, 88, 5, 9, 20, 94, 28, 70, 78, 87, 78, 43, 80, 25, 88, 46, 21, 52, 49, 36, 64, 52, 59, 24, 56, 54, 10, 81, 78, 66, 28, 53, 48, 2, 89, 44, 79, 16, 55, 27, 6, 0, 46, 76, 87, 30, 90, 40, 51, 98, 97, 55, 72, 32, 79, 61, 39, 74, 58, 55, 58, 32, 4, 76, 19};
kmeans(data, k, eps);
}
}