java 实现KMeans无监督聚类算法,使用k-means++进行初始化聚类中心

步骤

1.种子随机:​​ 第一个点是随机的。

2.​​远近为纲:​​ 后续每个新质心的选择,都​​依赖​​于到​​当前所有已选质心​​的距离。

3.​​距离平方加权:​​ 使用平方距离度量“远近”并计算概率。

4.概率性偏袒远方:​​ 每个点的被选概率,与其到最近已有质心的距离平方​​成正比​​。距离越远,概率越高。

5.​​动态更新:​​ 每选出一个新质心,立即重新评估所有点到新质心集合的距离和概率。

6.​​迭代选点:​​ 重复计算距离和依概率选择的过程,直到选够 K 个点。

package org.core;
import java.util.*;

class Point {
    private double[] coordinates;

    public Point(double[] coordinates) {
        this.coordinates = coordinates;
    }




    public int getLength() {
        return coordinates.length;
    }

    public double getValue(int index) {
        return coordinates[index];
    }

    public void setValue(int index, double val) {
        coordinates[index] = val;
    }

    public double sum(){
        return  Arrays.stream(this.coordinates).sum();

    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("[");
        for (int i = 0; i < coordinates.length; i++) {
            sb.append(coordinates[i]);
            if (i < coordinates.length - 1) sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }
}

class Cluster {
    private int id; //类别
    private Point centroid; //中心点
//    private List<Map<Integer,Point>> points; //数据点

    private List<Point> points; //数据点

    public Cluster(int id) {
        this.id = id;
        this.points = new ArrayList<>();
    }

    public void addPoint(Point point) {
        points.add(point);
    }
    public void addPoint(double[] point) {
        points.add(new Point(point));
    }


    public void clearPoints() {
        points.clear();
    }

    // Getter和Setter方法
    public int getId() { return id; }
    public Point getCentroid() { return centroid; }
    public List<Point> getPoints() { return points; }
    public void setCentroid(Point centroid) { this.centroid = centroid; }
    public void setCentroid(double[] centroid) { this.centroid = new Point(centroid); }
}



public class accKMeans {
    private int k; // 聚类数量
    private List<Point> points; // 数据点集合
    private List<Cluster> clusters; // 聚类集合
    private int maxIterations; // 最大迭代次数
    private String contensType = "k-means++"; // random,k-means++,
    private double tol = 0.0;
    public double acc_dist;
    List<Integer> labels; //类别

    public accKMeans(int k, int maxIterations) {
        this.k = k;
        this.maxIterations = maxIterations;
        this.clusters = new ArrayList<>();
    }

    public accKMeans(int k, int maxIterations,Double tol) {
        this.k = k;
        this.maxIterations = maxIterations;
        this.clusters = new ArrayList<>();
        this.tol = tol;
    }

    public accKMeans(int k, int maxIterations,Double tol,String contensType) {
        this.k = k;
        this.maxIterations = maxIterations;
        this.clusters = new ArrayList<>();
        this.tol = tol;
        this.contensType = contensType;
    }

    public accKMeans(int k, int maxIterations,Double tol,List<Cluster> clusters) {
        this.k = k;
        this.maxIterations = maxIterations;
        this.clusters = new ArrayList<>();
        this.tol = tol;
        this.clusters = clusters;
    }

    // 初始化聚类中心
    private void initClusters() {
        if (this.clusters.isEmpty()) {
            if (this.contensType.equals("k-means++")) {
                //System.out.println(this.contensType);
                minMaxCluters();
            } else if (this.contensType.equals("random")) {
                randomClusters();
            }
        }
    }

    private void randomClusters() {
        Random random = new Random();
        for (int i = 0; i < this.k; i++) {
            Cluster cluster = new Cluster(i);
            Point centroid = this.points.get(random.nextInt(this.points.size()));
            cluster.setCentroid(centroid);
            this.clusters.add(i, cluster);
        }
    }

    private void minMaxCluters() {

        // 步骤1: 随机选择一个点作为第一个质心
//        int firstIndex = random.nextInt(data.size());
//        centroids.add(data.get(firstIndex));

        double minSum = Double.MAX_VALUE;
        Point centroid = null;
        for (Point poit : this.points) {
            if (poit.sum() < minSum) {
                minSum = poit.sum();
                centroid = poit;
            }
        }
        Cluster cluster = new Cluster(0);
        cluster.setCentroid(centroid);
        this.clusters.add(0, cluster);

        int n = this.points.size();


        // 存储每个点到最近质心的距离平方
        Point minDistances = new Point(new double[n]);

        // 重复直到选出k个质心
//        while (this.clusters.size() < this.k) {
        for (int kl=1;kl<this.k;kl++){

            double sum = 0.0;

            // 步骤2: 计算每个点到最近质心的距离平方
            for (int i = 0; i < n; i++) {
                double minDistance = Double.MAX_VALUE;
                // 计算点到所有已选质心的最小距离
                for (Cluster cluster1 : this.clusters) {
                    double distance = calculateDistance(this.points.get(i), cluster1.getCentroid());
                    if (distance < minDistance) {
                        minDistance = distance;
                    }
                }
                minDistances.setValue(i,minDistance);
                sum += minDistance;
            }

            Random random = new Random();
            int setIndex= random.nextInt(n);
            double dist=0.0;
            for (int i=0;i<n;i++){
                if(minDistances.getValue(i)>dist){
                    dist = minDistances.getValue(i);
                    setIndex = i;
                }
            }

            Cluster clr1 = new Cluster(kl);
            clr1.setCentroid(this.points.get(setIndex));
            this.clusters.add(kl, clr1);

            if(this.clusters.size() >= this.k){
                break;
            }
        }
    }

    // 计算两点间欧氏距离
    private double calculateDistance(Point p1, Point p2) {
        double sum = 0.0;
        for (int i = 0; i < p1.getLength(); i++) {
            sum += Math.pow(p1.getValue(i) - p2.getValue(i), 2);
        }
        return Math.sqrt(sum);
    }

    //相似度
    private double accDist(Point p1, Point p2) {
        double sum = 0.0;
        int n = p1.getLength();
        for (int i = 0; i < n; i++) {
            if(p1.getValue(i)==p2.getValue(i) && p1.getValue(i)>0.0){
                sum +=1;
            }
        }
        return 1- sum/n;
    }

    private void clearnPoints(){
        for (Cluster cluster : this.clusters) {
            cluster.clearPoints();
        }
    }

    // 分配点到最近的类
    private void assignPoints() {
        this.acc_dist = 0.0;
        this.labels = new ArrayList<>();
        for (Point point : this.points) {
            double minDistance = Double.MAX_VALUE;
            Cluster closestCluster = null;
            for (Cluster cluster : this.clusters) {
                double distance = calculateDistance(point, cluster.getCentroid());
                if (distance < minDistance) {
                    minDistance = distance;
                    closestCluster = cluster;
                }
            }
            assert closestCluster != null;
            this.clusters.get(closestCluster.getId()).addPoint(point);
            this.acc_dist += minDistance;
            this.labels.add(closestCluster.getId());
        }
    }

    // 更新聚类中心:此逻辑可优化
    private double updateCentroids() {
        double acc = 0.0;
        for (Cluster cluster : this.clusters) {
            Point old_centens = cluster.getCentroid();
            if (cluster.getPoints().isEmpty()) continue;

            double[] newCentroid = new double[cluster.getCentroid().getLength()];
            for (Point point : cluster.getPoints()) {
                for (int i = 0; i < point.getLength(); i++) {
                    newCentroid[i] += point.getValue(i);
                }
            }

            for (int i = 0; i < newCentroid.length; i++) {
                newCentroid[i] /= cluster.getPoints().size();
            }
            Point new_centens = new Point(newCentroid);

            acc += calculateDistance(old_centens,new_centens);

            cluster.setCentroid(new_centens);

        }

        return acc;
    }

    public List<Point> dataTransform(List<double[]> data){

        List<Point> rslt = new ArrayList<>();

        for (double[] item : data){
            //Point point = new Point(item);
            rslt.add(new Point(item));
        }

        return rslt;
    }

    // 执行聚类
    public void fit(List<double[]> points) {

        this.points = dataTransform(points);

        //初始化质心
        initClusters();
        for (Cluster cluster : this.clusters) {
            System.out.println("Cluster " + cluster.getId() + ":");
            System.out.println("Centroid: " + cluster.getCentroid());
            System.out.println("size: " + cluster.getPoints().size());
            System.out.println("Points: " + cluster.getPoints());
            System.out.println("---------初始化质心-------------");
        }

        for (int i = 0; i < maxIterations; i++) {
            //清理各中心得数据点
            clearnPoints();
            //数据点分配到各个距离最近得中心
            assignPoints();
            //更新中心点
            double accTol = updateCentroids();

            //打印计算步骤
//            for (Cluster cluster : this.clusters) {
//                System.out.println("Cluster " + cluster.getId() + ":");
//                System.out.println("Centroid: " + cluster.getCentroid());
//                System.out.println("Points: " + cluster.getPoints().size());
//                System.out.println("-------计算---------------"+i);
//            }

            if (accTol<=this.tol){
                return;
            }

        }




    }

    // 输出聚类结果
    public void printResults() {
        for (Cluster cluster : this.clusters) {
            System.out.println("Cluster " + cluster.getId() + ":");
            System.out.println("Centroid: " + cluster.getCentroid());
            System.out.println("size: " + cluster.getPoints().size());
            System.out.println("Points: " + cluster.getPoints());
            System.out.println("----------------------");
        }
        System.out.println("accDist: " + this.acc_dist);
        System.out.println("labels: " + this.labels);
    }

    public static void main(String[] args) {
        // 示例数据
        List<double[]> points = new ArrayList<>();
        points.add(new double[]{1.0, 1.0});
        points.add(new double[]{1.5, 2.0});
        points.add(new double[]{3.0, 4.0});
        points.add(new double[]{5.0, 7.0});
        points.add(new double[]{3.5, 5.0});
        points.add(new double[]{4.5, 5.0});
        points.add(new double[]{6.5, 5.5});

        System.out.println(new Point(points.get(6)).sum());

        //初始化中心点 begin
        int k = 3;

//        List<Cluster> clts = new ArrayList<>();
//        Cluster label1 = new Cluster(0);
//        label1.setCentroid(new double[]{1.0, 1.0});
//        clts.add(label1);
//
//        Cluster label2 = new Cluster(1);
//        label2.setCentroid(new double[]{3.5, 4.5});
//        clts.add(label2);

//        Cluster label3 = new Cluster(2);
//        label3.setCentroid(new double[]{4.5, 5.0});
//        clts.add(label3);

        //初始化中心点 end

        // 创建K-means实例并运行 k++ accDist: 5.452368837997316  csh accDist:
        // random,k-means++,
        accKMeans kmeans = new accKMeans(4, 300,0.001,"k-means++");
        kmeans.fit(points);
        kmeans.printResults();

//        Random random = new Random();
//        for (int i=0;i<6;i++){
//            System.out.println(random.nextDouble()*15);
//        }
    }
}
posted @ 2025-08-18 22:43  ARYOUOK  阅读(9)  评论(0)    收藏  举报