[Java]数据分析--聚类

距离度量

  • 需求:计算两点间的欧几里得距离、曼哈顿距离、切比雪夫距离、堪培拉距离
  • 实现:利用commons.math3库相应函数
 1 import org.apache.commons.math3.ml.distance.*;
 2 
 3 public class TestMetrics {
 4     public static void main(String[] args) {
 5         double[] x = {1, 3}, y = {5, 6};
 6         
 7         EuclideanDistance eD = new EuclideanDistance();
 8         System.out.printf("Euclidean distance = %.2f%n", eD.compute(x,y));
 9         
10         ManhattanDistance mD = new ManhattanDistance();
11         System.out.printf("Manhattan distance = %.2f%n", mD.compute(x,y));
12         
13         ChebyshevDistance cD = new ChebyshevDistance();
14         System.out.printf("Chebyshev distance = %.2f%n", cD.compute(x,y));
15         
16         CanberraDistance caD = new CanberraDistance();
17         System.out.printf("Canberra distance =  %.2f%n", caD.compute(x,y));
18     }
19 }
View Code

Euclidean distance = 5.00
Manhattan distance = 7.00
Chebyshev distance = 4.00
Canberra distance = 1.00

层次聚类

  • 需求:将13个样本点分为3类
  • 实现:m点划分为k类,先令m点的每个点为一类,然后找到中心最近的两个类,用一个新的聚类替换,重复m-k次

HierachicalClustering.java

 1 import java.util.HashSet;
 2 
 3 public class HierarchicalClustering {
 4     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
 5         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
 6     private static final int M = DATA.length;  // number of points
 7     private static final int K = 3;            // number of clusters
 8 
 9     public static void main(String[] args) {
10         HashSet<Cluster> clusters = load(DATA);
11         for (int i = 0; i < M - K; i++) {
12             System.out.printf("%n%2d clusters:%n", M-i-1);
13             coalesce(clusters);
14             System.out.println(clusters);
15         }
16     }
17     
18     private static HashSet<Cluster> load(double[][] data) {
19         HashSet<Cluster> clusters = new HashSet();
20         for (double[] datum : DATA) {
21             clusters.add(new Cluster(datum[0], datum[1]));
22         }
23         return clusters;
24     } 
25     
26     private static void coalesce(HashSet<Cluster> clusters) {
27         Cluster cluster1=null, cluster2=null;
28         double minDist = Double.POSITIVE_INFINITY;
29         for (Cluster c1 : clusters) {
30             for (Cluster c2 : clusters) {
31                 if (!c1.equals(c2) && Cluster.distance(c1, c2) < minDist) {
32                     cluster1 = c1;
33                     cluster2 = c2;
34                     minDist = Cluster.distance(c1, c2);
35                 }
36             }
37         }
38         clusters.remove(cluster1);
39         clusters.remove(cluster2);
40         clusters.add(Cluster.union(cluster1, cluster2));
41     }
42 }
View Code

Point.java

 1 public class Point {
 2     private final double x, y;
 3 
 4     public Point(double x, double y) {
 5         this.x = x;
 6         this.y = y;
 7     }
 8 
 9     public double getX() {
10         return x;
11     }
12 
13     public double getY() {
14         return y;
15     }
16 
17     @Override
18     public int hashCode() {
19         int xhC = new Double(x).hashCode();
20         int yhC = new Double(y).hashCode();
21         return (int)(xhC + 79*yhC);
22     }
23 
24     @Override
25     public boolean equals(Object object) {
26         if (object == null) {
27             return false;
28         } else if (object == this) {
29             return true;
30         } else if (!(object instanceof Point)) {
31             return false;
32         }
33         Point that = (Point)object;
34         return bits(that.x) == bits(this.x) && bits(that.y) == bits(this.y);
35     }
36     
37     private long bits(double d) {
38         return Double.doubleToLongBits(d);
39 
40     }
41 
42     @Override
43     public String toString() {
44         return String.format("(%.2f,%.2f)", x,y);
45     }
46 }
View Code

Cluster.java

 1 import java.util.HashSet;
 2 
 3 public class Cluster {
 4     private final HashSet<Point> points;
 5     private Point centroid;
 6 
 7     public Cluster(HashSet points, Point centroid) {
 8         this.points = points;
 9         this.centroid = centroid;
10     }
11     
12     public Cluster(Point point) {
13         this.points = new HashSet();
14         this.points.add(point);
15         this.centroid = point;
16     }
17 
18     public Cluster(double x, double y) {
19         this(new Point(x,y));
20     }
21 
22     public Point getCentroid() {
23         return centroid;
24     }
25 
26     public void add(Point point) {
27         points.add(point);
28         recomputeCentroid();
29     }
30 
31     public void recomputeCentroid() {
32         double xSum=0.0, ySum=0.0;
33         for (Point point : points) {
34             xSum += point.getX();
35             ySum += point.getY();
36         }
37         centroid = new Point(xSum/points.size(), ySum/points.size());
38     }
39     
40     public static double distance(Cluster c1, Cluster c2) {
41         double dx = c1.centroid.getX() - c2.centroid.getX();
42         double dy = c1.centroid.getY() - c2.centroid.getY();
43         return Math.sqrt(dx*dx + dy*dy);
44     }
45     
46     public static Cluster union(Cluster c1, Cluster c2) {
47         Cluster cluster = new Cluster(c1.points, c1.centroid);
48         cluster.points.addAll(c2.points);
49         cluster.recomputeCentroid();
50         return cluster;
51     }
52 
53     @Override
54     public int hashCode() {
55         return points.hashCode();
56     }
57 
58     @Override
59     public boolean equals(Object object) {
60         if (object == null) {
61             return false;
62         } else if (object == this) {
63             return true;
64         } else if (!(object instanceof Cluster)) {
65             return false;
66         }
67         final Cluster that = (Cluster)object;
68         return that.points.equals(this.points);
69     }
70 
71     @Override
72     public String toString() {
73         return String.format("%n{%s,%s}", centroid, points);
74     }
75 }
View Code

结果-->

  1 12 clusters:
  2 [
  3 {(1.00,1.00),[(1.00,1.00)]}, 
  4 {(1.00,3.00),[(1.00,3.00)]}, 
  5 {(2.00,6.00),[(2.00,6.00)]}, 
  6 {(3.00,2.00),[(3.00,2.00)]}, 
  7 {(4.00,3.00),[(4.00,3.00)]}, 
  8 {(6.00,4.00),[(6.00,4.00)]}, 
  9 {(7.00,1.00),[(7.00,1.00)]}, 
 10 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
 11 {(6.00,3.00),[(6.00,3.00)]}, 
 12 {(3.00,4.00),[(3.00,4.00)]}, 
 13 {(1.00,5.00),[(1.00,5.00)]}, 
 14 {(5.00,6.00),[(5.00,6.00)]}]
 15 
 16 11 clusters:
 17 [
 18 {(1.00,1.00),[(1.00,1.00)]}, 
 19 {(1.00,3.00),[(1.00,3.00)]}, 
 20 {(2.00,6.00),[(2.00,6.00)]}, 
 21 {(3.00,2.00),[(3.00,2.00)]}, 
 22 {(4.00,3.00),[(4.00,3.00)]}, 
 23 {(7.00,1.00),[(7.00,1.00)]}, 
 24 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
 25 {(3.00,4.00),[(3.00,4.00)]}, 
 26 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
 27 {(1.00,5.00),[(1.00,5.00)]}, 
 28 {(5.00,6.00),[(5.00,6.00)]}]
 29 
 30 10 clusters:
 31 [
 32 {(1.00,1.00),[(1.00,1.00)]}, 
 33 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 34 {(1.00,3.00),[(1.00,3.00)]}, 
 35 {(3.00,2.00),[(3.00,2.00)]}, 
 36 {(4.00,3.00),[(4.00,3.00)]}, 
 37 {(7.00,1.00),[(7.00,1.00)]}, 
 38 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
 39 {(3.00,4.00),[(3.00,4.00)]}, 
 40 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
 41 {(5.00,6.00),[(5.00,6.00)]}]
 42 
 43  9 clusters:
 44 [
 45 {(1.00,1.00),[(1.00,1.00)]}, 
 46 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 47 {(1.00,3.00),[(1.00,3.00)]}, 
 48 {(7.00,1.00),[(7.00,1.00)]}, 
 49 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
 50 {(3.50,2.50),[(3.00,2.00), (4.00,3.00)]}, 
 51 {(3.00,4.00),[(3.00,4.00)]}, 
 52 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
 53 {(5.00,6.00),[(5.00,6.00)]}]
 54 
 55  8 clusters:
 56 [
 57 {(1.00,1.00),[(1.00,1.00)]}, 
 58 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 59 {(1.00,3.00),[(1.00,3.00)]}, 
 60 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
 61 {(7.00,1.00),[(7.00,1.00)]}, 
 62 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
 63 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
 64 {(5.00,6.00),[(5.00,6.00)]}]
 65 
 66  7 clusters:
 67 [
 68 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 69 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
 70 {(1.00,2.00),[(1.00,1.00), (1.00,3.00)]}, 
 71 {(7.00,1.00),[(7.00,1.00)]}, 
 72 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
 73 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
 74 {(5.00,6.00),[(5.00,6.00)]}]
 75 
 76  6 clusters:
 77 [
 78 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 79 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
 80 {(1.00,2.00),[(1.00,1.00), (1.00,3.00)]}, 
 81 {(6.33,5.67),[(7.00,6.00), (7.00,5.00), (5.00,6.00)]}, 
 82 {(7.00,1.00),[(7.00,1.00)]}, 
 83 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}]
 84 
 85  5 clusters:
 86 [
 87 {(6.20,4.80),[(6.00,3.00), (7.00,6.00), (7.00,5.00), (6.00,4.00), (5.00,6.00)]}, 
 88 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 89 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
 90 {(1.00,2.00),[(1.00,1.00), (1.00,3.00)]}, 
 91 {(7.00,1.00),[(7.00,1.00)]}]
 92 
 93  4 clusters:
 94 [
 95 {(6.20,4.80),[(6.00,3.00), (7.00,6.00), (7.00,5.00), (6.00,4.00), (5.00,6.00)]}, 
 96 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
 97 {(7.00,1.00),[(7.00,1.00)]}, 
 98 {(2.40,2.60),[(1.00,1.00), (3.00,2.00), (4.00,3.00), (3.00,4.00), (1.00,3.00)]}]
 99 
100  3 clusters:
101 [
102 {(6.20,4.80),[(6.00,3.00), (7.00,6.00), (7.00,5.00), (6.00,4.00), (5.00,6.00)]}, 
103 {(7.00,1.00),[(7.00,1.00)]}, 
104 {(2.14,3.43),[(1.00,1.00), (2.00,6.00), (3.00,2.00), (4.00,3.00), (3.00,4.00), (1.00,3.00), (1.00,5.00)]}]
View Code

weka实现

 1 import java.util.ArrayList;
 2 import weka.clusterers.HierarchicalClusterer;
 3 import static weka.clusterers.HierarchicalClusterer.TAGS_LINK_TYPE;
 4 import weka.core.Attribute;
 5 import weka.core.Instance;
 6 import weka.core.Instances;
 7 import weka.core.SelectedTag;
 8 import weka.core.SparseInstance;
 9 
10 public class WekaHierarchicalClustering {
11     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
12         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
13     private static final int M = DATA.length;  // number of points
14     private static final int K = 3;            // number of clusters
15 
16     public static void main(String[] args) {
17         Instances dataset = load(DATA);
18         HierarchicalClusterer hc = new HierarchicalClusterer();
19         hc.setLinkType(new SelectedTag(4, TAGS_LINK_TYPE));  // CENTROID
20         hc.setNumClusters(3);
21         try {
22             hc.buildClusterer(dataset);
23             for (Instance instance : dataset) {
24                 System.out.printf("(%.0f,%.0f): %s%n", 
25                         instance.value(0), instance.value(1), 
26                         hc.clusterInstance(instance));
27             }
28         } catch (Exception e) {
29             System.err.println(e);
30         }
31     }
32     
33     private static Instances load(double[][] data) {
34         ArrayList<Attribute> attributes = new ArrayList<Attribute>();
35         attributes.add(new Attribute("X"));
36         attributes.add(new Attribute("Y"));
37         Instances dataset = new Instances("Dataset", attributes, M);
38         for (double[] datum : data) {
39             Instance instance = new SparseInstance(2);
40             instance.setValue(0, datum[0]);
41             instance.setValue(1, datum[1]);
42             dataset.add(instance);
43         }
44         return dataset;
45     }
46 }
View Code

结果-->

(1,1): 0
(1,3): 0
(1,5): 0
(2,6): 0
(3,2): 0
(3,4): 0
(4,3): 0
(5,6): 1
(6,3): 1
(6,4): 1
(7,1): 2
(7,5): 1
(7,6): 1
View Code

weka画图

 1 import java.awt.BorderLayout;
 2 import java.awt.Container;
 3 import java.util.ArrayList;
 4 import javax.swing.JFrame;
 5 import weka.clusterers.HierarchicalClusterer;
 6 import static weka.clusterers.HierarchicalClusterer.TAGS_LINK_TYPE;
 7 import weka.core.Attribute;
 8 import weka.core.Instance;
 9 import weka.core.Instances;
10 import weka.core.SelectedTag;
11 import weka.core.SparseInstance;
12 import weka.gui.hierarchyvisualizer.HierarchyVisualizer;
13 
14 public class WekaHierarchicalClustering2 {
15     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
16         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
17     private static final int M = DATA.length;  // number of points
18     private static final int K = 3;            // number of clusters
19 
20     public static void main(String[] args) {
21         Instances dataset = load(DATA);
22         HierarchicalClusterer hc = new HierarchicalClusterer();
23         hc.setLinkType(new SelectedTag(4, TAGS_LINK_TYPE));  // CENTROID
24         hc.setNumClusters(1);
25         try {
26             hc.buildClusterer(dataset);
27             for (Instance instance : dataset) {
28                 System.out.printf("(%.0f,%.0f): %s%n", 
29                         instance.value(0), instance.value(1), 
30                         hc.clusterInstance(instance));
31             }
32             displayDendrogram(hc.graph());
33         } catch (Exception e) {
34             System.err.println(e);
35         }
36     }
37     
38     private static Instances load(double[][] data) {
39         ArrayList<Attribute> attributes = new ArrayList<Attribute>();
40         attributes.add(new Attribute("X"));
41         attributes.add(new Attribute("Y"));
42         Instances dataset = new Instances("Dataset", attributes, M);
43         for (double[] datum : data) {
44             Instance instance = new SparseInstance(2);
45             instance.setValue(0, datum[0]);
46             instance.setValue(1, datum[1]);
47             dataset.add(instance);
48         }
49         return dataset;
50     }
51     
52     public static void displayDendrogram(String graph) {
53         JFrame frame = new JFrame("Dendrogram");
54         frame.setSize(500, 400);
55         frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
56         Container pane = frame.getContentPane();
57         pane.setLayout(new BorderLayout());
58         pane.add(new HierarchyVisualizer(graph));
59         frame.setVisible(true);
60     }
61 }
View Code

 

K-均值聚类

  • 需求:同上
  • 实现:从数据集中选k个点创建k个聚类,其余点添加到最近的聚类中,重新计算中心

KMeans.java(普通实现)

 1 import java.util.HashSet;
 2 import java.util.Random;
 3 import java.util.Set;
 4 
 5 public class KMeans {
 6     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2},
 7             {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
 8     private static final int M = DATA.length;
 9     private static final int K = 3;
10     private static HashSet<Point> points;
11     private static HashSet<Cluster> clusters = new HashSet();
12     private static Random RANDOM = new Random();
13 
14     public static void main(String[] args){
15         points = load(DATA);
16 
17         int i0 = RANDOM.nextInt(M);
18         Point p = new Point(DATA[i0][0],DATA[i0][1]);
19         points.remove(p);
20 
21         HashSet<Point> initSet = new HashSet();
22         initSet.add(p);
23 
24         for(int i = 1; i < K; i ++){
25             p = farthestFrom(initSet);
26             initSet.add(p);
27             points.remove(p);
28         }
29 
30         for(Point point:initSet){
31             Cluster cluster = new Cluster(point);
32             clusters.add(cluster);
33         }
34 
35         for(Point point:points){
36             Cluster cluster = closestTo(point);
37             cluster.add(point);
38             cluster.recomputeCentroid();
39         }
40         System.out.println(clusters);
41     }
42 
43     private static HashSet<Point> load(double[][] data) {
44         HashSet<Point> points = new HashSet();
45         for (double[] datum : DATA) {
46             points.add(new Point(datum[0], datum[1]));
47         }
48         return points;
49     }
50 
51     // return the cluster whose centroid is closet to the specified point
52     private static Cluster closestTo(Point point){
53         double minDist = Double.POSITIVE_INFINITY;
54         Cluster c = null;
55         for(Cluster cluster:clusters){
56             double d = distance2(cluster.getCentroid(),point);
57             if(d < minDist){
58                 minDist = d;
59                 c = cluster;
60             }
61         }
62         return c;
63     }
64 
65     // return the point that is farthest from the specified set
66     private static Point farthestFrom(Set<Point> set){
67         Point p = null;
68         double maxDist = 0.0;
69         for(Point point:points){
70             if(set.contains(point)){
71                 continue;
72             }
73             double d = dist(point,set);
74             if(d > maxDist){
75                 p = point;
76                 maxDist = d;
77             }
78         }
79         return p;
80     }
81 
82     // return the distance from p to the nearest point in the set
83     public static double dist(Point p, Set<Point> set){
84         double minDist = Double.POSITIVE_INFINITY;
85         for(Point point:set){
86             double d = distance2(p,point);
87             minDist = (d < minDist ? d : minDist);
88         }
89         return minDist;
90     }
91 
92     public static double distance2(Point p, Point q){
93         double dx = p.getX() - q.getX();
94         double dy = p.getY() - q.getY();
95         return dx*dx + dy*dy;
96     }
97 }
View Code

[{(2.40,2.60),[(1.00,1.00), (1.00,3.00), (3.00,2.00), (4.00,3.00), (3.00,4.00)]},
{(6.33,4.17),[(6.00,3.00), (7.00,6.00), (6.00,4.00), (7.00,5.00), (7.00,1.00), (5.00,6.00)]},
{(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}]

KMeans.java(Weka 实现)

 1 import java.util.ArrayList;
 2 import weka.clusterers.SimpleKMeans;
 3 import weka.core.Attribute;
 4 import weka.core.Instance;
 5 import weka.core.Instances;
 6 import weka.core.SparseInstance;
 7 
 8 public class KMeans {
 9     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
10         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
11     private static final int M = DATA.length;  // number of points
12     private static final int K = 3;            // number of clusters
13 
14     public static void main(String[] args) {
15         Instances dataset = load(DATA);
16         SimpleKMeans skm = new SimpleKMeans();
17         System.out.printf("%d clusters:%n", K);
18         try {
19             skm.setNumClusters(K);
20             skm.buildClusterer(dataset);
21             for (Instance instance : dataset) {
22                 System.out.printf("(%.0f,%.0f): %s%n", 
23                         instance.value(0), instance.value(1), 
24                         skm.clusterInstance(instance));
25             }
26         } catch (Exception e) {
27             System.err.println(e);
28         }
29     }
30     
31     private static Instances load(double[][] data) {
32         ArrayList<Attribute> attributes = new ArrayList<Attribute>();
33         attributes.add(new Attribute("X"));
34         attributes.add(new Attribute("Y"));
35         Instances dataset = new Instances("Dataset", attributes, M);
36         for (double[] datum : data) {
37             Instance instance = new SparseInstance(2);
38             instance.setValue(0, datum[0]);
39             instance.setValue(1, datum[1]);
40             dataset.add(instance);
41         }
42         return dataset;
43     }
44 }
View Code

结果-->

(1,1): 1
(1,3): 1
(1,5): 0
(2,6): 0
(3,2): 1
(3,4): 0
(4,3): 0
(5,6): 0
(6,3): 2
(6,4): 2
(7,1): 2
(7,5): 2
(7,6): 2
View Code

KMeansPlusPlus.java(Apache Common Math 实现)

 1 import java.util.ArrayList;
 2 import java.util.List;
 3 import org.apache.commons.math3.ml.clustering.CentroidCluster;
 4 import org.apache.commons.math3.ml.clustering.DoublePoint;
 5 import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
 6 import org.apache.commons.math3.ml.distance.EuclideanDistance;
 7 
 8 public class KMeansPlusPlus {
 9     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
10         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
11     private static final int M = DATA.length;  // number of points
12     private static final int K = 3;  // number of clusters
13     private static final int MAX = 100;  // maximum number of iterations
14     private static final EuclideanDistance ED = new EuclideanDistance();
15     
16     public static void main(String[] args) {
17         List<DoublePoint> points = load(DATA);
18         KMeansPlusPlusClusterer<DoublePoint> clusterer;
19         clusterer = new KMeansPlusPlusClusterer(K, MAX, ED);
20         List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
21         
22         for (CentroidCluster<DoublePoint> cluster : clusters) {
23             System.out.println(cluster.getPoints());
24         }
25     }
26     
27     private static List<DoublePoint> load(double[][] data) {
28         List<DoublePoint> points = new ArrayList(M);
29         for (double[] pair : data) {
30             points.add(new DoublePoint(pair));            
31         }
32         return points;
33     } 
34 }
View Code

[[5.0, 6.0], [6.0, 3.0], [6.0, 4.0], [7.0, 5.0], [7.0, 6.0]]
[[1.0, 1.0], [1.0, 3.0], [1.0, 5.0], [2.0, 6.0], [3.0, 2.0], [3.0, 4.0], [4.0, 3.0]]
[[7.0, 1.0]]

仿射传播聚类

  • 需求:同上
  • 实现:
  • 特点:不同于KMeans,聚类个数k不需事先确定,
 1 public class AffinityPropagation {
 2     private static double[][] x = {{1,2}, {2,3}, {4,1}, {4,4}, {5,3}};
 3     private static int n = x.length;                 // number of points
 4     private static double[][] s = new double[n][n];  // similarities
 5     private static double[][] r = new double[n][n];  // responsibilities
 6     private static double[][] a = new double[n][n];  // availabilities
 7     private static final int ITERATIONS = 10;
 8     private static final double DAMPER = 0.5;
 9 
10     public static void main(String[] args) {
11         initSimilarities();
12         for (int i = 0; i < ITERATIONS; i++) {
13             updateResponsibilities();
14             updateAvailabilities();
15         }
16         printResults();
17     }
18     
19     private static void initSimilarities() {
20         double sum = 0;
21         for (int i = 0; i < n; i++) {
22             for (int j = 0; j < i; j++) {
23                 sum += s[i][j] = s[j][i] = negSqEuclidDist(x[i], x[j]);
24             }
25         }
26         double average = 2*sum/(n*n - n);  // average of s[i][j] for j < i
27         for (int i = 0; i < n; i++) {
28             s[i][i] = average;
29         }
30     }
31     
32     private static void updateResponsibilities() {
33         for (int i = 0; i < n; i++) {
34             for (int k = 0; k < n; k++) {
35                 double oldValue = r[i][k];
36                 double max = Double.NEGATIVE_INFINITY;
37                 for (int j = 0; j < n; j++) {
38                     if (j != k) {
39                         max = Math.max(max, a[i][j] + s[i][j]);
40                     }
41                 }
42                 double newValue = s[i][k] - max;
43                 r[i][k] = DAMPER*oldValue + (1 - DAMPER)*newValue;
44             }
45         }
46     }
47     
48     private static void updateAvailabilities() {
49         for (int i = 0; i < n; i++) {
50             for (int k = 0; k < n; k++) {
51                 double oldValue = a[i][k];
52                 double newValue = Math.min(0, r[k][k] + sumOfPos(i,k));
53                 if (k == i) {
54                     newValue = sumOfPos(k,k);
55                 }
56                 a[i][k] = DAMPER*oldValue + (1 - DAMPER)*newValue;
57             }
58         }
59     }
60     
61     /*  Returns the negative square of the Euclidean distance from x to y.
62     */
63     private static double negSqEuclidDist(double[] x, double[] y) {
64         double d0 = x[0] - y[0];
65         double d1 = x[1] - y[1];
66         return -(d0*d0 + d1*d1);
67     }
68     
69     /*  Returns the sum of the positive r[j][k] excluding r[i][k] and r[k][k].
70     */
71     private static double sumOfPos(int i, int k) {
72         double sum = 0;
73         for (int j = 0; j < n; j++) {
74             if (j != i && j != k) {
75                 sum += Math.max(0, r[j][k]);
76             }
77         }
78         return sum;
79     }
80     
81     private static void printResults() {
82         for (int i = 0; i < n; i++) {
83             double max = a[i][0] + r[i][0];
84             int k = 0;
85             for (int j = 1; j < n; j++) {
86                 double arij = a[i][j] + r[i][j];
87                 if (arij > max) {
88                     max = arij;
89                     k = j;
90                 }
91             }
92             System.out.printf("point %d has exemplar point %d%n", i, k);
93         }
94     }
95 }
View Code

point 0 has exemplar point 1
point 1 has exemplar point 1
point 2 has exemplar point 4
point 3 has exemplar point 4
point 4 has exemplar point 4

参考

https://blog.csdn.net/xzfreewind/article/details/73770327

posted @ 2021-04-24 23:33  cxc1357  阅读(641)  评论(0编辑  收藏  举报