关于Additive Ensembles of Regression Trees模型的快速打分预测
一.论文《QuickScorer:a Fast Algorithm to Rank Documents with Additive Ensembles of Regression Trees》是为了解决LTR模型的预测问题,如果LTR中的LambdaMart在生成模型时产生的树数和叶结点过多,在对样本打分预测时会遍历每棵树,这样在线上使用时效率较慢,这篇文章主要就是利用了bitvector方法加速打分预测。代码我找了很久没找到开源的,后来无意中在Solr ltr中看到被改动过了的源码,不过这个源码集成在solr中,这里暂时贴出来,后期再剥离出,集成到ranklib中,以便使用。
二.图片解说
1. Ensemble trees原始打分过程
像gbdt,lambdamart,xgboost或lightgbm等这样的集成树模型在打分预测阶段,比如来了一个样本,这个样本是vector形式输入到每一棵树中,然后在每棵树中像if else这样的过程走到或映射到每棵树的一个节点中,这个节点就是每棵树的打分,然后将每棵树的打分乘上学习率(shrinkage)加和就是此样本的预测分。
2.论文中提到的打分过程
A.为回归树中的每个分枝打上true和false标签
比如图中样本X=[0.2,1.1,0.2],在回归树的branch中判断X[0],X[1],X[2]的true和false,比如图中根结点X[1]<=1.0,但样本X[1]=1.1,所以是false(走左边是true,右边是false),这样将所有branch打上true和false标签(可以直接打上false标志,不用考虑true),后面需要用到所有的false branch。
B.为每个branch分配一个bitvector
这个bitvector中的"0"表示true leaves,比如"001111"表示6个叶结点中的最左边两个叶结点是候选节点。“110011”表示在右子树中true的结点只有中间两个,作为候选结点。
C.打分阶段
此阶段是最后的打分预测阶段,根据前几个图的过程,将所有branch为false的bitvector按位与操作,就会得出样本落在哪个叶结点上。比如图中的结果是"001101",最左边为1的便是最终的叶结点的编号,每个回归树都会这样操作得到预测值,乘上学习率(shrinkage)然后加和就会得到一个样本的预测值。
三.代码
1 import org.apache.lucene.index.LeafReaderContext; 2 import org.apache.lucene.search.Explanation; 3 import org.apache.solr.ltr.feature.Feature; 4 import org.apache.solr.ltr.model.LTRScoringModel; 5 import org.apache.solr.ltr.model.ModelException; 6 import org.apache.solr.ltr.norm.Normalizer; 7 import org.apache.solr.util.SolrPluginUtils; 8 9 import java.util.*; 10 11 public class MultipleAdditiveTreesModel extends LTRScoringModel { 12 13 // 特征名:索引(从0开始) 14 private final HashMap<String, Integer> fname2index = new HashMap(); 15 private List<RegressionTree> trees; 16 17 private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) { 18 MultipleAdditiveTreesModel.RegressionTree rt = new MultipleAdditiveTreesModel.RegressionTree(); 19 if(map != null) { 20 SolrPluginUtils.invokeSetters(rt, map.entrySet()); 21 } 22 23 return rt; 24 } 25 26 private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) { 27 MultipleAdditiveTreesModel.RegressionTreeNode rtn = new MultipleAdditiveTreesModel.RegressionTreeNode(); 28 if(map != null) { 29 SolrPluginUtils.invokeSetters(rtn, map.entrySet()); 30 } 31 32 return rtn; 33 } 34 35 public void setTrees(Object trees) { 36 this.trees = new ArrayList(); 37 Iterator var2 = ((List)trees).iterator(); 38 39 while(var2.hasNext()) { 40 Object o = var2.next(); 41 MultipleAdditiveTreesModel.RegressionTree rt = this.createRegressionTree((Map)o); 42 this.trees.add(rt); 43 } 44 } 45 46 public void setTrees(List<RegressionTree> trees) { 47 this.trees = trees; 48 } 49 50 public List<RegressionTree> getTrees() { 51 return this.trees; 52 } 53 54 public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) { 55 super(name, features, norms, featureStoreName, allFeatures, params); 56 57 for(int i = 0; i < features.size(); ++i) { 58 String key = ((Feature)features.get(i)).getName(); 59 this.fname2index.put(key, Integer.valueOf(i));//特征名:索引 60 } 61 62 } 63 64 public void validate() throws ModelException { 65 super.validate(); 66 if(this.trees == null) { 67 throw new ModelException("no trees declared for model " + this.name); 68 } else { 69 Iterator var1 = this.trees.iterator(); 70 71 while(var1.hasNext()) { 72 MultipleAdditiveTreesModel.RegressionTree tree = (MultipleAdditiveTreesModel.RegressionTree)var1.next(); 73 tree.validate(); 74 } 75 76 } 77 } 78 79 public float score(float[] modelFeatureValuesNormalized) { 80 float score = 0.0F; 81 82 MultipleAdditiveTreesModel.RegressionTree t; 83 for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) { 84 t = (MultipleAdditiveTreesModel.RegressionTree)var3.next(); 85 } 86 87 return score; 88 } 89 90 public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) { 91 float[] fv = new float[featureExplanations.size()]; 92 int index = 0; 93 94 for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) { 95 Explanation featureExplain = (Explanation)details.next(); 96 fv[index] = featureExplain.getValue(); 97 } 98 99 ArrayList var12 = new ArrayList(); 100 index = 0; 101 102 for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) { 103 MultipleAdditiveTreesModel.RegressionTree t = (MultipleAdditiveTreesModel.RegressionTree)var13.next(); 104 float score = t.score(fv); 105 Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]); 106 var12.add(p); 107 } 108 109 return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12); 110 } 111 112 public String toString() { 113 StringBuilder sb = new StringBuilder(this.getClass().getSimpleName()); 114 sb.append("(name=").append(this.getName()); 115 sb.append(",trees=["); 116 117 for(int ii = 0; ii < this.trees.size(); ++ii) { 118 if(ii > 0) { 119 sb.append(','); 120 } 121 122 sb.append(this.trees.get(ii)); 123 } 124 125 sb.append("])"); 126 return sb.toString(); 127 } 128 129 public class RegressionTree { 130 private Float weight; 131 private MultipleAdditiveTreesModel.RegressionTreeNode root; 132 133 public void setWeight(float weight) { 134 this.weight = new Float(weight); 135 } 136 137 public void setWeight(String weight) { 138 this.weight = new Float(weight); 139 } 140 141 public float getWeight() { 142 return this.weight; 143 } 144 145 public void setRoot(Object root) { 146 this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root); 147 } 148 149 public RegressionTreeNode getRoot() { 150 return this.root; 151 } 152 153 public float score(float[] featureVector) { 154 return this.weight.floatValue() * this.root.score(featureVector); 155 } 156 157 public String explain(float[] featureVector) { 158 return this.root.explain(featureVector); 159 } 160 161 public String toString() { 162 StringBuilder sb = new StringBuilder(); 163 sb.append("(weight=").append(this.weight); 164 sb.append(",root=").append(this.root); 165 sb.append(")"); 166 return sb.toString(); 167 } 168 169 public RegressionTree() { 170 } 171 172 public void validate() throws ModelException { 173 if(this.weight == null) { 174 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight"); 175 } else if(this.root == null) { 176 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree"); 177 } else { 178 this.root.validate(); 179 } 180 } 181 } 182 183 public class RegressionTreeNode { 184 private static final float NODE_SPLIT_SLACK = 1.0E-6F; 185 private float value = 0.0F; 186 private String feature; 187 private int featureIndex = -1; 188 private Float threshold; 189 private MultipleAdditiveTreesModel.RegressionTreeNode left; 190 private MultipleAdditiveTreesModel.RegressionTreeNode right; 191 192 public void setValue(float value) { 193 this.value = value; 194 } 195 196 public void setValue(String value) { 197 this.value = Float.parseFloat(value); 198 } 199 200 public void setFeature(String feature) { 201 this.feature = feature; 202 Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature); 203 this.featureIndex = idx == null?-1:idx.intValue(); 204 } 205 206 public int getFeatureIndex() { 207 return this.featureIndex; 208 } 209 210 public void setThreshold(float threshold) { 211 this.threshold = Float.valueOf(threshold + 1.0E-6F); 212 } 213 214 public void setThreshold(String threshold) { 215 this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F); 216 } 217 218 public float getThreshold() { 219 return this.threshold; 220 } 221 222 public void setLeft(Object left) { 223 this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left); 224 } 225 226 public RegressionTreeNode getLeft() { 227 return this.left; 228 } 229 230 public void setRight(Object right) { 231 this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right); 232 } 233 234 public RegressionTreeNode getRight() { 235 return this.right; 236 } 237 238 public boolean isLeaf() { 239 return this.feature == null; 240 } 241 242 public float score(float[] featureVector) { 243 return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F); 244 } 245 246 public String explain(float[] featureVector) { 247 if(this.isLeaf()) { 248 return "val: " + this.value; 249 } else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) { 250 String rval; 251 if(featureVector[this.featureIndex] <= this.threshold.floatValue()) { 252 rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | "; 253 return rval + this.left.explain(featureVector); 254 } else { 255 rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | "; 256 return rval + this.right.explain(featureVector); 257 } 258 } else { 259 return "\'" + this.feature + "\' does not exist in FV, Return Zero"; 260 } 261 } 262 263 public String toString() { 264 StringBuilder sb = new StringBuilder(); 265 if(this.isLeaf()) { 266 sb.append(this.value); 267 } else { 268 sb.append("(feature=").append(this.feature); 269 sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F); 270 sb.append(",left=").append(this.left); 271 sb.append(",right=").append(this.right); 272 sb.append(')'); 273 } 274 275 return sb.toString(); 276 } 277 278 public RegressionTreeNode() { 279 } 280 281 public void validate() throws ModelException { 282 if(this.isLeaf()) { 283 if(this.left != null || this.right != null) { 284 throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right); 285 } 286 } else if(null == this.threshold) { 287 throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold"); 288 } else if(null == this.left) { 289 throw new ModelException("MultipleAdditiveTreesModel tree node is missing left"); 290 } else { 291 this.left.validate(); 292 if(null == this.right) { 293 throw new ModelException("MultipleAdditiveTreesModel tree node is missing right"); 294 } else { 295 this.right.validate(); 296 } 297 } 298 } 299 } 300 301 }
1 import org.apache.commons.lang.ArrayUtils; 2 import org.apache.lucene.util.CloseableThreadLocal; 3 import org.apache.solr.ltr.feature.Feature; 4 import org.apache.solr.ltr.model.ModelException; 5 import org.apache.solr.ltr.norm.Normalizer; 6 7 import java.util.*; 8 9 public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{ 10 11 private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL; 12 13 // 64bits De Bruijn Sequence 14 // see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6) 15 private static final long HASH_BITS = 0x022fdd63cc95386dL; 16 private static final int[] hashTable = new int[64]; 17 18 static { 19 long hash = HASH_BITS; 20 for (int i = 0; i < 64; ++i) { 21 hashTable[(int) (hash >>> 58)] = i; 22 hash <<= 1; 23 } 24 } 25 26 /** 27 * Finds the index of rightmost bit with O(1) by using De Bruijn strategy. 28 * 29 * @param bits target bits (64bits) 30 * @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a> 31 */ 32 private static int findIndexOfRightMostBit(long bits) { 33 return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)]; 34 } 35 36 /** 37 * The number of trees of this model. 38 */ 39 private int treeNum; 40 41 /** 42 * Weights of each tree. 43 */ 44 private float[] weights; 45 46 /** 47 * List of all leaves of this model. 48 * We use tree instead of value to manage wide (i.e., more than 64 leaves) trees. 49 */ 50 private RegressionTreeNode[] leaves; 51 52 /** 53 * Offsets of each leaf block correspond to each tree. 54 */ 55 private int[] leafOffsets; 56 57 /** 58 * The number of conditions of this model. 59 */ 60 private int condNum; 61 62 /** 63 * Thresholds of each condition. 64 * These thresholds are grouped by corresponding feature and each block is sorted by threshold values. 65 */ 66 private float[] thresholds; 67 68 /** 69 * Corresponding featureIndex of each condition. 70 */ 71 private int[] featureIndexes; 72 73 /** 74 * Offsets of each condition block correspond to each feature. 75 */ 76 private int[] condOffsets; 77 78 /** 79 * Forward bitvectors of each condition which correspond to original additive trees. 80 */ 81 private long[] forwardBitVectors; 82 83 /** 84 * Backward bitvectors of each condition which correspond to inverted additive trees. 85 */ 86 private long[] backwardBitVectors; 87 88 /** 89 * Mappings from threasholdes index to tree indexes. 90 */ 91 private int[] treeIds; 92 93 /** 94 * Bitvectors of each tree for calculating the score. 95 * We reuse bitvectors instance in each thread to prevent from re-allocating arrays. 96 */ 97 private CloseableThreadLocal<long[]> threadLocalTreeBitvectors = null; 98 99 /** 100 * Boolean statistical tendency of this model. 101 * If conditions of the model tend to be false, we use inverted bitvectors for speeding up. 102 */ 103 private volatile float falseRatio = 0.5f; 104 105 /** 106 * The decay factor for updating falseRatio in each evaluation step. 107 * This factor is used like "{@code ratio = preRatio * decay ratio * (1 - decay)}". 108 */ 109 private float falseRatioDecay = 0.99f; 110 111 /** 112 * Comparable node cost for selecting leaf candidates. 113 */ 114 private static class NodeCost implements Comparable<NodeCost> { 115 private final int id; 116 private final int cost; 117 private final int depth; 118 private final int left; 119 private final int right; 120 121 private NodeCost(int id, int cost, int depth, int left, int right) { 122 this.id = id; 123 this.cost = cost; 124 this.depth = depth; 125 this.left = left; 126 this.right = right; 127 } 128 129 public int getId() { 130 return id; 131 } 132 133 public int getLeft() { 134 return left; 135 } 136 137 public int getRight() { 138 return right; 139 } 140 141 /** 142 * Sorts by cost and depth. 143 * We prefer cheaper cost and deeper one. 144 */ 145 @Override 146 public int compareTo(NodeCost n) { 147 if (cost != n.cost) { 148 return Integer.compare(cost, n.cost); 149 } else if (depth != n.depth) { 150 return Integer.compare(n.depth, depth); // revere order 151 } else { 152 return Integer.compare(id, n.id); 153 } 154 } 155 } 156 157 /** 158 * Comparable condition for constructing and sorting bitvectors. 159 */ 160 private static class Condition implements Comparable<Condition> { 161 private final int featureIndex; 162 private final float threshold; 163 private final int treeId; 164 private final long forwardBitvector; 165 private final long backwardBitvector; 166 167 private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) { 168 this.featureIndex = featureIndex; 169 this.threshold = threshold; 170 this.treeId = treeId; 171 this.forwardBitvector = forwardBitvector; 172 this.backwardBitvector = backwardBitvector; 173 } 174 175 int getFeatureIndex() { 176 return featureIndex; 177 } 178 179 float getThreshold() { 180 return threshold; 181 } 182 183 int getTreeId() { 184 return treeId; 185 } 186 187 long getForwardBitvector() { 188 return forwardBitvector; 189 } 190 191 long getBackwardBitvector() { 192 return backwardBitvector; 193 } 194 195 /* 196 * Sort by featureIndex and threshold with ascent order. 197 */ 198 @Override 199 public int compareTo(Condition c) { 200 if (featureIndex != c.featureIndex) { 201 return Integer.compare(featureIndex, c.featureIndex); 202 } else { 203 return Float.compare(threshold, c.threshold); 204 } 205 } 206 } 207 208 /** 209 * Base class for traversing node with depth first order. 210 */ 211 private abstract static class Visitor { 212 private int nodeId = 0; 213 214 int getNodeId() { 215 return nodeId; 216 } 217 218 void visit(RegressionTree tree) { 219 nodeId = 0; 220 visit(tree.getRoot(), 0); 221 } 222 223 private void visit(RegressionTreeNode node, int depth) { 224 if (node.isLeaf()) { 225 doVisitLeaf(node, depth); 226 } else { 227 // visit children first 228 visit(node.getLeft(), depth + 1); 229 visit(node.getRight(), depth + 1); 230 231 doVisitBranch(node, depth); 232 } 233 ++nodeId; 234 } 235 236 protected abstract void doVisitLeaf(RegressionTreeNode node, int depth); 237 238 protected abstract void doVisitBranch(RegressionTreeNode node, int depth); 239 } 240 241 /** 242 * {@link Visitor} implementation for calculating the cost of each node. 243 */ 244 private static class NodeCostVisitor extends Visitor { 245 246 private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>(); 247 private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>(); 248 249 PriorityQueue<NodeCost> getNodeCostQueue() { 250 return nodeCostQueue; 251 } 252 253 @Override 254 protected void doVisitLeaf(RegressionTreeNode node, int depth) { 255 nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1)); 256 idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1)); 257 } 258 259 @Override 260 protected void doVisitBranch(RegressionTreeNode node, int depth) { 261 // calculate the cost of this node from children costs 262 final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop(); 263 final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop(); 264 final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue()); 265 266 nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey())); 267 idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1)); 268 } 269 } 270 271 /** 272 * {@link Visitor} implementation for extracting leaves and bitvectors. 273 */ 274 private static class QuickScorerVisitor extends Visitor { 275 276 private final int treeId; 277 private final int leafNum; 278 private final Set<Integer> leafIdSet; 279 private final Set<Integer> skipIdSet; 280 281 private final Stack<Long> bitsStack = new Stack<>(); 282 private final List<RegressionTreeNode> leafList = new ArrayList<>(); 283 private final List<Condition> conditionList = new ArrayList<>(); 284 285 private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) { 286 this.treeId = treeId; 287 this.leafNum = leafNum; 288 this.leafIdSet = leafIdSet; 289 this.skipIdSet = skipIdSet; 290 } 291 292 List<RegressionTreeNode> getLeafList() { 293 return leafList; 294 } 295 296 List<Condition> getConditionList() { 297 return conditionList; 298 } 299 300 private long reverseBits(long bits) { 301 long revBits = 0L; 302 long mask = (1L << (leafNum - 1)); 303 for (int i = 0; i < leafNum; ++i) { 304 if ((bits & mask) != 0L) revBits |= (1L << i); 305 mask >>>= 1; 306 } 307 return revBits; 308 } 309 310 @Override 311 protected void doVisitLeaf(RegressionTreeNode node, int depth) { 312 if (skipIdSet.contains(getNodeId())) return; 313 314 bitsStack.add(1L << leafList.size()); // we use rightmost bit for detecting leaf 315 leafList.add(node); 316 } 317 318 @Override 319 protected void doVisitBranch(RegressionTreeNode node, int depth) { 320 if (skipIdSet.contains(getNodeId())) return; 321 322 if (leafIdSet.contains(getNodeId())) { 323 // an endpoint of QuickScorer 324 doVisitLeaf(node, depth); 325 return; 326 } 327 328 final long rightBits = bitsStack.pop(); // bits of false branch 329 final long leftBits = bitsStack.pop(); // bits of true branch 330 /* 331 * NOTE: 332 * forwardBitvector = ~leftBits 333 * backwardBitvector = ~(reverse(rightBits)) 334 */ 335 conditionList.add( 336 new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits))); 337 bitsStack.add(leftBits | rightBits); 338 } 339 } 340 341 public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, 342 List<Feature> allFeatures, Map<String, Object> params) { 343 super(name, features, norms, featureStoreName, allFeatures, params); 344 } 345 346 /** 347 * Set falseRadioDecay parameter of this model. 348 * 349 * @param falseRatioDecay decay parameter for updating falseRatio 350 */ 351 public void setFalseRatioDecay(float falseRatioDecay) { 352 this.falseRatioDecay = falseRatioDecay; 353 } 354 355 /** 356 * @see #setFalseRatioDecay(float) 357 */ 358 public void setFalseRatioDecay(String falseRatioDecay) { 359 this.falseRatioDecay = Float.parseFloat(falseRatioDecay); 360 } 361 362 /** 363 * {@inheritDoc} 364 */ 365 @Override 366 public void validate() throws ModelException { 367 // validate trees before initializing QuickScorer 368 super.validate(); 369 370 // initialize QuickScorer with validated trees 371 init(getTrees()); 372 } 373 374 /** 375 * Initializes quick scorer with given trees. 376 * 利用给定的树集初始化快速打分模型 377 * 378 * @param trees base additive trees model 379 */ 380 private void init(List<RegressionTree> trees) { 381 this.treeNum = trees.size(); 382 this.weights = new float[trees.size()]; 383 this.leafOffsets = new int[trees.size() + 1]; 384 this.leafOffsets[0] = 0; 385 386 // re-create tree bitvectors 387 if (this.threadLocalTreeBitvectors != null) this.threadLocalTreeBitvectors.close(); 388 this.threadLocalTreeBitvectors = new CloseableThreadLocal<long[]>() { 389 @Override 390 protected long[] initialValue() { 391 return new long[treeNum]; 392 } 393 }; 394 395 int treeId = 0; 396 List<RegressionTreeNode> leafList = new ArrayList<>(); 397 List<Condition> conditionList = new ArrayList<>(); 398 for (RegressionTree tree : trees) { 399 // select up to 64 leaves from given tree 400 QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree); 401 402 // extract leaves and conditions with selected leaf candidates 403 visitor.visit(tree); 404 leafList.addAll(visitor.getLeafList()); 405 conditionList.addAll(visitor.getConditionList()); 406 407 // update weight, offset and treeId 408 this.weights[treeId] = tree.getWeight(); 409 this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size(); 410 ++treeId; 411 } 412 413 // remap list to array for performance reason 414 this.leaves = leafList.toArray(new RegressionTreeNode[0]); 415 416 // sort conditions by ascent order of featureIndex and threshold 417 Collections.sort(conditionList); 418 419 // remap information of conditions 420 int idx = 0; 421 int preFeatureIndex = -1; 422 this.condNum = conditionList.size(); 423 this.thresholds = new float[conditionList.size()]; 424 this.forwardBitVectors = new long[conditionList.size()]; 425 this.backwardBitVectors = new long[conditionList.size()]; 426 this.treeIds = new int[conditionList.size()]; 427 List<Integer> featureIndexList = new ArrayList<>(); 428 List<Integer> condOffsetList = new ArrayList<>(); 429 for (Condition condition : conditionList) { 430 this.thresholds[idx] = condition.threshold; 431 this.forwardBitVectors[idx] = condition.getForwardBitvector(); 432 this.backwardBitVectors[idx] = condition.getBackwardBitvector(); 433 this.treeIds[idx] = condition.getTreeId(); 434 435 if (preFeatureIndex != condition.getFeatureIndex()) { 436 featureIndexList.add(condition.getFeatureIndex()); 437 condOffsetList.add(idx); 438 preFeatureIndex = condition.getFeatureIndex(); 439 } 440 441 ++idx; 442 } 443 condOffsetList.add(conditionList.size()); // guard 444 445 this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0])); 446 this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0])); 447 } 448 449 /** 450 * Checks costs of all nodes and select leaves up to 64. 451 * 452 * <p>NOTE: 453 * We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits. 454 * However, this modification caused performance degradation in our experiments, and we decided to use this form. 455 * 456 * @param treeId index of given regression tree 457 * @param tree target regression tree 458 * @return QuickScorerVisitor with proper id sets 459 */ 460 private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) { 461 // calculate costs of all nodes 462 NodeCostVisitor nodeCostVisitor = new NodeCostVisitor(); 463 nodeCostVisitor.visit(tree); 464 465 // poll zero cost nodes (i.e., real leaves) 466 Set<Integer> leafIdSet = new HashSet<>(); 467 Set<Integer> skipIdSet = new HashSet<>(); 468 while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) { 469 if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break; 470 NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll(); 471 leafIdSet.add(nodeCost.id); 472 } 473 474 // merge leaves until the number of leaves reaches 64 475 while (leafIdSet.size() > 64) { 476 final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll(); 477 assert nodeCost.left >= 0 && nodeCost.right >= 0; 478 479 // update leaves 480 leafIdSet.remove(nodeCost.left); 481 leafIdSet.remove(nodeCost.right); 482 leafIdSet.add(nodeCost.id); 483 484 // register previous leaves to skip ids 485 skipIdSet.add(nodeCost.left); 486 skipIdSet.add(nodeCost.right); 487 } 488 489 return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet); 490 } 491 492 /** 493 * {@inheritDoc} 494 */ 495 @Override 496 public float score(float[] modelFeatureValuesNormalized) { 497 assert threadLocalTreeBitvectors != null; 498 long[] treeBitvectors = threadLocalTreeBitvectors.get(); 499 Arrays.fill(treeBitvectors, MAX_BITS); 500 501 int falseNum = 0; 502 float score = 0.0f; 503 if (falseRatio <= 0.5) { 504 // use forward bitvectors 505 for (int i = 0; i < condOffsets.length - 1; ++i) { 506 final int featureIndex = featureIndexes[i]; 507 for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) { 508 if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break; 509 treeBitvectors[treeIds[j]] &= forwardBitVectors[j]; 510 ++falseNum; 511 } 512 } 513 514 for (int i = 0; i < leafOffsets.length - 1; ++i) { 515 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]); 516 score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized); 517 } 518 } else { 519 // use backward bitvectors 520 falseNum = condNum; 521 for (int i = 0; i < condOffsets.length - 1; ++i) { 522 final int featureIndex = featureIndexes[i]; 523 for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) { 524 if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break; 525 treeBitvectors[treeIds[j]] &= backwardBitVectors[j]; 526 --falseNum; 527 } 528 } 529 530 for (int i = 0; i < leafOffsets.length - 1; ++i) { 531 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]); 532 score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized); 533 } 534 } 535 536 // update false ratio 537 falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay); 538 return score; 539 } 540 541 }
1 import org.apache.lucene.search.IndexSearcher; 2 import org.apache.lucene.search.Query; 3 import org.apache.solr.ltr.feature.Feature; 4 import org.apache.solr.ltr.feature.FeatureException; 5 import org.apache.solr.ltr.norm.IdentityNormalizer; 6 import org.apache.solr.ltr.norm.Normalizer; 7 import org.apache.solr.request.SolrQueryRequest; 8 import org.junit.Ignore; 9 import org.junit.Test; 10 11 import java.io.IOException; 12 import java.util.ArrayList; 13 import java.util.HashMap; 14 import java.util.LinkedHashMap; 15 import java.util.List; 16 import java.util.Map; 17 import java.util.Random; 18 19 import static org.hamcrest.CoreMatchers.is; 20 import static org.junit.Assert.assertThat; 21 22 public class TestQuickScorerTreesModelBenchmark { 23 24 /** 25 * 产生特征 26 * @param featureNum 特征个数 27 * @return 28 */ 29 private List<Feature> createDummyFeatures(int featureNum) { 30 List<Feature> features = new ArrayList<>(); 31 for (int i = 0; i < featureNum; ++i) { 32 features.add(new Feature("fv_" + i, null) { 33 @Override 34 protected void validate() throws FeatureException { } 35 36 @Override 37 public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request, 38 Query originalQuery, Map<String, String[]> efi) throws IOException { 39 return null; 40 } 41 42 @Override 43 public LinkedHashMap<String, Object> paramsToMap() { 44 return null; 45 } 46 }); 47 } 48 return features; 49 } 50 51 private List<Normalizer> createDummyNormalizer(int featureNum) { 52 List<Normalizer> normalizers = new ArrayList<>(); 53 for (int i = 0; i < featureNum; ++i) { 54 normalizers.add(new IdentityNormalizer()); 55 } 56 return normalizers; 57 } 58 59 /** 60 * 创建单棵树 61 * 递归调用自己 62 * @param leafNum 叶子个数 63 * @param features 特征 64 * @param rand 产生随机数 65 * @return 66 */ 67 private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) { 68 Map<String, Object> node = new HashMap<>(); 69 if (leafNum == 1) { 70 // leaf 71 node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 72 return node; 73 } 74 75 // branch 76 node.put("feature", features.get(rand.nextInt(features.size())).getName()); 77 node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 78 node.put("left", createRandomTree(leafNum / 2, features, rand)); 79 node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand)); 80 return node; 81 } 82 83 /** 84 * 这里随机创建多棵树作为model测试 85 * @param treeNum 树的个数 86 * @param leafNum 叶子个数 87 * @param features 特征 88 * @param rand 产生随机数 89 * @return 90 */ 91 private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features, 92 Random rand) { 93 List<Object> trees = new ArrayList<>(); 94 for (int i = 0; i < treeNum; ++i) { 95 Map<String, Object> tree = new HashMap<>(); 96 tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率 97 tree.put("root", createRandomTree(leafNum, features, rand)); 98 trees.add(tree); 99 } 100 return trees; 101 } 102 103 /** 104 * 对比两个打分模型的分值是否一致 105 * @param featureNum 特征个数 106 * @param treeNum 树个数 107 * @param leafNum 叶子个数 108 * @param loopNum 样本个数 109 * @throws Exception 110 */ 111 private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception { 112 Random rand = new Random(0); 113 114 List<Feature> features = createDummyFeatures(featureNum); //产生特征 115 List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化 116 117 for (int i = 0; i < loopNum; ++i) { 118 List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); 119 120 MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms, 121 "dummy", features, null); 122 matModel.setTrees(trees); 123 matModel.validate(); 124 125 QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features, 126 null); 127 qstModel.setTrees(trees);//设置提供的树模型 128 qstModel.validate();//对提供的树结构进行验证 129 130 float[] featureValues = new float[featureNum]; 131 for (int j = 0; j < 100; ++j) { 132 for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5) 133 134 float expected = matModel.score(featureValues); 135 float actual = qstModel.score(featureValues); 136 assertThat(actual, is(expected)); 137 //System.out.println("expected: " + expected + " actual: " + actual); 138 } 139 } 140 } 141 142 /** 143 * 两个模型是否得分一致 144 * 145 * @throws Exception thrown if testcase failed to initialize models 146 */ 147 /*@Test 148 public void testAccuracy() throws Exception { 149 compareScore(25, 200, 32, 100); 150 //compareScore(19, 500, 31, 10000); 151 }*/ 152 153 154 /** 155 * 对比两个打分模型打分的时间消耗 156 * @param featureNum 特征个数 157 * @param treeNum 树个数 158 * @param leafNum 叶子个数 159 * @param loopNum 样本个数 160 * @throws Exception 161 */ 162 private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception { 163 Random rand = new Random(0); 164 165 //随机产生features 166 List<Feature> features = createDummyFeatures(featureNum); 167 //随机产生normalizer 168 List<Normalizer> norms = createDummyNormalizer(featureNum); 169 //随机创建trees 170 List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); 171 172 //初始化multiple additive trees model 173 MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms, 174 "dummy", features, null); 175 matModel.setTrees(trees); 176 matModel.validate(); 177 178 //初始化quick scorer trees model 179 QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features, 180 null); 181 qstModel.setTrees(trees); 182 qstModel.validate(); 183 184 //随机产生样本, loopNum * featureNum 185 float[][] featureValues = new float[loopNum][featureNum]; 186 for (int i = 0; i < loopNum; ++i) { 187 for (int k = 0; k < featureNum; ++k) { 188 featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0) 189 } 190 } 191 192 long start; 193 /*long matOpNsec = 0; 194 for (int i = 0; i < loopNum; ++i) { 195 start = System.nanoTime(); 196 matModel.score(featureValues[i]); 197 matOpNsec += System.nanoTime() - start; 198 } 199 long qstOpNsec = 0; 200 for (int i = 0; i < loopNum; ++i) { 201 start = System.nanoTime(); 202 qstModel.score(featureValues[i]); 203 qstOpNsec += System.nanoTime() - start; 204 } 205 System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op"); 206 System.out.println("QuickScorerTreesModel : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/ 207 208 long matOpNsec = 0; 209 start = System.currentTimeMillis(); 210 for(int i = 0; i < loopNum; i++) { 211 matModel.score(featureValues[i]); 212 } 213 matOpNsec = System.currentTimeMillis() - start; 214 215 long qstOpNsec = 0; 216 start = System.currentTimeMillis(); 217 for(int i = 0; i < loopNum; i++) { 218 qstModel.score(featureValues[i]); 219 } 220 qstOpNsec = System.currentTimeMillis() - start; 221 222 System.out.println("MultipleAdditiveTreesModel : " + matOpNsec); 223 224 System.out.println("QuickScorerTreesModel : " + qstOpNsec); 225 226 //assertThat(matOpNsec > qstOpNsec, is(true)); 227 } 228 229 /** 230 * 测试性能 231 * @throws Exception thrown if testcase failed to initialize models 232 */ 233 234 @Test 235 public void testPerformance() throws Exception { 236 //features,trees,leafs,samples 237 compareTime(20, 500, 61, 10000); 238 } 239 240 }