第一个:计算NMI的:

package clusters;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* DATE: 16-6-18 TIME: 上午10:00
*/

/**
* 参考文献:http://www-nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html
*/
public class NormalizedMutualInformation {
public static String path = "/home/fhqplzj/IdeaProjects/Vein/src/main/resources/nmi_data";

public static void loadData(List<List<Integer>> lists) {
try {
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
String line;
while ((line = bufferedReader.readLine()) != null) {
String[] data = line.split("\\s+");
ArrayList<Integer> integers = new ArrayList<>();
for (String s : data) {
integers.add(Integer.parseInt(s));
}
lists.add(integers);
}
bufferedReader.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}

public static void main(String[] args) {
List<List<Integer>> lists = new ArrayList<>();
loadData(lists);
int K = lists.size();
int N = 0;
int[] clusters = new int[K];
for (int i = 0; i < K; i++) {
clusters[i] = lists.get(i).size();
N += clusters[i];
}
Map<Integer, Integer> map = new HashMap<>();
for (List<Integer> list : lists) {
for (Integer integer : list) {
map.put(integer, map.getOrDefault(integer, 0) + 1);
}
}
double clusterEntropy = 0;
for (int cluster : clusters) {
double tmp = 1.0 * cluster / N;
clusterEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));
}
// System.out.println("clusterEntropy = " + clusterEntropy);
double classEntropy = 0;
for (Integer integer : map.values()) {
double tmp = 1.0 * integer / N;
classEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));
}
// System.out.println("classEntropy = " + classEntropy);
double totalEntropy = 0;
Map<Integer, Integer> tmpMap = new HashMap<>();
for (int i = 0; i < K; i++) {
int wk = clusters[i];
tmpMap.clear();
for (Integer integer : lists.get(i)) {
tmpMap.put(integer, tmpMap.getOrDefault(integer, 0) + 1);
}
for (Map.Entry<Integer, Integer> entry : tmpMap.entrySet()) {
int cj = map.get(entry.getKey());
int value = entry.getValue();
totalEntropy += (1.0 * value / N * (Math.log(1.0 * N * value / (wk * cj)) / Math.log(2)));
}
}
// System.out.println("totalEntropy = " + totalEntropy);
double nmi = 2 * totalEntropy / (clusterEntropy + classEntropy);
System.out.println(String.format("nmi = %.2f", nmi));
}
}

//////////////////////////////////////////////

第二个,一些工具类:

package clusters;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* DATE: 16-6-18 TIME: 上午11:07
*/
public class ClusterUtils {
public static int combination(int n, int k) {
if (k > n) {
return 0;
}
int[] data = new int[n + 1];
data[0] = 1;
for (int i = 0; i < n; i++) {
for (int j = i + 1; j >= 1; j--) {
data[j] += data[j - 1];
}
}
return data[k];
}

public static int computeTPAndFP(int[] clusters) {
int result = 0;
for (int cluster : clusters) {
result += combination(cluster, 2);
}
return result;
}

public static int computeFP(List<Map<Integer, Integer>> mapList) {
int FP = 0;
for (Map<Integer, Integer> map : mapList) {
for (Integer integer : map.values()) {
if (integer >= 2) {
FP += combination(integer, 2);
}
}
}
return FP;
}

public static int computeOneClass(List<Integer> list) {
int n = list.size();
if (n == 0) {
return 0;
}
int result = 0;
for (int i = 0; i < n - 1; i++) {
for (int j = i + 1; j < n; j++) {
result += list.get(i) * list.get(j);
}
}
return result;
}

public static int computeFN(List<List<Integer>> lists) {
int result = 0;
for (List<Integer> list : lists) {
result += computeOneClass(list);
}
return result;
}

public static double computeFValue(double P, double R, double beta) {
return (beta * beta + 1) * P * R / (beta * beta * P + R);
}

public static void main(String[] args) {
List<Integer> list = Arrays.asList(1, 4, 0);
System.out.println("computeOneClass(list) = " + computeOneClass(list));
}
}

第三个,计算RI、P、R、F以及Purity的,顺便调用了NMI,一起打印输出,beta取1和5,如stanford文章所述,计算F1和F5

package clusters;

import java.util.*;

/**
* DATE: 16-6-18 TIME: 上午11:05
*/
public class RandIndex {
public static void main(String[] args) {
List<List<Integer>> lists = new ArrayList<>();
NormalizedMutualInformation.loadData(lists);
int K = lists.size();
int N = 0;
int[] clusters = new int[K];
for (int i = 0; i < K; i++) {
clusters[i] = lists.get(i).size();
N += clusters[i];
}
int TPAndFP = ClusterUtils.computeTPAndFP(clusters);
List<Map<Integer, Integer>> mapList = new ArrayList<>();
for (List<Integer> list : lists) {
Map<Integer, Integer> map = new HashMap<>();
for (Integer integer : list) {
map.put(integer, map.getOrDefault(integer, 0) + 1);
}
mapList.add(map);
}
Set<Integer> set = new HashSet<>();
for (Map<Integer, Integer> map : mapList) {
set.addAll(map.keySet());
}
int FP = ClusterUtils.computeFP(mapList);
int TP = TPAndFP - FP;
List<List<Integer>> lists1 = new ArrayList<>();
for (Integer integer : set) {
List<Integer> list = new ArrayList<>();
for (Map<Integer, Integer> map : mapList) {
if (map.containsKey(integer)) {
list.add(map.get(integer));
}
}
lists1.add(list);
}
int FN = ClusterUtils.computeFN(lists1);
int TN = ClusterUtils.combination(N, 2) - TPAndFP - FN;
// System.out.println("TP = " + TP);
// System.out.println("FP = " + FP);
// System.out.println("FN = " + FN);
// System.out.println("TN = " + TN);
double RI = 1.0 * (TP + TN) / (TP + FP + FN + TN);
/**
* compute Purity
*/
int totalMax = 0;
for (Map<Integer, Integer> map : mapList) {
totalMax += map.values().stream().reduce(Math::max).get();
}
double purity = 1.0 * totalMax / N;
System.out.println(String.format("purity = %.2f", purity));
/**
* println Normalized Mutual Information
*/
NormalizedMutualInformation.main(null);
System.out.println(String.format("RI = %.2f", RI));
/**
* compute F5
*/
double P = 1.0 * TP / (TP + FP);
double R = 1.0 * TP / (TP + FN);
double beta = 1;
System.out.println(String.format("P = %.2f", P));
System.out.printf("R = %.3f\n", R);
System.out.println(String.format("beta = 1, F = %.2f", ClusterUtils.computeFValue(P, R, beta)));
beta = 5;
System.out.println(String.format("beta = 5, F = %.3f", ClusterUtils.computeFValue(P, R, beta)));
}
}

 

输入数据就是stanford文中的3个类簇:

1 1 1 1 1 2
1 2 2 2 2 3
1 1 3 3 3

本文来自http://blog.csdn.net/asd991936157/article/details/51705958,只为学习

 

posted on 2017-12-11 10:40  ALT_LB  阅读(1706)  评论(0编辑  收藏  举报