Ranklib源码剖析--LambdaMart

Ranklib是一套优秀的Learning to Rank领域的开源实现,其中有实现了MART,RankNet,RankBoost,LambdaMart,Random Forest等模型。其中由微软发布的LambdaMART是IR业内常用的Learning to Rank模型,本文主要介绍Ranklib中的LambdaMART模型的具体实现,用以帮助理解paper中阐述的方法。本文是基于version2.3版本的Ranklib来介绍的。

LambdaMart的基本原理详见之前的博客:http://www.cnblogs.com/bentuwuying/p/6690836.html。要知道LambdaMart是基于MART的,而MART又是由若干棵regression tree组合而成的。所以,我们先来看看Ranklib中是如何实现regression tree的,以及在给定training data with labels的情况下,regression tree是如何拟合的。

1. regression tree

regression tree拟合给定training data的步骤总结概括如下:

RegressionTree
    nodes #限制一棵树的最大叶子节点数
    minLeafSupport #控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂
    root #根节点
    leaves #叶子节点list
    构造函数RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)
        对各个类变量进行初始化
    fit #对training data进行拟合regression tree
        新建一个队列queue,用于按队列顺序(即按层遍历的顺序)进行分裂
        初始化一个regression tree的根节点root
        root.split #根节点分裂
            hist.findBestSplit #调用Split对象包含的FeatureHistogram对象的分裂方法(在该节点的已经统计好的特征统计直方图的基础上,寻找最佳分裂点,进行分裂,再计算左右子节点的特征统计直方图,并对左右子节点进行初始化)
                判断deviance,为0则分裂不成功
                根据samplingRate决定usedFeatures(分裂时需要使用的features的索引)
                调用内部的findBestSplit方法
                    在一个节点上,在usedFeatures中,根据该节点的特征统计直方图,来进行分裂时feature和threshold的选择
                    S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight
                    对每个可选的划分点(feature和threshold组合),求最大的S值,对应于均方误差最小,是最优的划分点
                判断划分是否成功,若S=-1,则分裂不成功
                对该节点上的每个训练数据,根据最优分裂点,进行左右子节点的分配
                初始化分裂后左右子节点各自的特征统计直方图
                    construct #一般用作父节点分裂后产生的左子节点的特征统计直方图的构造函数(当使用父节点来构造时,thresholds数组不变,但是sum和count数组需要重新构造)
                    construct #一般用作父节点分裂后产生的右子节点的特征统计直方图的构造函数
                计算本节点和左右子节点的均方误差
                sp.set #调用FeatureHistogram对象所在的Split对象的方法
                    一般在该节点进行分裂完成后,设定分裂时的featureID,threshold,deviance
                    只有非叶子节点才会进行分裂(调用这个方法),所以只有非叶子节点的featureID不为-1,叶子节点由于没有调用这个方法,故featureID=-1
                初始化左子节点(根据分裂到左子节点的训练数据索引数组,左子节点的特征统计直方图,左子节点的均方误差,左子节点的训练数据label之和),并设置到当前节点的左子节点变量上
                初始化右子节点(根据分裂到右子节点的训练数据索引数组,右子节点的特征统计直方图,右子节点的均方误差,右子节点的训练数据label之和),并设置到当前节点的右子节点变量上
        insert #将左右的子节点插入队列,用于下面遍历
            按均方误差从大到小的顺序进行插入队列
        循环:按队列顺序(即按层遍历的顺序)进行分裂,再将每次能够成功分裂的产生的两个子节点插入队列中
        根据根节点root的leaves类方法(迭代遍历),设置regression tree的leaves类变量

 

下面是regression tree拟合过程中涉及到的几个类文件代码,关键部分都有添加了详细的注释。

 

1. FeatureHistogram

  1 package ciir.umass.edu.learning.tree;
  2 import java.util.ArrayList;
  3 import java.util.Arrays;
  4 import java.util.List;
  5 import java.util.Random;
  6 import ciir.umass.edu.learning.DataPoint;
  7 import ciir.umass.edu.utilities.MyThreadPool;
  8 import ciir.umass.edu.utilities.WorkerThread;
  9 /**
 10  * @author vdang
 11  */
 12 //特征直方图类,对RankList对象进行特征的直方图统计,选择每次split时最优的feature和划分点
 13 public class FeatureHistogram {
 14     // 存放分裂时的featureIdx,thresholdIdx,以及评判是否最佳分裂的评分值sumLeft*sumLeft/countLeft + sumRight*sumRight/countRight
 15     class Config {
 16         int featureIdx = -1;
 17         int thresholdIdx = -1;
 18         double S = -1;
 19     }
 20     
 21     //Parameter
 22     public static float samplingRate = 1; //采样率,用于对分裂时使用的feature个数进行采样,不使用所有的feature
 23     
 24     //Variables
 25     public int[] features = null; //feature数组,每个元素是一个feature id(fid)
 26     public float[][] thresholds = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是阈值,个数为所有训练数据在此feature上的value的去重个数,从小到大排序的不重复值,用于对此节点的训练数据在此feature上分裂时可选的feature value阈值
 27     public double[][] sum = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是label之和,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二维数组大小与thresholds数组相同
 28     public double sumResponse = 0; //所有的训练数据的label之和
 29     public double sqSumResponse = 0; //所有的训练数据的label的平方和
 30     public int[][] count = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是个数,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的个数,count二维数组大小与thresholds数组相同
 31     public int[][] sampleToThresholdMap = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是索引,是对应训练数据samples[i][j]在特定feature上每个训练数据的value对应于其在thresholds数组中相应行的列索引位置
 32     
 33     //whether to re-use its parents @sum and @count instead of cleaning up the parent and re-allocate for the children.
 34     //@sum and @count of any intermediate tree node (except for root) can be re-used.  
 35     private boolean reuseParent = false;
 36     
 37     public FeatureHistogram()
 38     {
 39         
 40     }
 41 
 42     //FeatureHistogram构造函数(1-1),一般用作整棵树/根节点的feature histogram,计算该节点的特征统计直方图
 43     //@samples: 训练数据
 44     //@labels: 训练数据的label
 45     //@sampleSortedIdx: 将样本根据特征排序,方便做树的分列时快速找出最优分列点,sorted list of samples by each feature, need initializing only once,初始化可见LambdaMART.java中的init()
 46     //@features: 训练数据的特征集合
 47     //@thresholds: 创建存放候选阈值(分列点)的表,a table of candidate thresolds for each feature, we will select the best tree split from these candidates later on
 48 ,初始化可见LambdaMART.java中的init(),此二维数组的每一行的最后一列的值是后加的,为Float.MAX_VALUE
 49     public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds)
 50     {
 51         this.features = features;
 52         this.thresholds = thresholds;
 53         
 54         sumResponse = 0;
 55         sqSumResponse = 0;
 56         
 57         sum = new double[features.length][];
 58         count = new int[features.length][];
 59         sampleToThresholdMap = new int[features.length][];
 60         
 61         //确定是否使用多线程计算
 62         MyThreadPool p = MyThreadPool.getInstance();
 63         if(p.size() == 1)
 64             construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length-1);
 65         else
 66             p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length);            
 67     }
 68     //FeatureHistogram构造函数(1-2),被(1-1)调用
 69     protected void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds, int start, int end)
 70     {
 71         for(int i=start;i<=end;i++) //对于每个feature
 72         {
 73             int fid = features[i]; // 获取feature id
 74             //get the list of samples associated with this node (sorted in ascending order with respect to the current feature)
 75             int[] idx = sampleSortedIdx[i]; //根据此feature下的value从小到大排序后的训练数据的索引数组
 76             
 77             double sumLeft = 0; //累计此值,用于给sumLabel使用
 78             float[] threshold = thresholds[i];
 79             double[] sumLabel = new double[threshold.length]; //对应前面sum二维数组的一行
 80             int[] c = new int[threshold.length]; //对应前面count二维数组的一行
 81             int[] stMap = new int[samples.length]; //对应前面sampleToThresholdMap二维数组的一行
 82             
 83             int last = -1;
 84             for(int t=0;t<threshold.length;t++) //对于每个可选的split阈值
 85             {
 86                 int j=last+1;
 87                 //find the first sample that exceeds the current threshold
 88                 for(;j<idx.length;j++)
 89                 {
 90                     int k = idx[j]; //获取此DataPoint在samples数组中的索引
 91                     if(samples[k].getFeatureValue(fid) >  threshold[t])
 92                         break;
 93                     sumLeft += labels[k];
 94                     if(i == 0)
 95                     {
 96                         sumResponse += labels[k];
 97                         sqSumResponse += labels[k] * labels[k];
 98                     }
 99                     stMap[k] =  t;
100                 }
101                 last = j-1;    
102                 sumLabel[t] = sumLeft;
103                 c[t] = last+1;
104             }
105             sampleToThresholdMap[i] = stMap;
106             sum[i] = sumLabel;
107             count[i] = c;
108         }
109     }
110     
111     //update(1-1), update the histogram with these training labels (the feature histogram will be used to find the best tree split)
112     protected void update(double[] labels)
113     {
114         sumResponse = 0;
115         sqSumResponse = 0;
116         
117         
118         //确定是否使用多线程计算
119         MyThreadPool p = MyThreadPool.getInstance();
120         if(p.size() == 1)
121             update(labels, 0, features.length-1);
122         else
123             p.execute(new Worker(this, labels), features.length);
124     }
125 
126     //update(1-2),被(1-1)调用
127     protected void update(double[] labels, int start, int end)
128     {
129         for(int f=start;f<=end;f++)
130             Arrays.fill(sum[f], 0);
131         for(int k=0;k<labels.length;k++)
132         {
133             for(int f=start;f<=end;f++)
134             {
135                 int t = sampleToThresholdMap[f][k];
136                 sum[f][t] += labels[k];
137                 if(f == 0)
138                 {
139                     sumResponse += labels[k];
140                     sqSumResponse += labels[k]*labels[k];
141                 }
142                 //count doesn't change, so no need to re-compute
143             }
144         }
145         for(int f=start;f<=end;f++)
146         {            
147             for(int t=1;t<thresholds[f].length;t++)
148                 sum[f][t] += sum[f][t-1];
149         }
150     }
151     
152     //FeatureHistogram构造函数(2-1),一般用作父节点分裂后产生的左子节点的特征统计直方图的构造函数
153     //当使用父节点来构造时,thresholds数组不变,但是sum和count数组需要重新构造
154     //@soi: 使用的训练数据的索引位置
155     public void construct(FeatureHistogram parent, int[] soi, double[] labels)
156     {
157         this.features = parent.features;
158         this.thresholds = parent.thresholds;
159         sumResponse = 0;
160         sqSumResponse = 0;
161         sum = new double[features.length][];
162         count = new int[features.length][];
163         sampleToThresholdMap = parent.sampleToThresholdMap;
164         
165         
166         //确定是否使用多线程计算
167         MyThreadPool p = MyThreadPool.getInstance();
168         if(p.size() == 1)
169             construct(parent, soi, labels, 0, features.length-1);
170         else
171             p.execute(new Worker(this, parent, soi, labels), features.length);    
172     }
173 
174     //FeatureHistogram构造函数(2-2),被(2-1)调用
175     protected void construct(FeatureHistogram parent, int[] soi, double[] labels, int start, int end)
176     {
177         //init
178         for(int i=start;i<=end;i++)
179         {            
180             float[] threshold = thresholds[i];
181             sum[i] = new double[threshold.length];
182             count[i] = new int[threshold.length];
183             Arrays.fill(sum[i], 0);
184             Arrays.fill(count[i], 0);
185         }
186         
187         //update
188         for(int i=0;i<soi.length;i++)
189         {
190             int k = soi[i];
191             for(int f=start;f<=end;f++)
192             {
193                 int t = sampleToThresholdMap[f][k];
194                 sum[f][t] += labels[k];
195                 count[f][t] ++;
196                 if(f == 0)
197                 {
198                     sumResponse += labels[k];
199                     sqSumResponse += labels[k]*labels[k];
200                 }
201             }
202         }
203         
204         for(int f=start;f<=end;f++)
205         {            
206             for(int t=1;t<thresholds[f].length;t++)
207             {
208                 sum[f][t] += sum[f][t-1];
209                 count[f][t] += count[f][t-1];
210             }
211         }
212     }
213     
214     //FeatureHistogram构造函数(3-1),一般用作父节点分裂后产生的右子节点的特征统计直方图的构造函数
215     public void construct(FeatureHistogram parent, FeatureHistogram leftSibling, boolean reuseParent)
216     {
217         this.reuseParent = reuseParent;
218         this.features = parent.features;
219         this.thresholds = parent.thresholds;
220         sumResponse = parent.sumResponse - leftSibling.sumResponse;
221         sqSumResponse = parent.sqSumResponse - leftSibling.sqSumResponse;
222         
223         if(reuseParent)
224         {
225             sum = parent.sum;
226             count = parent.count;
227         }
228         else
229         {
230             sum = new double[features.length][];
231             count = new int[features.length][];
232         }
233         sampleToThresholdMap = parent.sampleToThresholdMap;
234 
235         //确定是否使用多线程计算
236         MyThreadPool p = MyThreadPool.getInstance();
237         if(p.size() == 1)
238             construct(parent, leftSibling, 0, features.length-1);
239         else
240             p.execute(new Worker(this, parent, leftSibling), features.length);
241     }
242 
243     //FeatureHistogram构造函数(3-2),被(3-1)调用
244     protected void construct(FeatureHistogram parent, FeatureHistogram leftSibling, int start, int end)
245     {
246         for(int f=start;f<=end;f++)
247         {
248             float[] threshold = thresholds[f];
249             if(!reuseParent)
250             {
251                 sum[f] = new double[threshold.length];
252                 count[f] = new int[threshold.length];
253             }
254             for(int t=0;t<threshold.length;t++)
255             {
256                 sum[f][t] = parent.sum[f][t] - leftSibling.sum[f][t];
257                 count[f][t] = parent.count[f][t] - leftSibling.count[f][t];
258             }
259         }
260     }
261     
262     //findBestSplit函数(1-2),被(1-1)调用。在一个节点上,在usedFeatures中,根据该节点的特征统计直方图,来进行分裂时feature和threshold的选择
263     protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end)
264     {
265         Config cfg = new Config();
266         int totalCount = count[start][count[start].length-1];
267         for(int f=start;f<=end;f++)
268         {
269             int i = usedFeatures[f];
270             float[] threshold = thresholds[i];
271             
272             for(int t=0;t<threshold.length;t++)
273             {
274                 int countLeft = count[i][t];
275                 int countRight = totalCount - countLeft;
276                 if(countLeft < minLeafSupport || countRight < minLeafSupport)
277                     continue;
278                 
279                 double sumLeft = sum[i][t];
280                 double sumRight = sumResponse - sumLeft;
281                 
282                 double S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight;
283                 //求最大的S值,对应于均方误差最小,是最优的划分点
284                 if(cfg.S < S)
285                 {
286                     cfg.S = S;
287                     cfg.featureIdx = i;
288                     cfg.thresholdIdx = t;
289                 }
290             }
291         }        
292         return cfg;
293     }
294     
295     //findBestSplit函数(1-1),在该节点的已经统计好的特征统计直方图的基础上,寻找最佳分裂点,进行分裂,再计算左右子节点的特征统计直方图,并对左右子节点进行初始化
296     public boolean findBestSplit(Split sp, double[] labels, int minLeafSupport)
297     {
298         if(sp.getDeviance() >= 0.0 && sp.getDeviance() <= 0.0)//equals 0
299             return false;//no need to split
300         
301         int[] usedFeatures = null;//index of the features to be used for tree splitting
302         if(samplingRate < 1)//need to do sub sampling (feature sampling)
303         {
304             int size = (int)(samplingRate * features.length);
305             usedFeatures = new int[size];
306             //put all features into a pool
307             List<Integer> fpool = new ArrayList<Integer>();
308             for(int i=0;i<features.length;i++)
309                 fpool.add(i);
310             //do sampling, without replacement
311             Random r = new Random();
312             for(int i=0;i<size;i++)
313             {
314                 int sel = r.nextInt(fpool.size());
315                 usedFeatures[i] = fpool.get(sel);
316                 fpool.remove(sel);
317             }
318         }
319         else//no sub-sampling, all features will be used
320         {
321             usedFeatures = new int[features.length];
322             for(int i=0;i<features.length;i++)
323                 usedFeatures[i] = i;
324         }
325         
326         //find the best split
327         Config best = new Config();
328         //确定是否使用多线程
329         MyThreadPool p = MyThreadPool.getInstance();
330         if(p.size() == 1)
331             best = findBestSplit(usedFeatures, minLeafSupport, 0, usedFeatures.length-1);
332         else
333         {
334             WorkerThread[] workers = p.execute(new Worker(this, usedFeatures, minLeafSupport), usedFeatures.length);
335             for(int i=0;i<workers.length;i++)
336             {
337                 Worker wk = (Worker)workers[i];
338                 if(best.S < wk.cfg.S)
339                     best = wk.cfg;
340             }        
341         }
342         
343         if(best.S == -1)//unsplitable, for some reason...
344             return false;
345         
346         //if(minS >= sp.getDeviance())
347             //return null;
348         
349         double[] sumLabel = sum[best.featureIdx];
350         int[] sampleCount = count[best.featureIdx];
351         
352         double s = sumLabel[sumLabel.length-1];
353         int c = sampleCount[sumLabel.length-1];
354         
355         double sumLeft = sumLabel[best.thresholdIdx];
356         int countLeft = sampleCount[best.thresholdIdx];
357         
358         double sumRight = s - sumLeft;
359         int countRight = c - countLeft;
360         
361         int[] left = new int[countLeft];
362         int[] right = new int[countRight];
363         int l = 0;
364         int r = 0;
365         int k = 0;
366         int[] idx = sp.getSamples();
367         //对该节点上的每个训练数据,根据最优分裂点,进行左右子节点的分配
368         for(int j=0;j<idx.length;j++)
369         {
370             k = idx[j];
371             if(sampleToThresholdMap[best.featureIdx][k] <= best.thresholdIdx)//go to the left
372                 left[l++] = k;
373             else//go to the right
374                 right[r++] = k;
375         }
376         
377         //初始化分裂后左右子节点各自的特征统计直方图
378         FeatureHistogram lh = new FeatureHistogram();
379         lh.construct(sp.hist, left, labels); //初始化左子节点的特征统计直方图
380         FeatureHistogram rh = new FeatureHistogram();
381         rh.construct(sp.hist, lh, !sp.isRoot()); //初始化右子节点的特征统计直方图
382         double var = sqSumResponse - sumResponse * sumResponse / idx.length; //计算本节点的均方误差
383         double varLeft = lh.sqSumResponse - lh.sumResponse * lh.sumResponse / left.length; //计算左子节点的均方误差
384         double varRight = rh.sqSumResponse - rh.sumResponse * rh.sumResponse / right.length; //计算右子节点的均方误差
385         
386         sp.set(features[best.featureIdx], thresholds[best.featureIdx][best.thresholdIdx], var);
387         sp.setLeft(new Split(left, lh, varLeft, sumLeft));
388         sp.setRight(new Split(right, rh, varRight, sumRight));
389         
390         sp.clearSamples(); //清理本节点所属的sortedSampleIDs,samples,hist等数据
391         
392         return true;
393     }    
394     class Worker extends WorkerThread {
395         FeatureHistogram fh = null;
396         int type = -1;
397         
398         //find best split (type == 0)
399         int[] usedFeatures = null;
400         int minLeafSup = -1;
401         Config cfg = null;
402         
403         //update (type = 1)
404         double[] labels = null;
405         
406         //construct (type = 2)
407         FeatureHistogram parent = null;
408         int[] soi = null;
409         
410         //construct (type = 3)
411         FeatureHistogram leftSibling = null;
412         
413         //construct (type = 4)
414         DataPoint[] samples;
415         int[][] sampleSortedIdx;
416         float[][] thresholds;
417         
418         public Worker()
419         {
420         }
421         public Worker(FeatureHistogram fh, int[] usedFeatures, int minLeafSup)
422         {
423             type = 0;
424             this.fh = fh;
425             this.usedFeatures = usedFeatures;
426             this.minLeafSup = minLeafSup;
427         }
428         public Worker(FeatureHistogram fh, double[] labels)
429         {
430             type = 1;
431             this.fh = fh;
432             this.labels = labels;
433         }
434         public Worker(FeatureHistogram fh, FeatureHistogram parent, int[] soi, double[] labels)
435         {
436             type = 2;
437             this.fh = fh;
438             this.parent = parent;
439             this.soi = soi;
440             this.labels = labels;
441         }
442         public Worker(FeatureHistogram fh, FeatureHistogram parent, FeatureHistogram leftSibling)
443         {
444             type = 3;
445             this.fh = fh;
446             this.parent = parent;
447             this.leftSibling = leftSibling;
448         }
449         public Worker(FeatureHistogram fh, DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds)
450         {
451             type = 4;
452             this.fh = fh;
453             this.samples = samples;
454             this.labels = labels;
455             this.sampleSortedIdx = sampleSortedIdx;
456             this.thresholds = thresholds;            
457         }
458         public void run()
459         {
460             if(type == 0)
461                 cfg = fh.findBestSplit(usedFeatures, minLeafSup, start, end);
462             else if(type == 1)
463                 fh.update(labels, start, end);
464             else if(type == 2)
465                 fh.construct(parent, soi, labels, start, end);
466             else if(type == 3)
467                 fh.construct(parent, leftSibling, start, end);
468             else if(type == 4)
469                 fh.construct(samples, labels, sampleSortedIdx, thresholds, start, end);
470         }        
471         public WorkerThread clone()
472         {
473             Worker wk = new Worker();
474             wk.fh = fh;
475             wk.type = type;
476             
477             //find best split (type == 0)
478             wk.usedFeatures = usedFeatures;
479             wk.minLeafSup = minLeafSup;
480             //wk.cfg = cfg;
481             
482             //update (type = 1)
483             wk.labels = labels;
484             
485             //construct (type = 2)
486             wk.parent = parent;
487             wk.soi = soi;
488             
489             //construct (type = 3)
490             wk.leftSibling = leftSibling;
491             
492             //construct (type = 1)
493             wk.samples = samples;
494             wk.sampleSortedIdx = sampleSortedIdx;
495             wk.thresholds = thresholds;            
496             
497             return wk;
498         }
499     }
500 }

 

2. Split

  1 package ciir.umass.edu.learning.tree;
  2 import java.util.ArrayList;
  3 import java.util.List;
  4 import ciir.umass.edu.learning.DataPoint;
  5 /**
  6  * 
  7  * @author vdang
  8  *
  9  */
 10 //Tree node,节点类,用于:
 11 // 1)训练时候的分裂判断(利用FeatureHistogram类);
 12 // 2)存储该节点的分裂规则(featureID,threshold)以及该节点的输出(avgLabel,deviance等)
 13 public class Split {
 14     //Key attributes of a split (tree node)
 15     //存储该节点的分裂规则(featureID,threshold)以及该节点的输出(avgLabel,deviance等)
 16     private int featureID = -1;
 17     private float threshold = 0F;
 18     private double avgLabel = 0.0F;
 19     
 20     //Intermediate variables (ONLY used during learning)
 21     //*DO NOT* attempt to access them once the training is done
 22     private boolean isRoot = false;
 23     private double sumLabel = 0.0;
 24     private double sqSumLabel = 0.0;
 25     private Split left = null;
 26     private Split right = null;
 27     private double deviance = 0F;//mean squared error "S"
 28     private int[][] sortedSampleIDs = null;
 29     public int[] samples = null;//训练时候,该节点上的训练数据集的索引
 30     public FeatureHistogram hist = null;//训练时候,该节点上的训练数据集的特征统计直方图
 31     
 32     public Split()
 33     {
 34         
 35     }
 36     public Split(int featureID, float threshold, double deviance)
 37     {
 38         this.featureID = featureID;
 39         this.threshold = threshold;
 40         this.deviance = deviance;
 41     }
 42     public Split(int[][] sortedSampleIDs, double deviance, double sumLabel, double sqSumLabel)
 43     {
 44         this.sortedSampleIDs = sortedSampleIDs;
 45         this.deviance = deviance;
 46         this.sumLabel = sumLabel;
 47         this.sqSumLabel = sqSumLabel;
 48         avgLabel = sumLabel/sortedSampleIDs[0].length;
 49     }
 50     public Split(int[] samples, FeatureHistogram hist, double deviance, double sumLabel)
 51     {
 52         this.samples = samples;
 53         this.hist = hist;
 54         this.deviance = deviance;
 55         this.sumLabel = sumLabel;
 56         avgLabel = sumLabel/samples.length;
 57     }
 58     
 59     //一般在该节点进行分裂完成后,设定分裂时的featureID,threshold,deviance。
 60     //只有非叶子节点才会进行分裂(调用这个方法),所以只有非叶子节点的featureID不为-1,叶子节点由于没有调用这个方法,故featureID=-1
 61     public void set(int featureID, float threshold, double deviance)
 62     {
 63         this.featureID = featureID;
 64         this.threshold = threshold;
 65         this.deviance = deviance;
 66     }
 67     public void setLeft(Split s)
 68     {
 69         left = s;
 70     }
 71     public void setRight(Split s)
 72     {
 73         right = s;
 74     }
 75     public void setOutput(float output)
 76     {
 77         avgLabel = output;
 78     }
 79     
 80     public Split getLeft()
 81     {
 82         return left;
 83     }
 84     public Split getRight()
 85     {
 86         return right;
 87     }
 88     public double getDeviance()
 89     {
 90         return deviance;
 91     }
 92     public double getOutput()
 93     {
 94         return avgLabel;
 95     }
 96     
 97     //得到此节点(一般是根节点)下的所有叶子节点的list
 98     //采用了递归的方法,碰到叶子节点(featureID=-1)则加入到list中,否则递归地调用leaves(list),
 99     public List<Split> leaves()
100     {
101         List<Split> list = new ArrayList<Split>();
102         leaves(list);
103         return list;        
104     }
105     private void leaves(List<Split> leaves)
106     {
107         if(featureID == -1)
108             leaves.add(this);
109         else
110         {
111             left.leaves(leaves);
112             right.leaves(leaves);
113         }
114     }
115     
116     //得到一个DataPoint在此节点(一般是根节点)下的最终落入(每层都按照分裂规则进入下一层)的叶子节点的输出值(avgLabel值)
117     public double eval(DataPoint dp)
118     {
119         Split n = this;
120         while(n.featureID != -1)
121         {
122             if(dp.getFeatureValue(n.featureID) <= n.threshold)
123                 n = n.left;
124             else
125                 n = n.right;
126         }
127         return n.avgLabel;
128     }
129     
130     public String toString()
131     {
132         return toString("");
133     }
134     public String toString(String indent)
135     {
136         String strOutput = indent + "<split>" + "\n";
137         strOutput += getString(indent + "\t");
138         strOutput += indent + "</split>" + "\n";
139         return strOutput;
140     }
141     public String getString(String indent)
142     {
143         String strOutput = "";
144         if(featureID == -1)
145         {
146             strOutput += indent + "<output> " + avgLabel + " </output>" + "\n";
147         }
148         else
149         {
150             strOutput += indent + "<feature> " + featureID + " </feature>" + "\n";
151             strOutput += indent + "<threshold> " + threshold + " </threshold>" + "\n";
152             strOutput += indent + "<split pos=\"left\">" + "\n";
153             strOutput += left.getString(indent + "\t");
154             strOutput += indent + "</split>" + "\n";
155             strOutput += indent + "<split pos=\"right\">" + "\n";
156             strOutput += right.getString(indent + "\t");
157             strOutput += indent + "</split>" + "\n";
158         }
159         return strOutput;
160     }
161     //Internal functions(ONLY used during learning)
162     //*DO NOT* attempt to call them once the training is done
163     //*重要*,训练时候,在该节点上进行分裂,调用了该节点的特征统计直方图对象的方法findBestSplit
164     public boolean split(double[] trainingLabels, int minLeafSupport)
165     {
166         return hist.findBestSplit(this, trainingLabels, minLeafSupport);
167     }
168     public int[] getSamples()
169     {
170         if(sortedSampleIDs != null)
171             return sortedSampleIDs[0];
172         return samples;
173     }
174     public int[][] getSampleSortedIndex()
175     {
176         return sortedSampleIDs;
177     }
178     public double getSumLabel()
179     {
180         return sumLabel;
181     }
182     public double getSqSumLabel()
183     {
184         return sqSumLabel;
185     }
186     public void clearSamples()
187     {
188         sortedSampleIDs = null;
189         samples = null;
190         hist = null;
191     }
192     public void setRoot(boolean isRoot)
193     {
194         this.isRoot = isRoot;
195     }
196     public boolean isRoot()
197     {
198         return isRoot;
199     }
200 }

 

3. RegressionTree

  1 package ciir.umass.edu.learning.tree;
  2 import java.util.ArrayList;
  3 import java.util.List;
  4 import ciir.umass.edu.learning.DataPoint;
  5 /**
  6  * @author vdang
  7  */
  8 //回归树类
  9 public class RegressionTree {
 10     
 11     //Parameters
 12     protected int nodes = 10;//-1 for unlimited number of nodes (the size of the tree will then be controlled *ONLY* by minLeafSupport)
 13     protected int minLeafSupport = 1; //控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂
 14     
 15     //Member variables and functions 
 16     protected Split root = null; //根节点
 17     protected List<Split> leaves = null; //叶子节点list
 18     
 19     protected DataPoint[] trainingSamples = null;
 20     protected double[] trainingLabels = null;
 21     protected int[] features = null;
 22     protected float[][] thresholds = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是阈值,个数为所有训练数据在此feature上的value的去重个数,从小到大排序的不重复值,用于对此节点的训练数据在此feature上分裂时可选的feature value阈值
 23     protected int[] index = null;
 24     protected FeatureHistogram hist = null;
 25     
 26     public RegressionTree(Split root)
 27     {
 28         this.root = root;
 29         leaves = root.leaves();
 30     }
 31     public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)
 32     {
 33         this.nodes = nLeaves;
 34         this.trainingSamples = trainingSamples;
 35         this.trainingLabels = labels;
 36         this.hist = hist;
 37         this.minLeafSupport = minLeafSupport;
 38         index = new int[trainingSamples.length];
 39         for(int i=0;i<trainingSamples.length;i++)
 40             index[i] = i;
 41     }
 42     
 43     /**
 44      * Fit the tree from the specified training data
 45      */
 46     public void fit()
 47     {
 48         List<Split> queue = new ArrayList<Split>(); //用于按队列顺序(即按层遍历的顺序)进行分裂
 49         root = new Split(index, hist, Float.MAX_VALUE, 0); //回归树的根节点
 50         root.setRoot(true);
 51         root.split(trainingLabels, minLeafSupport); //根节点分裂1次,下面多了2个子节点
 52         insert(queue, root.getLeft()); //将左子节点插入队列,用于下面遍历
 53         insert(queue, root.getRight()); //将右子节点插入队列,用于下面遍历
 54         //循环:按队列顺序(即按层遍历的顺序)进行分裂,再将每次能够成功分裂的产生的两个子节点插入队列中
 55         int taken = 0;
 56         while( (nodes == -1 || taken + queue.size() < nodes) && queue.size() > 0)
 57         {
 58             Split leaf = queue.get(0);
 59             queue.remove(0);
 60             
 61             if(leaf.getSamples().length < 2 * minLeafSupport)
 62             {
 63                 taken++;
 64                 continue;
 65             }
 66             
 67             if(!leaf.split(trainingLabels, minLeafSupport))//unsplitable (i.e. variance(s)==0; or after-split variance is higher than before) 对每个遍历到的节点,进行1次分裂,下面多了2个子节点
 68                 taken++;
 69             else
 70             {
 71                 insert(queue, leaf.getLeft()); //将左子节点插入队列,用于下面遍历
 72                 insert(queue, leaf.getRight()); //将右子节点插入队列,用于下面遍历
 73             }            
 74         }
 75         leaves = root.leaves();
 76     }
 77     
 78     /**
 79      * Get the tree output for the input sample
 80      * @param dp
 81      * @return
 82      */
 83     public double eval(DataPoint dp)
 84     {
 85         return root.eval(dp);
 86     }
 87     /**
 88      * Retrieve all leave nodes in the tree
 89      * @return
 90      */
 91     public List<Split> leaves()
 92     {
 93         return leaves;
 94     }
 95     /**
 96      * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory
 97      */
 98     public void clearSamples()
 99     {
100         trainingSamples = null;
101         trainingLabels = null;
102         features = null;
103         thresholds = null;
104         index = null;
105         hist = null;
106         for(int i=0;i<leaves.size();i++)
107             leaves.get(i).clearSamples();
108     }
109     
110     /**
111      * Generate the string representation of the tree
112      */
113     public String toString()
114     {
115         if(root != null)
116             return root.toString();
117         return "";
118     }
119     public String toString(String indent)
120     {
121         if(root != null)
122             return root.toString(indent);
123         return "";
124     }
125     
126     public double variance()
127     {
128         double var = 0;
129         for(int i=0;i<leaves.size();i++)
130             var += leaves.get(i).getDeviance();
131         return var;
132     }
133     protected void insert(List<Split> ls, Split s)
134     {
135         int i=0;
136         while(i < ls.size())
137         {
138             if(ls.get(i).getDeviance() > s.getDeviance()) //按均方误差从大到小的顺序进行插入队列
139                 i++;
140             else
141                 break;
142         }
143         ls.add(i, s);
144     }
145 }

 

2. LambdaMart

LambdaMart模型训练过程总结概括如下:

 1 LambdaMart
 2     init
 3         初始化训练数据:martSamples,modelScores,pseudoResponses,weights
 4         将样本根据特征排序,方便做树的分裂时快速找出最优分裂点:sortedIdx
 5         初始化二维数组:thresholds(第一维是feature,下标是相应的features的下标,不是feature id;第二维是阈值,个数为所有训练数据在此feature上的value的去重个数,从小到大排序的不重复值,用于对此节点的训练数据在此feature上分裂时可选的feature value阈值)
 6         hist.construct #根据训练数据以及thresholds二维数组,初始化一个FeatureHistogram对象,用于构造整体数据的特征统计直方图,用于在根节点上进行分裂
 7             初始化:
 8                 sum #二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是label之和,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二维数组大小与thresholds数组相同
 9                 count #二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是个数,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的个数,count二维数组大小与thresholds数组相同
10                 sampleToThresholdMap #二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是索引,是对应训练数据samples[i][j]在特定feature上每个训练数据的value对应于其在thresholds数组中相应行的列索引位置
11                 sumResponse #所有的训练数据的label之和
12                 sqSumResponse #所有的训练数据的label的平方和
13     learn
14         初始化一个Ensemble对象ensemble
15         开始Gradient Boosting过程,即依次构造若干棵regression tree:
16             computePseudoResponses #计算本轮迭代中,每个instance需要拟合的pseudo responses值(即梯度值,lambda)
17                 根据LambdaMart的梯度计算公式进行计算
18             hist.update #根据本轮迭代中计算得到的pseudo responses值(即梯度值,lambda),更新特征统计直方图,因为只改变了training data中每个instance的label,而其他值(如features)并未改变
19             初始化一棵regression tree(根据训练数据和特征统计直方图)
20             rt.fit #用regression tree对训练数据+本轮迭代中的pseudo responses值(即梯度值,lambda)进行拟合
21             将本轮迭代拟合产生的regression tree加入到ensembel对象中
22             updateTreeOutput #更新本轮迭代中拟合数据的regression tree的各个叶子节点的输出
23             计算本轮迭代后(新regression tree已经加入到集成模型中),training data中各个instance的预测分:modelScores
24             computeModelScoreOnTraining #计算本轮迭代后,最新模型对于training data总体的排序评价分(例如NDCG)
25             计算本轮迭代后(新regression tree已经加入到集成模型中),validation data中各个instance的预测分:modelScoresOnValidation
26             computeModelScoreOnValidation #计算本轮迭代后,最新模型对于validation data总体的排序评价分(例如NDCG)
27             更新在validation data上的历次各个模型的最优排序评价分:bestScoreOnValidationData,以及最优模型编号:bestModelOnValidation
28             如果在连续若干轮迭代中,模型在validation data上的排序评价分都没有提高,则终止迭代
29         回滚到在验证集上的最优模型
30         计算最优模型在training data和validation data上的排序评价分

 

下面是LambdaMart训练过程的代码,关键部分都有添加了详细的注释。

 

1. LambdaMART

  1 package ciir.umass.edu.learning.tree;
  2 import ciir.umass.edu.learning.DataPoint;
  3 import ciir.umass.edu.learning.RankList;
  4 import ciir.umass.edu.learning.Ranker;
  5 import ciir.umass.edu.metric.MetricScorer;
  6 import ciir.umass.edu.utilities.MergeSorter;
  7 import ciir.umass.edu.utilities.MyThreadPool;
  8 import ciir.umass.edu.utilities.RankLibError;
  9 import ciir.umass.edu.utilities.SimpleMath;
 10 import java.io.BufferedReader;
 11 import java.io.StringReader;
 12 import java.util.ArrayList;
 13 import java.util.Arrays;
 14 import java.util.List;
 15 /**
 16  * @author vdang
 17  *
 18  *  This class implements LambdaMART.
 19  *  Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures. 
 20  *  Journal of Information Retrieval, 2007.
 21  */
 22 public class LambdaMART extends Ranker {
 23     //Parameters
 24     public static int nTrees = 1000;//the number of trees
 25     public static float learningRate = 0.1F;//or shrinkage
 26     public static int nThreshold = 256;
 27     public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away. 
 28     public static int nTreeLeaves = 10;
 29     public static int minLeafSupport = 1;
 30     
 31     //for debugging
 32     public static int gcCycle = 100;
 33     
 34     //Local variables
 35     protected float[][] thresholds = null;
 36     protected Ensemble ensemble = null;
 37     protected double[] modelScores = null;//on training data
 38     
 39     protected double[][] modelScoresOnValidation = null;
 40     protected int bestModelOnValidation = Integer.MAX_VALUE-2;
 41     
 42     //Training instances prepared for MART
 43     protected DataPoint[] martSamples = null;//Need initializing only once
 44     protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once 
 45     protected FeatureHistogram hist = null;
 46     protected double[] pseudoResponses = null;//different for each iteration
 47     protected double[] weights = null;//different for each iteration
 48     
 49     public LambdaMART()
 50     {        
 51     }
 52     public LambdaMART(List<RankList> samples, int[] features, MetricScorer scorer)
 53     {
 54         super(samples, features, scorer);
 55     }
 56     
 57     public void init()
 58     {
 59         PRINT("Initializing... ");        
 60         //initialize samples for MART
 61         int dpCount = 0;
 62         for(int i=0;i<samples.size();i++)
 63         {
 64             RankList rl = samples.get(i);
 65             dpCount += rl.size();
 66         }
 67         int current = 0;
 68         martSamples = new DataPoint[dpCount];
 69         modelScores = new double[dpCount];
 70         pseudoResponses = new double[dpCount];
 71         weights = new double[dpCount];
 72         for(int i=0;i<samples.size();i++)
 73         {
 74             RankList rl = samples.get(i);
 75             for(int j=0;j<rl.size();j++)
 76             {
 77                 martSamples[current+j] = rl.get(j);
 78                 modelScores[current+j] = 0.0F;
 79                 pseudoResponses[current+j] = 0.0F;
 80                 weights[current+j] = 0;
 81             }
 82             current += rl.size();
 83         }            
 84         
 85         //sort (MART) samples by each feature so that we can quickly retrieve a sorted list of samples by any feature later on.
 86         // 将样本根据特征排序,方便做树的分裂时快速找出最优分裂点
 87         sortedIdx = new int[features.length][];
 88         MyThreadPool p = MyThreadPool.getInstance();
 89         if(p.size() == 1)//single-thread
 90             sortSamplesByFeature(0, features.length-1);
 91         else//multi-thread
 92         {
 93             int[] partition = p.partition(features.length);
 94             for(int i=0;i<partition.length-1;i++)
 95                 p.execute(new SortWorker(this, partition[i], partition[i+1]-1));
 96             p.await();
 97         }
 98         
 99         //Create a table of candidate thresholds (for each feature). Later on, we will select the best tree split from these candidates        // 创建存放候选阈值(分裂点)的表
100         thresholds = new float[features.length][];
101         for(int f=0;f<features.length;f++)
102         {
103             //For this feature, keep track of the list of unique values and the max/min 
104             List<Float> values = new ArrayList<Float>();
105             float fmax = Float.NEGATIVE_INFINITY;
106             float fmin = Float.MAX_VALUE;
107             for(int i=0;i<martSamples.length;i++)
108             {
109                 int k = sortedIdx[f][i];//get samples sorted with respect to this feature
110                 float fv = martSamples[k].getFeatureValue(features[f]);
111                 values.add(fv);
112                 if(fmax < fv)
113                     fmax = fv;
114                 if(fmin > fv)
115                     fmin = fv;
116                 //skip all samples with the same feature value
117                 int j=i+1;
118                 while(j < martSamples.length)
119                 {
120                     if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv)
121                         break;
122                     j++;
123                 }
124                 i = j-1;//[i, j] gives the range of samples with the same feature value
125             }
126             
127             if(values.size() <= nThreshold || nThreshold == -1)
128             {
129                 thresholds[f] = new float[values.size()+1];
130                 for(int i=0;i<values.size();i++)
131                     thresholds[f][i] = values.get(i);
132                 thresholds[f][values.size()] = Float.MAX_VALUE;
133             }
134             else
135             {
136                 float step = (Math.abs(fmax - fmin))/nThreshold;
137                 thresholds[f] = new float[nThreshold+1];
138                 thresholds[f][0] = fmin;
139                 for(int j=1;j<nThreshold;j++)
140                     thresholds[f][j] = thresholds[f][j-1] + step;
141                 thresholds[f][nThreshold] = Float.MAX_VALUE;
142             }
143         }
144         
145         if(validationSamples != null)
146         {
147             modelScoresOnValidation = new double[validationSamples.size()][];
148             for(int i=0;i<validationSamples.size();i++)
149             {
150                 modelScoresOnValidation[i] = new double[validationSamples.get(i).size()];
151                 Arrays.fill(modelScoresOnValidation[i], 0);
152             }
153         }
154         
155         //compute the feature histogram (this is used to speed up the procedure of finding the best tree split later on)
156         // 计算特征直方图,加速寻找分裂点
157         hist = new FeatureHistogram();
158         hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);
159         //we no longer need the sorted indexes of samples
160         sortedIdx = null;
161         
162         System.gc();
163         PRINTLN("[Done]");
164     }
165     public void learn()
166     {
167         ensemble = new Ensemble();
168         
169         PRINTLN("---------------------------------");
170         PRINTLN("Training starts...");
171         PRINTLN("---------------------------------");
172         PRINTLN(new int[]{7, 9, 9}, new String[]{"#iter", scorer.name()+"-T", scorer.name()+"-V"});
173         PRINTLN("---------------------------------");        
174         
175         //Start the gradient boosting process
176         for(int m=0; m<nTrees; m++)
177         {
178             PRINT(new int[]{7}, new String[]{(m+1)+""});
179             
180             //Compute lambdas (which act as the "pseudo responses")
181             //Create training instances for MART:
182             //  - Each document is a training sample
183             //    - The lambda for this document serves as its training label
184             // 计算lambdas (pseudo responses)
185             computePseudoResponses();
186             
187             //update the histogram with these training labels (the feature histogram will be used to find the best tree split)
188             // 根据新的label更新特征直方图
189             hist.update(pseudoResponses);
190         
191             //Fit a regression tree        
192             // 回归决策树    
193             RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);
194             rt.fit();
195             
196             //Add this tree to the ensemble (our model)
197             // 将新生成的树加入模型
198             ensemble.add(rt, learningRate);
199             //update the outputs of the tree (with gamma computed using the Newton-Raphson method) 
200             // 更新树的输出
201             updateTreeOutput(rt);
202             
203             //Update the model's outputs on all training samples
204             // 更新所有训练样本的模型输出
205             List<Split> leaves = rt.leaves();
206             for(int i=0;i<leaves.size();i++)
207             {
208                 Split s = leaves.get(i);
209                 int[] idx = s.getSamples();
210                 for(int j=0;j<idx.length;j++)
211                     modelScores[idx[j]] += learningRate * s.getOutput();
212             }
213             //clear references to data that is no longer used
214             rt.clearSamples();
215             
216             //beg the garbage collector to work...
217             if(m % gcCycle == 0)
218                 System.gc();//this call is expensive. We shouldn't do it too often.
219             //Evaluate the current model
220             // 评价模型
221             scoreOnTrainingData = computeModelScoreOnTraining();
222             //**** NOTE ****
223             //The above function to evaluate the current model on the training data is equivalent to a single call:
224             //
225             //        scoreOnTrainingData = scorer.score(rank(samples);
226             //
227             //However, this function is more efficient since it uses the cached outputs of the model (as opposed to re-evaluating the model 
228             //on the entire training set).
229             
230             PRINT(new int[]{9}, new String[]{SimpleMath.round(scoreOnTrainingData, 4) + ""});            
231             
232             //Evaluate the current model on the validation data (if available)
233             if(validationSamples != null)
234             {
235                 //Update the model's scores on all validation samples
236                 for(int i=0;i<modelScoresOnValidation.length;i++)
237                     for(int j=0;j<modelScoresOnValidation[i].length;j++)
238                         modelScoresOnValidation[i][j] += learningRate * rt.eval(validationSamples.get(i).get(j));
239                 
240                 //again, equivalent to scoreOnValidation=scorer.score(rank(validationSamples)), but more efficient since we use the cached models' outputs
241                 double score = computeModelScoreOnValidation();
242                 
243                 PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
244                 if(score > bestScoreOnValidationData)
245                 {
246                     bestScoreOnValidationData = score;
247                     bestModelOnValidation = ensemble.treeCount()-1;
248                 }
249             }
250             
251             PRINTLN("");
252             
253             //Should we stop early?
254             // 检验是否提前结束
255             if(m - bestModelOnValidation > nRoundToStopEarly)
256                 break;
257         }
258         
259         //Rollback to the best model observed on the validation data
260         // 回滚到在验证集上的最优模型
261         while(ensemble.treeCount() > bestModelOnValidation+1)
262             ensemble.remove(ensemble.treeCount()-1);
263         
264         //Finishing up
265         scoreOnTrainingData = scorer.score(rank(samples));
266         PRINTLN("---------------------------------");
267         PRINTLN("Finished sucessfully.");
268         PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4));
269         if(validationSamples != null)
270         {
271             bestScoreOnValidationData = scorer.score(rank(validationSamples));
272             PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
273         }
274         PRINTLN("---------------------------------");
275     }
276     public double eval(DataPoint dp)
277     {
278         return ensemble.eval(dp);
279     }    
280     public Ranker createNew()
281     {
282         return new LambdaMART();
283     }
284     public String toString()
285     {
286         return ensemble.toString();
287     }
288     public String model()
289     {
290         String output = "## " + name() + "\n";
291         output += "## No. of trees = " + nTrees + "\n";
292         output += "## No. of leaves = " + nTreeLeaves + "\n";
293         output += "## No. of threshold candidates = " + nThreshold + "\n";
294         output += "## Learning rate = " + learningRate + "\n";
295         output += "## Stop early = " + nRoundToStopEarly + "\n";
296         output += "\n";
297         output += toString();
298         return output;
299     }
300         @Override
301     public void loadFromString(String fullText)
302     {
303         try {
304             String content = "";
305             //String model = "";
306                         StringBuffer model = new StringBuffer ();
307             BufferedReader in = new BufferedReader(new StringReader(fullText));
308             while((content = in.readLine()) != null)
309             {
310                 content = content.trim();
311                 if(content.length() == 0)
312                     continue;
313                 if(content.indexOf("##")==0)
314                     continue;
315                 //actual model component
316                 //model += content;
317                                 model.append (content);
318             }
319             in.close();
320             //load the ensemble
321             ensemble = new Ensemble(model.toString());
322             features = ensemble.getFeatures();
323         }
324         catch(Exception ex)
325         {
326             throw RankLibError.create("Error in LambdaMART::load(): ", ex);
327         }
328     }
329     public void printParameters()
330     {
331         PRINTLN("No. of trees: " + nTrees);
332         PRINTLN("No. of leaves: " + nTreeLeaves);
333         PRINTLN("No. of threshold candidates: " + nThreshold);
334         PRINTLN("Min leaf support: " + minLeafSupport);
335         PRINTLN("Learning rate: " + learningRate);
336         PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");        
337     }    
338     public String name()
339     {
340         return "LambdaMART";
341     }
342     public Ensemble getEnsemble()
343     {
344         return ensemble;
345     }
346     
347     protected void computePseudoResponses()
348     {
349         Arrays.fill(pseudoResponses, 0F);
350         Arrays.fill(weights, 0);
351         MyThreadPool p = MyThreadPool.getInstance();
352         if(p.size() == 1)//single-thread
353             computePseudoResponses(0, samples.size()-1, 0);
354         else //multi-threading
355         {
356             List<LambdaComputationWorker> workers = new ArrayList<LambdaMART.LambdaComputationWorker>();
357             //divide the entire dataset into chunks of equal size for each worker thread
358             int[] partition = p.partition(samples.size());
359             int current = 0;
360             for(int i=0;i<partition.length-1;i++)
361             {
362                 //execute the worker
363                 LambdaComputationWorker wk = new LambdaComputationWorker(this, partition[i], partition[i+1]-1, current); 
364                 workers.add(wk);//keep it so we can get back results from it later on
365                 p.execute(wk);
366                 
367                 if(i < partition.length-2)
368                     for(int j=partition[i]; j<=partition[i+1]-1;j++)
369                         current += samples.get(j).size();
370             }
371             
372             //wait for all workers to complete before we move on to the next stage
373             p.await();
374         }
375     }
376     protected void computePseudoResponses(int start, int end, int current)
377     {
378         int cutoff = scorer.getK();
379         //compute the lambda for each document (a.k.a "pseudo response")
380         for(int i=start;i<=end;i++)
381         {
382             RankList orig = samples.get(i);            
383             int[] idx = MergeSorter.sort(modelScores, current, current+orig.size()-1, false);
384             RankList rl = new RankList(orig, idx, current);
385             double[][] changes = scorer.swapChange(rl);
386             //NOTE: j, k are indices in the sorted (by modelScore) list, not the original
387             // ==> need to map back with idx[j] and idx[k] 
388             for(int j=0;j<rl.size();j++)
389             {
390                 DataPoint p1 = rl.get(j);
391                 int mj = idx[j];
392                 for(int k=0;k<rl.size();k++)
393                 {
394                     if(j > cutoff && k > cutoff)//swaping these pair won't result in any change in target measures since they're below the cut-off point
395                         break;
396                     DataPoint p2 = rl.get(k);
397                     int mk = idx[k];
398                     if(p1.getLabel() > p2.getLabel())
399                     {
400                         double deltaNDCG = Math.abs(changes[j][k]);
401                         if(deltaNDCG > 0)
402                         {
403                             double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk]));
404                             double lambda = rho * deltaNDCG;
405                             pseudoResponses[mj] += lambda;
406                             pseudoResponses[mk] -= lambda;
407                             double delta = rho * (1.0 - rho) * deltaNDCG;
408                             weights[mj] += delta;
409                             weights[mk] += delta;
410                         }
411                     }
412                 }
413             }
414             current += orig.size();
415         }
416     }
417     protected void updateTreeOutput(RegressionTree rt)
418     {
419         List<Split> leaves = rt.leaves();
420         for(int i=0;i<leaves.size();i++)
421         {
422             float s1 = 0F;
423             float s2 = 0F;
424             Split s = leaves.get(i);
425             int[] idx = s.getSamples();
426             for(int j=0;j<idx.length;j++)
427             {
428                 int k = idx[j];
429                 s1 += pseudoResponses[k];
430                 s2 += weights[k];
431             }
432             if(s2 == 0)
433                 s.setOutput(0);
434             else
435                 s.setOutput(s1/s2);
436         }
437     }
438     protected int[] sortSamplesByFeature(DataPoint[] samples, int fid)
439     {
440         double[] score = new double[samples.length];
441         for(int i=0;i<samples.length;i++)
442             score[i] = samples[i].getFeatureValue(fid);
443         int[] idx = MergeSorter.sort(score, true); 
444         return idx;
445     }
446     /**
447      * This function is equivalent to the inherited function rank(...), but it uses the cached model's outputs instead of computing them from scratch.
448      * @param rankListIndex
449      * @param current
450      * @return
451      */
452     protected RankList rank(int rankListIndex, int current)
453     {
454         RankList orig = samples.get(rankListIndex);    
455         double[] scores = new double[orig.size()];
456         for(int i=0;i<scores.length;i++)
457             scores[i] = modelScores[current+i];
458         int[] idx = MergeSorter.sort(scores, false);
459         return new RankList(orig, idx);
460     }
461     protected float computeModelScoreOnTraining() 
462     {
463         /*float s = 0;
464         int current = 0;    
465         MyThreadPool p = MyThreadPool.getInstance();
466         if(p.size() == 1)//single-thread
467             s = computeModelScoreOnTraining(0, samples.size()-1, current);
468         else
469         {
470             List<Worker> workers = new ArrayList<Worker>();
471             //divide the entire dataset into chunks of equal size for each worker thread
472             int[] partition = p.partition(samples.size());
473             for(int i=0;i<partition.length-1;i++)
474             {
475                 //execute the worker
476                 Worker wk = new Worker(this, partition[i], partition[i+1]-1, current);
477                 workers.add(wk);//keep it so we can get back results from it later on
478                 p.execute(wk);
479                 
480                 if(i < partition.length-2)
481                     for(int j=partition[i]; j<=partition[i+1]-1;j++)
482                         current += samples.get(j).size();
483             }        
484             //wait for all workers to complete before we move on to the next stage
485             p.await();
486             for(int i=0;i<workers.size();i++)
487                 s += workers.get(i).score;
488         }*/
489         float s = computeModelScoreOnTraining(0, samples.size()-1, 0);
490         s = s / samples.size();
491         return s;
492     }
493     protected float computeModelScoreOnTraining(int start, int end, int current) 
494     {
495         float s = 0;
496         int c = current;
497         for(int i=start;i<=end;i++)
498         {
499             s += scorer.score(rank(i, c));
500             c += samples.get(i).size();
501         }
502         return s;
503     }
504     protected float computeModelScoreOnValidation() 
505     {
506         /*float score = 0;
507         MyThreadPool p = MyThreadPool.getInstance();
508         if(p.size() == 1)//single-thread
509             score = computeModelScoreOnValidation(0, validationSamples.size()-1);
510         else
511         {
512             List<Worker> workers = new ArrayList<Worker>();
513             //divide the entire dataset into chunks of equal size for each worker thread
514             int[] partition = p.partition(validationSamples.size());
515             for(int i=0;i<partition.length-1;i++)
516             {
517                 //execute the worker
518                 Worker wk = new Worker(this, partition[i], partition[i+1]-1);
519                 workers.add(wk);//keep it so we can get back results from it later on
520                 p.execute(wk);
521             }        
522             //wait for all workers to complete before we move on to the next stage
523             p.await();
524             for(int i=0;i<workers.size();i++)
525                 score += workers.get(i).score;
526         }*/
527         float score = computeModelScoreOnValidation(0, validationSamples.size()-1);
528         return score/validationSamples.size();
529     }
530     protected float computeModelScoreOnValidation(int start, int end) 
531     {
532         float score = 0;
533         for(int i=start;i<=end;i++)
534         {
535             int[] idx = MergeSorter.sort(modelScoresOnValidation[i], false);
536             score += scorer.score(new RankList(validationSamples.get(i), idx));
537         }
538         return score;
539     }
540     
541     protected void sortSamplesByFeature(int fStart, int fEnd)
542     {
543         for(int i=fStart;i<=fEnd; i++)
544             sortedIdx[i] = sortSamplesByFeature(martSamples, features[i]);
545     }
546     //For multi-threading processing
547     class SortWorker implements Runnable {
548         LambdaMART ranker = null;
549         int start = -1;
550         int end = -1;
551         SortWorker(LambdaMART ranker, int start, int end)
552         {
553             this.ranker = ranker;
554             this.start = start;
555             this.end = end;
556         }        
557         public void run()
558         {
559             ranker.sortSamplesByFeature(start, end);
560         }
561     }
562     class LambdaComputationWorker implements Runnable {
563         LambdaMART ranker = null;
564         int rlStart = -1;
565         int rlEnd = -1;
566         int martStart = -1;
567         LambdaComputationWorker(LambdaMART ranker, int rlStart, int rlEnd, int martStart)
568         {
569             this.ranker = ranker;
570             this.rlStart = rlStart;
571             this.rlEnd = rlEnd;
572             this.martStart = martStart;
573         }        
574         public void run()
575         {
576             ranker.computePseudoResponses(rlStart, rlEnd, martStart);
577         }
578     }
579     class Worker implements Runnable {
580         LambdaMART ranker = null;
581         int rlStart = -1;
582         int rlEnd = -1;
583         int martStart = -1;
584         int type = -1;
585         
586         //compute score on validation
587         float score = 0;
588         
589         Worker(LambdaMART ranker, int rlStart, int rlEnd)
590         {
591             type = 3;
592             this.ranker = ranker;
593             this.rlStart = rlStart;
594             this.rlEnd = rlEnd;
595         }
596         Worker(LambdaMART ranker, int rlStart, int rlEnd, int martStart)
597         {
598             type = 4;
599             this.ranker = ranker;
600             this.rlStart = rlStart;
601             this.rlEnd = rlEnd;
602             this.martStart = martStart;
603         }
604         public void run()
605         {
606             if(type == 4)
607                 score = ranker.computeModelScoreOnTraining(rlStart, rlEnd, martStart);
608             else if(type == 3)
609                 score = ranker.computeModelScoreOnValidation(rlStart, rlEnd);
610         }
611     }
612 }

 

版权声明:

   本文由笨兔勿应所有,发布于http://www.cnblogs.com/bentuwuying。如果转载,请注明出处,在未经作者同意下将本文用于商业用途,将追究其法律责任。

 

posted @ 2017-04-14 09:16  笨兔勿应  阅读(6211)  评论(2编辑  收藏  举报