java - AdaBoost算法

由于对AdaBoost算法的弱分类器不是很了解,没明白算法描述里的“在权值分布的训练集上,取阈值使得分类误差率最小,然后就得到基本分类器”这句话。不是很明白怎么根据权值分布得到的阈值?提供的代码是直接给出了弱分类器,不知道是不是这样,有问题请提出,谢谢。一起学习。由于被弱分类器搞的郁闷,所以代码中没有添加注释,但是步骤1.2.3是根据李航的算法描述1.2.3进行编写的。

  1 import java.util.ArrayList;
  2 import java.util.List;
  3 
  4 public class AdaBoost {
  5     public static void main(String[] args){
  6         TestPoint[] testpoint = new TestPoint[10];
  7         testpoint[0] = new TestPoint(0,1);
  8         testpoint[1] = new TestPoint(1,1);
  9         testpoint[2] = new TestPoint(2,1);
 10         testpoint[3] = new TestPoint(3,-1);
 11         testpoint[4] = new TestPoint(4,-1);
 12         testpoint[5] = new TestPoint(5,-1);
 13         testpoint[6] = new TestPoint(6,1);
 14         testpoint[7] = new TestPoint(7,1);
 15         testpoint[8] = new TestPoint(8,1);
 16         testpoint[9] = new TestPoint(9,-1);
 17         
 18         List<List<Integer>> G = new ArrayList<List<Integer>>();
 19         double[] v = {2.5,8.5,5.5};
 20         double[] D = new double[testpoint.length];
 21         List<Double> A = new ArrayList<Double>();
 22         
 23         D = first(D,testpoint);
 24         for(int i=0;i<v.length;i++){
 25             second(testpoint,D,v[i],G,A,i);
 26         }
 27         third(A,G);
 28     }
 29 
 30     private static void third(List<Double> a, List<List<Integer>> g) {
 31         System.out.print("所得函数:    sign[");
 32         for(int i=0;i<a.size();i++){
 33             System.out.print(a.get(i) + " * " + "g[" + i + "]");
 34             if(i<a.size()-1){
 35                 System.out.print(" + ");
 36             }
 37         }
 38         System.out.println("]");
 39     }
 40 
 41     private static List<List<Integer>> second(TestPoint[] testpoint, double[] D, double v, List<List<Integer>> G, List<Double> A,int index) {
 42         double Z = 0;
 43         double error = 0.0;
 44         double a = 0;
 45         
 46         int[] GTemp = new int[testpoint.length];
 47         List<Integer> LTemp = new ArrayList<Integer>();
 48         
 49         for(int i=0;i<testpoint.length;i++){
 50             if(v != 5.5){    
 51                 if(testpoint[i].getX() < v){
 52                     GTemp[i] = 1;
 53                 }
 54                 else{
 55                     GTemp[i] = -1;
 56                 }
 57             }
 58             else{
 59                 if(testpoint[i].getX() < v){
 60                     GTemp[i] = -1;
 61                 }
 62                 else{
 63                     GTemp[i] = 1;
 64                 }
 65             }
 66         }
 67 
 68         for(int i=0;i<GTemp.length;i++){
 69             LTemp.add(GTemp[i]);
 70         }
 71         G.add(LTemp);
 72         
 73         for(int i=0;i<testpoint.length;i++){
 74             if(testpoint[i].getY() != GTemp[i]){
 75                 error += D[i] * 1;
 76             }
 77         }
 78         
 79         System.out.println("         错误率e" + (index + 1) + ":" +error);
 80         
 81         a = 0.5 * Math.log((1-error)/error);
 82         A.add(a);
 83         
 84         for(int i=0;i<testpoint.length;i++){
 85             Z += D[i] * Math.exp((-a) * testpoint[i].getY() * GTemp[i]);
 86         }
 87         
 88         System.out.print("权值分布D" + (index +1) + ":" ); 
 89         for(int i=0;i<testpoint.length;i++){
 90             D[i] = (D[i]/Z) * Math.exp((-a) * testpoint[i].getY() * GTemp[i]);
 91             System.out.print(D[i] + "  ");
 92         }
 93         System.out.println();
 94         
 95         return G;
 96     }
 97     
 98     private static double[] first(double[] D, TestPoint[] testpoint) {
 99         for(int i=0;i<testpoint.length;i++){
100             D[i] = 1.0/testpoint.length;
101         }
102         return D;
103     }
104 }

 

 

 1 //训练数据点,x为数据,y为类别
 2 public class TestPoint {
 3     private double x;
 4     private double y;
 5 
 6     public TestPoint(double i,double y){
 7         this.x = i;
 8         this.y = y;
 9     }
10 
11     public double getX() {
12         return x;
13     }
14 
15     public double getY() {
16         return y;
17     }
18     public void setX(int x) {
19         this.x = x;
20     }
21     
22     public void setY(int y) {
23         this.y = y;
24     }
25     
26     public String toString(double x ,double y){
27         return x + " " + y;
28     }
29 }

运行结果:

错误率e1:0.30000000000000004
权值分布D1:0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.16666666666666663 0.16666666666666663 0.16666666666666663 0.07142857142857142
错误率e2:0.21428571428571427
权值分布D2:0.04545454545454546 0.04545454545454546 0.04545454545454546 0.1666666666666667 0.1666666666666667 0.1666666666666667 0.10606060606060605 0.10606060606060605 0.10606060606060605 0.04545454545454546
错误率e3:0.18181818181818185
权值分布D3:0.12499999999999997 0.12499999999999997 0.12499999999999997 0.10185185185185185 0.10185185185185185 0.10185185185185185 0.0648148148148148 0.0648148148148148 0.0648148148148148 0.12499999999999997
所得函数: sign[0.4236489301936017 * g[0] + 0.6496414920651304 * g[1] + 0.752038698388137 * g[2]]

posted on 2013-09-24 09:49  Ja °  阅读(495)  评论(0编辑  收藏  举报

导航