java 实现KMeans无监督聚类算法,使用k-means++进行初始化聚类中心
步骤
1.种子随机: 第一个点是随机的。
2.远近为纲: 后续每个新质心的选择,都依赖于到当前所有已选质心的距离。
3.距离平方加权: 使用平方距离度量“远近”并计算概率。
4.概率性偏袒远方: 每个点的被选概率,与其到最近已有质心的距离平方成正比。距离越远,概率越高。
5.动态更新: 每选出一个新质心,立即重新评估所有点到新质心集合的距离和概率。
6.迭代选点: 重复计算距离和依概率选择的过程,直到选够 K 个点。
package org.core;
import java.util.*;
class Point {
private double[] coordinates;
public Point(double[] coordinates) {
this.coordinates = coordinates;
}
public int getLength() {
return coordinates.length;
}
public double getValue(int index) {
return coordinates[index];
}
public void setValue(int index, double val) {
coordinates[index] = val;
}
public double sum(){
return Arrays.stream(this.coordinates).sum();
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < coordinates.length; i++) {
sb.append(coordinates[i]);
if (i < coordinates.length - 1) sb.append(", ");
}
sb.append("]");
return sb.toString();
}
}
class Cluster {
private int id; //类别
private Point centroid; //中心点
// private List<Map<Integer,Point>> points; //数据点
private List<Point> points; //数据点
public Cluster(int id) {
this.id = id;
this.points = new ArrayList<>();
}
public void addPoint(Point point) {
points.add(point);
}
public void addPoint(double[] point) {
points.add(new Point(point));
}
public void clearPoints() {
points.clear();
}
// Getter和Setter方法
public int getId() { return id; }
public Point getCentroid() { return centroid; }
public List<Point> getPoints() { return points; }
public void setCentroid(Point centroid) { this.centroid = centroid; }
public void setCentroid(double[] centroid) { this.centroid = new Point(centroid); }
}
public class accKMeans {
private int k; // 聚类数量
private List<Point> points; // 数据点集合
private List<Cluster> clusters; // 聚类集合
private int maxIterations; // 最大迭代次数
private String contensType = "k-means++"; // random,k-means++,
private double tol = 0.0;
public double acc_dist;
List<Integer> labels; //类别
public accKMeans(int k, int maxIterations) {
this.k = k;
this.maxIterations = maxIterations;
this.clusters = new ArrayList<>();
}
public accKMeans(int k, int maxIterations,Double tol) {
this.k = k;
this.maxIterations = maxIterations;
this.clusters = new ArrayList<>();
this.tol = tol;
}
public accKMeans(int k, int maxIterations,Double tol,String contensType) {
this.k = k;
this.maxIterations = maxIterations;
this.clusters = new ArrayList<>();
this.tol = tol;
this.contensType = contensType;
}
public accKMeans(int k, int maxIterations,Double tol,List<Cluster> clusters) {
this.k = k;
this.maxIterations = maxIterations;
this.clusters = new ArrayList<>();
this.tol = tol;
this.clusters = clusters;
}
// 初始化聚类中心
private void initClusters() {
if (this.clusters.isEmpty()) {
if (this.contensType.equals("k-means++")) {
//System.out.println(this.contensType);
minMaxCluters();
} else if (this.contensType.equals("random")) {
randomClusters();
}
}
}
private void randomClusters() {
Random random = new Random();
for (int i = 0; i < this.k; i++) {
Cluster cluster = new Cluster(i);
Point centroid = this.points.get(random.nextInt(this.points.size()));
cluster.setCentroid(centroid);
this.clusters.add(i, cluster);
}
}
private void minMaxCluters() {
// 步骤1: 随机选择一个点作为第一个质心
// int firstIndex = random.nextInt(data.size());
// centroids.add(data.get(firstIndex));
double minSum = Double.MAX_VALUE;
Point centroid = null;
for (Point poit : this.points) {
if (poit.sum() < minSum) {
minSum = poit.sum();
centroid = poit;
}
}
Cluster cluster = new Cluster(0);
cluster.setCentroid(centroid);
this.clusters.add(0, cluster);
int n = this.points.size();
// 存储每个点到最近质心的距离平方
Point minDistances = new Point(new double[n]);
// 重复直到选出k个质心
// while (this.clusters.size() < this.k) {
for (int kl=1;kl<this.k;kl++){
double sum = 0.0;
// 步骤2: 计算每个点到最近质心的距离平方
for (int i = 0; i < n; i++) {
double minDistance = Double.MAX_VALUE;
// 计算点到所有已选质心的最小距离
for (Cluster cluster1 : this.clusters) {
double distance = calculateDistance(this.points.get(i), cluster1.getCentroid());
if (distance < minDistance) {
minDistance = distance;
}
}
minDistances.setValue(i,minDistance);
sum += minDistance;
}
Random random = new Random();
int setIndex= random.nextInt(n);
double dist=0.0;
for (int i=0;i<n;i++){
if(minDistances.getValue(i)>dist){
dist = minDistances.getValue(i);
setIndex = i;
}
}
Cluster clr1 = new Cluster(kl);
clr1.setCentroid(this.points.get(setIndex));
this.clusters.add(kl, clr1);
if(this.clusters.size() >= this.k){
break;
}
}
}
// 计算两点间欧氏距离
private double calculateDistance(Point p1, Point p2) {
double sum = 0.0;
for (int i = 0; i < p1.getLength(); i++) {
sum += Math.pow(p1.getValue(i) - p2.getValue(i), 2);
}
return Math.sqrt(sum);
}
//相似度
private double accDist(Point p1, Point p2) {
double sum = 0.0;
int n = p1.getLength();
for (int i = 0; i < n; i++) {
if(p1.getValue(i)==p2.getValue(i) && p1.getValue(i)>0.0){
sum +=1;
}
}
return 1- sum/n;
}
private void clearnPoints(){
for (Cluster cluster : this.clusters) {
cluster.clearPoints();
}
}
// 分配点到最近的类
private void assignPoints() {
this.acc_dist = 0.0;
this.labels = new ArrayList<>();
for (Point point : this.points) {
double minDistance = Double.MAX_VALUE;
Cluster closestCluster = null;
for (Cluster cluster : this.clusters) {
double distance = calculateDistance(point, cluster.getCentroid());
if (distance < minDistance) {
minDistance = distance;
closestCluster = cluster;
}
}
assert closestCluster != null;
this.clusters.get(closestCluster.getId()).addPoint(point);
this.acc_dist += minDistance;
this.labels.add(closestCluster.getId());
}
}
// 更新聚类中心:此逻辑可优化
private double updateCentroids() {
double acc = 0.0;
for (Cluster cluster : this.clusters) {
Point old_centens = cluster.getCentroid();
if (cluster.getPoints().isEmpty()) continue;
double[] newCentroid = new double[cluster.getCentroid().getLength()];
for (Point point : cluster.getPoints()) {
for (int i = 0; i < point.getLength(); i++) {
newCentroid[i] += point.getValue(i);
}
}
for (int i = 0; i < newCentroid.length; i++) {
newCentroid[i] /= cluster.getPoints().size();
}
Point new_centens = new Point(newCentroid);
acc += calculateDistance(old_centens,new_centens);
cluster.setCentroid(new_centens);
}
return acc;
}
public List<Point> dataTransform(List<double[]> data){
List<Point> rslt = new ArrayList<>();
for (double[] item : data){
//Point point = new Point(item);
rslt.add(new Point(item));
}
return rslt;
}
// 执行聚类
public void fit(List<double[]> points) {
this.points = dataTransform(points);
//初始化质心
initClusters();
for (Cluster cluster : this.clusters) {
System.out.println("Cluster " + cluster.getId() + ":");
System.out.println("Centroid: " + cluster.getCentroid());
System.out.println("size: " + cluster.getPoints().size());
System.out.println("Points: " + cluster.getPoints());
System.out.println("---------初始化质心-------------");
}
for (int i = 0; i < maxIterations; i++) {
//清理各中心得数据点
clearnPoints();
//数据点分配到各个距离最近得中心
assignPoints();
//更新中心点
double accTol = updateCentroids();
//打印计算步骤
// for (Cluster cluster : this.clusters) {
// System.out.println("Cluster " + cluster.getId() + ":");
// System.out.println("Centroid: " + cluster.getCentroid());
// System.out.println("Points: " + cluster.getPoints().size());
// System.out.println("-------计算---------------"+i);
// }
if (accTol<=this.tol){
return;
}
}
}
// 输出聚类结果
public void printResults() {
for (Cluster cluster : this.clusters) {
System.out.println("Cluster " + cluster.getId() + ":");
System.out.println("Centroid: " + cluster.getCentroid());
System.out.println("size: " + cluster.getPoints().size());
System.out.println("Points: " + cluster.getPoints());
System.out.println("----------------------");
}
System.out.println("accDist: " + this.acc_dist);
System.out.println("labels: " + this.labels);
}
public static void main(String[] args) {
// 示例数据
List<double[]> points = new ArrayList<>();
points.add(new double[]{1.0, 1.0});
points.add(new double[]{1.5, 2.0});
points.add(new double[]{3.0, 4.0});
points.add(new double[]{5.0, 7.0});
points.add(new double[]{3.5, 5.0});
points.add(new double[]{4.5, 5.0});
points.add(new double[]{6.5, 5.5});
System.out.println(new Point(points.get(6)).sum());
//初始化中心点 begin
int k = 3;
// List<Cluster> clts = new ArrayList<>();
// Cluster label1 = new Cluster(0);
// label1.setCentroid(new double[]{1.0, 1.0});
// clts.add(label1);
//
// Cluster label2 = new Cluster(1);
// label2.setCentroid(new double[]{3.5, 4.5});
// clts.add(label2);
// Cluster label3 = new Cluster(2);
// label3.setCentroid(new double[]{4.5, 5.0});
// clts.add(label3);
//初始化中心点 end
// 创建K-means实例并运行 k++ accDist: 5.452368837997316 csh accDist:
// random,k-means++,
accKMeans kmeans = new accKMeans(4, 300,0.001,"k-means++");
kmeans.fit(points);
kmeans.printResults();
// Random random = new Random();
// for (int i=0;i<6;i++){
// System.out.println(random.nextDouble()*15);
// }
}
}
自动化学习。

浙公网安备 33010602011771号