java apache Math 无监督谱聚类算法
package org.example; import org.apache.commons.math3.linear.*; import org.apache.commons.math3.ml.clustering.*; import org.apache.commons.math3.ml.distance.*; import org.apache.commons.math3.random.JDKRandomGenerator; import java.util.ArrayList; import java.util.List; import static org.apache.commons.math3.util.FastMath.*; public class spetCluster { private int k; // 聚类数量 private double sigma; // 高斯核参数 public spetCluster(int k, double sigma) { this.k = k; this.sigma = sigma; } // 计算相似度矩阵(高斯核) private RealMatrix computeSimilarityMatrix(double[][] data) { int n = data.length; RealMatrix similarity = MatrixUtils.createRealMatrix(n, n); EuclideanDistance distance = new EuclideanDistance(); for (int i = 0; i < n; i++) { for (int j = i; j < n; j++) { double dist = distance.compute(data[i], data[j]); double sim = exp(-dist*dist / (2*sigma*sigma)); similarity.setEntry(i, j, sim); similarity.setEntry(j, i, sim); } } return similarity; } // 计算归一化拉普拉斯矩阵 private RealMatrix computeLaplacian(RealMatrix similarity) { int n = similarity.getRowDimension(); RealMatrix degree = MatrixUtils.createRealMatrix(n, n); // 计算度矩阵 for (int i = 0; i < n; i++) { double sum = 0; for (int j = 0; j < n; j++) { sum += similarity.getEntry(i, j); } degree.setEntry(i, i, sum); } // 计算归一化拉普拉斯矩阵 L = I - D^(-1/2) W D^(-1/2) RealMatrix sqrtDegree = MatrixUtils.createRealMatrix(n, n); for (int i = 0; i < n; i++) { sqrtDegree.setEntry(i, i, 1.0/sqrt(degree.getEntry(i, i))); } return MatrixUtils.createRealIdentityMatrix(n) .subtract(sqrtDegree.multiply(similarity).multiply(sqrtDegree)); } // 执行谱聚类 public List<CentroidCluster<DoublePoint>> cluster(double[][] data) { // 1. 计算相似度矩阵 (data.length,data.length) RealMatrix similarity = computeSimilarityMatrix(data); //System.out.println(similarity.toString()); // 2. 计算拉普拉斯矩阵 (data.length,data.length) RealMatrix laplacian = computeLaplacian(similarity); //System.out.println(laplacian.toString()); // 3. 计算前k个特征向量 EigenDecomposition eigen = new EigenDecomposition(laplacian); RealMatrix eigenvectors = MatrixUtils.createRealMatrix(data.length, k); for (int i = 0; i < k; i++) { double[] v = eigen.getEigenvector(i).toArray(); for (int j = 0; j < v.length; j++) { eigenvectors.setEntry(j, i, abs(v[j])); } } // 4. 对特征向量行向量进行K-means聚类 List<DoublePoint> points = new ArrayList<>(); for (int i = 0; i < data.length; i++) { double[] vec = eigenvectors.getRow(i); points.add(new DoublePoint(vec)); } /*EuclideanDistance(欧氏距离)3 ManhattanDistance(曼哈顿距离)3 ChebyshevDistance(切比雪夫距离)3 */ KMeansPlusPlusClusterer<DoublePoint> kmeans = new KMeansPlusPlusClusterer<>(k, 1000,new ManhattanDistance(),new JDKRandomGenerator()// 随机数生成器 ); points.forEach(item -> System.out.println(item)); List<CentroidCluster<DoublePoint>> kmmodel = kmeans.cluster(points); return kmmodel; } }
package org.example; import org.apache.commons.math3.ml.clustering.CentroidCluster; import org.apache.commons.math3.ml.clustering.Clusterable; import org.apache.commons.math3.ml.clustering.DoublePoint; import java.util.List; import java.util.Random; public class SpectralClusteringExample { public static void main(String[] args) { // 创建测试数据(二维环形数据) // double[][] data = new double[1000][120]; // Random rand = new Random(); // for (int i = 0; i < 100; i++) { // double angle = 2 * Math.PI * i / 100; // data[i][0] = Math.cos(angle) + rand.nextGaussian() * 0.05; // data[i][1] = Math.sin(angle) + rand.nextGaussian() * 0.05; // } // for (int i = 100; i < 200; i++) { // double angle = 2 * Math.PI * (i-100) / 100; // data[i][0] = 2 * Math.cos(angle) + rand.nextGaussian() * 0.05; // data[i][1] = 2 * Math.sin(angle) + rand.nextGaussian() * 0.05; // } double[][] data = { {1.1, 1.2, 1.3}, {4.4, 4.5, 4.6}, {7.7, 7.8, 7.9}, {1.0, 1.0, 1.0}, {4.41, 4.55, 4.68}, {7.701, 7.801, 7.99} }; // 执行谱聚类 spetCluster sc = new spetCluster(3, 1); List<CentroidCluster<DoublePoint>> clusters = sc.cluster(data); // 输出聚类结果 for (int i = 0; i < clusters.size(); i++) { System.out.println("Cluster " + i + ": " + clusters.get(i).getPoints().size() + " points"); clusters.get(i).getPoints().forEach(item->System.out.println(item)); System.out.println("数据中心点:"+clusters.get(i).getCenter().toString()); } } }
自动化学习。