多维空间分割树--KD树

from http://blog.csdn.net/androidlushangderen/article/details/44985259

算法介绍

KD树的全称为k-Dimension Tree的简称,是一种分割K维空间的数据结构,主要应用于关键信息的搜索。为什么说是K维的呢,因为这时候的空间不仅仅是2维度的,他可能是3维,4维度的或者是更多。我们举个例子,如果是二维的空间,对于其中的空间进行分割的就是一条条的分割线,比如说下面这个样子。


如果是3维的呢,那么分割的媒介就是一个平面了,下面是3维空间的分割


这就稍稍有点抽象了,如果是3维以上,我们把这样的分割媒介可以统统叫做超平面 。那么KD树算法有什么特别之处呢,还有他与K-NN算法之间又有什么关系呢,这将是下面所将要描述的。

KNN

KNN就是K最近邻算法,他是一个分类算法,因为算法简单,分类效果也还不错,也被许多人使用着,算法的原理就是选出与给定数据最近的k个数据,然后根据k个数据中占比最多的分类作为测试数据的最终分类。图示如下:


算法固然简单,但是其中通过逐个去比较的办法求得最近的k个数据点,效率太低,时间复杂度会随着训练数据数量的增多而线性增长。于是就需要一种更加高效快速的办法来找到所给查询点的最近邻,而KD树就是其中的一种行之有效的办法。但是不管是KNN算法还是KD树算法,他们都属于相似性查询中的K近邻查询的范畴。在相似性查询算法中还有一类查询是范围查询,就是给定距离阈值和查询点,dbscan算法可以说是一种范围查询,基于给定点进行局部密度范围的搜索。想要了解KNN算法或者是Dbscan算法的可以点击我的K-最近邻算法Dbscan基于密度的聚类算法

KD-Tree

在KNN算法中,针对查询点数据的查找采用的是线性扫描的方法,说白了就是暴力比较,KD树在这方面用了二分划分的思想,将数据进行逐层空间上的划分,大大的提高了查询的速度,可以理解为一个变形的二分搜索时间,只不过这个适用到了多维空间的层次上。下面是二维空间的情况下,数据的划分结果:


现在看到的图在逻辑上的意思就是一棵完整的二叉树,虚线上的点是叶子节点。

KD树的算法原理

KD树的算法的实现原理并不是那么好理解,主要分为树的构建和基于KD树进行最近邻的查询2个过程,后者比前者更加复杂。当然,要想实现最近点的查询,首先我们得先理解KD树的构建过程。下面是KD树节点的定义,摘自百度百科:

 

域名
数据类型
描述
Node-data
数据矢量
数据集中某个数据点,是n维矢量(这里也就是k维)
Range
空间矢量
该节点所代表的空间范围
split
整数
垂直于分割超平面的方向轴序号
Left
k-d树
由位于该节点分割超平面左子空间内所有数据点所构成的k-d树
Right
k-d树
由位于该节点分割超平面右子空间内所有数据点所构成的k-d树
parent
k-d树
父节点

 

变量还是有点多的,节点中有孩子节点和父亲节点,所以必然会用到递归。KD树的构建算法过程如下(这里假设构建的是2维KD树,简单易懂,后续同上):

1、首先将数据节点坐标中的X坐标和Y坐标进行方差计算,选出其中方差大的,作为分割线的方向,就是接下来将要创建点的split值。

2、将上面的数据点按照分割方向的维度进行排序,选出其中的中位数的点作为数据矢量,就是要分割的分割点。

3、同时进行空间矢量的再次划分,要在父亲节点的空间范围内再进行子分割,就是Range变量,不理解的话,可以阅读我的代码加以理解。

4、对剩余的节点进行左侧空间和右侧空间的分割,进行左孩子和右孩子节点的分割。

5、分割的终点是最终只剩下1个数据点或一侧没有数据点的情况。

在这里举个例子,给定6个数据点:

(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)

对这6个数据点进行最终的KD树的构建效果图如下,左边是实际分割效果,右边是所构成的KD树:

       

x,y代表的是当前节点的分割方向。读者可以进行手动计算并验证,本人不再加以描述。

KD树构建完毕,之后就是对于给定查询点数据,进行此空间数据的最近数据点,大致过程如下:

1、从根节点开始,从上往下,根据分割方向,在对应维度的坐标点上,进行树的顺序查找,比如给定(3,1),首先来到(7,2),因为根节点的划分方向为X,因此只比较X坐标的划分,因为3<7,所以往左边走,后续的节点同样的道理,最终到达叶子节点为止。

2、当然以这种方式找到的点并不一定是最近的,也许在父节点的另外一个空间内存在更近的点呢,或者说另外一种情况,当前的叶子节点的父亲节点比叶子节点离查询点更近呢,这也是有可能的。

3、所以这个过程会有回溯的步骤,回溯到父节点时候,需要做2点,第一要和父节点比,谁里查询点更近,如果父节点更近,则更改当前找到的最近点,第二以查询点为圆心,当前查询点与最近点的距离为半径画个圆,判断是否与父节点的分割线是否相交,如果相交,则说明有存在父节点另外的孩子空间存在于查询距离更短的点,然后进行父节点空间的又一次深度优先遍历。在局部的遍历查找完毕,在于当前的最近点做比较,比较完之后,继续往上回溯。

下面给出基于上面例子的2个测试例子,查询点为(2.1,3.1)和(2,4.5),前者的例子用于理解一般过程,后面的测试点真正诠释了递归,回溯的过程。先看下(2.1,3.1)的情况:


因为没有碰到任何的父节点分割边界,所以就一直回溯到根节点,最近的节点就是叶子节点(2,3).下面(2,4.5)是需要重点理解的例子,中间出现了一次回溯,和一次再搜索:


在第一次回溯的时候,发现与y=4碰撞到了,进行了又一次的搜寻,结果发现存在更近的点,因此结果变化了,具体的过程可以详细查看百度百科-kd树对这个例子的描述。

算法的代码实现

许多资料都是只有理论,没有实践,本人基于上面的测试例子,自己写了一个,效果还行,基本上实现了上述的过程,不过貌似Range这个变量没有表现出用途来,可以我一番设计,例子完全是上面的例子,输入数据就不放出来了,就是给定的6个坐标点。

坐标点类Point.java:

 

[java] view plain copy
 
  1. package DataMining_KDTree;  
  2.   
  3. /** 
  4.  * 坐标点类 
  5.  *  
  6.  * @author lyq 
  7.  *  
  8.  */  
  9. public class Point{  
  10.     // 坐标点横坐标  
  11.     Double x;  
  12.     // 坐标点纵坐标  
  13.     Double y;  
  14.   
  15.     public Point(double x, double y){  
  16.         this.x = x;  
  17.         this.y = y;  
  18.     }  
  19.       
  20.     public Point(String x, String y) {  
  21.         this.x = (Double.parseDouble(x));  
  22.         this.y = (Double.parseDouble(y));  
  23.     }  
  24.   
  25.     /** 
  26.      * 计算当前点与制定点之间的欧式距离 
  27.      *  
  28.      * @param p 
  29.      *            待计算聚类的p点 
  30.      * @return 
  31.      */  
  32.     public double ouDistance(Point p) {  
  33.         double distance = 0;  
  34.   
  35.         distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)  
  36.                 * (this.y - p.y);  
  37.         distance = Math.sqrt(distance);  
  38.   
  39.         return distance;  
  40.     }  
  41.   
  42.     /** 
  43.      * 判断2个坐标点是否为用个坐标点 
  44.      *  
  45.      * @param p 
  46.      *            待比较坐标点 
  47.      * @return 
  48.      */  
  49.     public boolean isTheSame(Point p) {  
  50.         boolean isSamed = false;  
  51.   
  52.         if (this.x == p.x && this.y == p.y) {  
  53.             isSamed = true;  
  54.         }  
  55.   
  56.         return isSamed;  
  57.     }  
  58. }  

空间矢量类Range.java:

 

 

[java] view plain copy
 
  1. package DataMining_KDTree;  
  2.   
  3. /** 
  4.  * 空间矢量,表示所代表的空间范围 
  5.  *  
  6.  * @author lyq 
  7.  *  
  8.  */  
  9. public class Range {  
  10.     // 边界左边界  
  11.     double left;  
  12.     // 边界右边界  
  13.     double right;  
  14.     // 边界上边界  
  15.     double top;  
  16.     // 边界下边界  
  17.     double bottom;  
  18.   
  19.     public Range() {  
  20.         this.left = -Integer.MAX_VALUE;  
  21.         this.right = Integer.MAX_VALUE;  
  22.         this.top = Integer.MAX_VALUE;  
  23.         this.bottom = -Integer.MAX_VALUE;  
  24.     }  
  25.   
  26.     public Range(int left, int right, int top, int bottom) {  
  27.         this.left = left;  
  28.         this.right = right;  
  29.         this.top = top;  
  30.         this.bottom = bottom;  
  31.     }  
  32.   
  33.     /** 
  34.      * 空间矢量进行并操作 
  35.      *  
  36.      * @param range 
  37.      * @return 
  38.      */  
  39.     public Range crossOperation(Range r) {  
  40.         Range range = new Range();  
  41.   
  42.         // 取靠近右侧的左边界  
  43.         if (r.left > this.left) {  
  44.             range.left = r.left;  
  45.         } else {  
  46.             range.left = this.left;  
  47.         }  
  48.   
  49.         // 取靠近左侧的右边界  
  50.         if (r.right < this.right) {  
  51.             range.right = r.right;  
  52.         } else {  
  53.             range.right = this.right;  
  54.         }  
  55.   
  56.         // 取靠近下侧的上边界  
  57.         if (r.top < this.top) {  
  58.             range.top = r.top;  
  59.         } else {  
  60.             range.top = this.top;  
  61.         }  
  62.   
  63.         // 取靠近上侧的下边界  
  64.         if (r.bottom > this.bottom) {  
  65.             range.bottom = r.bottom;  
  66.         } else {  
  67.             range.bottom = this.bottom;  
  68.         }  
  69.   
  70.         return range;  
  71.     }  
  72.   
  73.     /** 
  74.      * 根据坐标点分割方向确定左侧空间矢量 
  75.      *  
  76.      * @param p 
  77.      *            数据矢量 
  78.      * @param dir 
  79.      *            分割方向 
  80.      * @return 
  81.      */  
  82.     public static Range initLeftRange(Point p, int dir) {  
  83.         Range range = new Range();  
  84.   
  85.         if (dir == KDTreeTool.DIRECTION_X) {  
  86.             range.right = p.x;  
  87.         } else {  
  88.             range.bottom = p.y;  
  89.         }  
  90.   
  91.         return range;  
  92.     }  
  93.   
  94.     /** 
  95.      * 根据坐标点分割方向确定右侧空间矢量 
  96.      *  
  97.      * @param p 
  98.      *            数据矢量 
  99.      * @param dir 
  100.      *            分割方向 
  101.      * @return 
  102.      */  
  103.     public static Range initRightRange(Point p, int dir) {  
  104.         Range range = new Range();  
  105.   
  106.         if (dir == KDTreeTool.DIRECTION_X) {  
  107.             range.left = p.x;  
  108.         } else {  
  109.             range.top = p.y;  
  110.         }  
  111.   
  112.         return range;  
  113.     }  
  114. }  

KD树节点类TreeNode.java:

 

 

[java] view plain copy
 
  1. package DataMining_KDTree;  
  2.   
  3. /** 
  4.  * KD树节点 
  5.  * @author lyq 
  6.  * 
  7.  */  
  8. public class TreeNode {  
  9.     //数据矢量  
  10.     Point nodeData;  
  11.     //分割平面的分割线  
  12.     int spilt;  
  13.     //空间矢量,该节点所表示的空间范围  
  14.     Range range;  
  15.     //父节点  
  16.     TreeNode parentNode;  
  17.     //位于分割超平面左侧的孩子节点  
  18.     TreeNode leftNode;  
  19.     //位于分割超平面右侧的孩子节点  
  20.     TreeNode rightNode;  
  21.     //节点是否被访问过,用于回溯时使用  
  22.     boolean isVisited;  
  23.       
  24.     public TreeNode(){  
  25.         this.isVisited = false;  
  26.     }  
  27. }  

算法封装类KDTreeTool.java:

 

 

[java] view plain copy
 
  1. package DataMining_KDTree;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.Collections;  
  9. import java.util.Comparator;  
  10. import java.util.Stack;  
  11.   
  12. /** 
  13.  * KD树-k维空间关键数据检索算法工具类 
  14.  *  
  15.  * @author lyq 
  16.  *  
  17.  */  
  18. public class KDTreeTool {  
  19.     // 空间平面的方向  
  20.     public static final int DIRECTION_X = 0;  
  21.     public static final int DIRECTION_Y = 1;  
  22.   
  23.     // 输入的测试数据坐标点文件  
  24.     private String filePath;  
  25.     // 原始所有数据点数据  
  26.     private ArrayList<Point> totalDatas;  
  27.     // KD树根节点  
  28.     private TreeNode rootNode;  
  29.   
  30.     public KDTreeTool(String filePath) {  
  31.         this.filePath = filePath;  
  32.   
  33.         readDataFile();  
  34.     }  
  35.   
  36.     /** 
  37.      * 从文件中读取数据 
  38.      */  
  39.     private void readDataFile() {  
  40.         File file = new File(filePath);  
  41.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  42.   
  43.         try {  
  44.             BufferedReader in = new BufferedReader(new FileReader(file));  
  45.             String str;  
  46.             String[] tempArray;  
  47.             while ((str = in.readLine()) != null) {  
  48.                 tempArray = str.split(" ");  
  49.                 dataArray.add(tempArray);  
  50.             }  
  51.             in.close();  
  52.         } catch (IOException e) {  
  53.             e.getStackTrace();  
  54.         }  
  55.   
  56.         Point p;  
  57.         totalDatas = new ArrayList<>();  
  58.         for (String[] array : dataArray) {  
  59.             p = new Point(array[0], array[1]);  
  60.             totalDatas.add(p);  
  61.         }  
  62.     }  
  63.   
  64.     /** 
  65.      * 创建KD树 
  66.      *  
  67.      * @return 
  68.      */  
  69.     public TreeNode createKDTree() {  
  70.         ArrayList<Point> copyDatas;  
  71.   
  72.         rootNode = new TreeNode();  
  73.         // 根据节点开始时所表示的空间时无限大的  
  74.         rootNode.range = new Range();  
  75.         copyDatas = (ArrayList<Point>) totalDatas.clone();  
  76.         recusiveConstructNode(rootNode, copyDatas);  
  77.   
  78.         return rootNode;  
  79.     }  
  80.   
  81.     /** 
  82.      * 递归进行KD树的构造 
  83.      *  
  84.      * @param node 
  85.      *            当前正在构造的节点 
  86.      * @param datas 
  87.      *            该节点对应的正在处理的数据 
  88.      * @return 
  89.      */  
  90.     private void recusiveConstructNode(TreeNode node, ArrayList<Point> datas) {  
  91.         int direction = 0;  
  92.         ArrayList<Point> leftSideDatas;  
  93.         ArrayList<Point> rightSideDatas;  
  94.         Point p;  
  95.         TreeNode leftNode;  
  96.         TreeNode rightNode;  
  97.         Range range;  
  98.         Range range2;  
  99.   
  100.         // 如果划分的数据点集合只有1个数据,则不再划分  
  101.         if (datas.size() == 1) {  
  102.             node.nodeData = datas.get(0);  
  103.             return;  
  104.         }  
  105.   
  106.         // 首先在当前的数据点集合中进行分割方向的选择  
  107.         direction = selectSplitDrc(datas);  
  108.         // 根据方向取出中位数点作为数据矢量  
  109.         p = getMiddlePoint(datas, direction);  
  110.   
  111.         node.spilt = direction;  
  112.         node.nodeData = p;  
  113.   
  114.         leftSideDatas = getLeftSideDatas(datas, p, direction);  
  115.         datas.removeAll(leftSideDatas);  
  116.         // 还要去掉自身  
  117.         datas.remove(p);  
  118.         rightSideDatas = datas;  
  119.   
  120.         if (leftSideDatas.size() > 0) {  
  121.             leftNode = new TreeNode();  
  122.             leftNode.parentNode = node;  
  123.             range2 = Range.initLeftRange(p, direction);  
  124.             // 获取父节点的空间矢量,进行交集运算做范围拆分  
  125.             range = node.range.crossOperation(range2);  
  126.             leftNode.range = range;  
  127.   
  128.             node.leftNode = leftNode;  
  129.             recusiveConstructNode(leftNode, leftSideDatas);  
  130.         }  
  131.   
  132.         if (rightSideDatas.size() > 0) {  
  133.             rightNode = new TreeNode();  
  134.             rightNode.parentNode = node;  
  135.             range2 = Range.initRightRange(p, direction);  
  136.             // 获取父节点的空间矢量,进行交集运算做范围拆分  
  137.             range = node.range.crossOperation(range2);  
  138.             rightNode.range = range;  
  139.   
  140.             node.rightNode = rightNode;  
  141.             recusiveConstructNode(rightNode, rightSideDatas);  
  142.         }  
  143.     }  
  144.   
  145.     /** 
  146.      * 搜索出给定数据点的最近点 
  147.      *  
  148.      * @param p 
  149.      *            待比较坐标点 
  150.      */  
  151.     public Point searchNearestData(Point p) {  
  152.         // 节点距离给定数据点的距离  
  153.         TreeNode nearestNode = null;  
  154.         // 用栈记录遍历过的节点  
  155.         Stack<TreeNode> stackNodes;  
  156.   
  157.         stackNodes = new Stack<>();  
  158.         findedNearestLeafNode(p, rootNode, stackNodes);  
  159.   
  160.         // 取出叶子节点,作为当前找到的最近节点  
  161.         nearestNode = stackNodes.pop();  
  162.         nearestNode = dfsSearchNodes(stackNodes, p, nearestNode);  
  163.   
  164.         return nearestNode.nodeData;  
  165.     }  
  166.   
  167.     /** 
  168.      * 深度优先的方式进行最近点的查找 
  169.      *  
  170.      * @param stack 
  171.      *            KD树节点栈 
  172.      * @param desPoint 
  173.      *            给定的数据点 
  174.      * @param nearestNode 
  175.      *            当前找到的最近节点 
  176.      * @return 
  177.      */  
  178.     private TreeNode dfsSearchNodes(Stack<TreeNode> stack, Point desPoint,  
  179.             TreeNode nearestNode) {  
  180.         // 是否碰到父节点边界  
  181.         boolean isCollision;  
  182.         double minDis;  
  183.         double dis;  
  184.         TreeNode parentNode;  
  185.   
  186.         // 如果栈内节点已经全部弹出,则遍历结束  
  187.         if (stack.isEmpty()) {  
  188.             return nearestNode;  
  189.         }  
  190.   
  191.         // 获取父节点  
  192.         parentNode = stack.pop();  
  193.   
  194.         minDis = desPoint.ouDistance(nearestNode.nodeData);  
  195.         dis = desPoint.ouDistance(parentNode.nodeData);  
  196.   
  197.         // 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新  
  198.         if (dis < minDis) {  
  199.             minDis = dis;  
  200.             nearestNode = parentNode;  
  201.         }  
  202.   
  203.         // 默认没有碰撞到  
  204.         isCollision = false;  
  205.         // 判断是否触碰到了父节点的空间分割线  
  206.         if (parentNode.spilt == DIRECTION_X) {  
  207.             if (parentNode.nodeData.x > desPoint.x - minDis  
  208.                     && parentNode.nodeData.x < desPoint.x + minDis) {  
  209.                 isCollision = true;  
  210.             }  
  211.         } else {  
  212.             if (parentNode.nodeData.y > desPoint.y - minDis  
  213.                     && parentNode.nodeData.y < desPoint.y + minDis) {  
  214.                 isCollision = true;  
  215.             }  
  216.         }  
  217.   
  218.         // 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历  
  219.         if (isCollision  
  220.                 && (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) {  
  221.             TreeNode newNode;  
  222.             // 新建当前的小局部节点栈  
  223.             Stack<TreeNode> otherStack = new Stack<>();  
  224.             // 从parentNode的树以下继续寻找  
  225.             findedNearestLeafNode(desPoint, parentNode, otherStack);  
  226.             newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop());  
  227.   
  228.             dis = newNode.nodeData.ouDistance(desPoint);  
  229.             if (dis < minDis) {  
  230.                 nearestNode = newNode;  
  231.             }  
  232.         }  
  233.   
  234.         // 继续往上回溯  
  235.         nearestNode = dfsSearchNodes(stack, desPoint, nearestNode);  
  236.   
  237.         return nearestNode;  
  238.     }  
  239.   
  240.     /** 
  241.      * 找到与所给定节点的最近的叶子节点 
  242.      *  
  243.      * @param p 
  244.      *            待比较节点 
  245.      * @param node 
  246.      *            当前搜索到的节点 
  247.      * @param stack 
  248.      *            遍历过的节点栈 
  249.      */  
  250.     private void findedNearestLeafNode(Point p, TreeNode node,  
  251.             Stack<TreeNode> stack) {  
  252.         // 分割方向  
  253.         int splitDic;  
  254.   
  255.         // 将遍历过的节点加入栈中  
  256.         stack.push(node);  
  257.         // 标记为访问过  
  258.         node.isVisited = true;  
  259.         // 如果此节点没有左右孩子节点说明已经是叶子节点了  
  260.         if (node.leftNode == null && node.rightNode == null) {  
  261.             return;  
  262.         }  
  263.   
  264.         splitDic = node.spilt;  
  265.         // 选择一个符合分割范围的节点继续递归搜寻  
  266.         if ((splitDic == DIRECTION_X && p.x < node.nodeData.x)  
  267.                 || (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) {  
  268.             if (!node.leftNode.isVisited) {  
  269.                 findedNearestLeafNode(p, node.leftNode, stack);  
  270.             } else {  
  271.                 // 如果左孩子节点已经访问过,则访问另一边  
  272.                 findedNearestLeafNode(p, node.rightNode, stack);  
  273.             }  
  274.         } else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x)  
  275.                 || (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) {  
  276.             if (!node.rightNode.isVisited) {  
  277.                 findedNearestLeafNode(p, node.rightNode, stack);  
  278.             } else {  
  279.                 // 如果右孩子节点已经访问过,则访问另一边  
  280.                 findedNearestLeafNode(p, node.leftNode, stack);  
  281.             }  
  282.         }  
  283.     }  
  284.   
  285.     /** 
  286.      * 根据给定的数据点通过计算反差选择的分割点 
  287.      *  
  288.      * @param datas 
  289.      *            部分的集合点集合 
  290.      * @return 
  291.      */  
  292.     private int selectSplitDrc(ArrayList<Point> datas) {  
  293.         int direction = 0;  
  294.         double avgX = 0;  
  295.         double avgY = 0;  
  296.         double varianceX = 0;  
  297.         double varianceY = 0;  
  298.   
  299.         for (Point p : datas) {  
  300.             avgX += p.x;  
  301.             avgY += p.y;  
  302.         }  
  303.   
  304.         avgX /= datas.size();  
  305.         avgY /= datas.size();  
  306.   
  307.         for (Point p : datas) {  
  308.             varianceX += (p.x - avgX) * (p.x - avgX);  
  309.             varianceY += (p.y - avgY) * (p.y - avgY);  
  310.         }  
  311.   
  312.         // 求最后的方差  
  313.         varianceX /= datas.size();  
  314.         varianceY /= datas.size();  
  315.   
  316.         // 通过比较方差的大小决定分割方向,选择波动较大的进行划分  
  317.         direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y;  
  318.   
  319.         return direction;  
  320.     }  
  321.   
  322.     /** 
  323.      * 根据坐标点方位进行排序,选出中间点的坐标数据 
  324.      *  
  325.      * @param datas 
  326.      *            数据点集合 
  327.      * @param dir 
  328.      *            排序的坐标方向 
  329.      */  
  330.     private Point getMiddlePoint(ArrayList<Point> datas, int dir) {  
  331.         int index = 0;  
  332.         Point middlePoint;  
  333.   
  334.         index = datas.size() / 2;  
  335.         if (dir == DIRECTION_X) {  
  336.             Collections.sort(datas, new Comparator<Point>() {  
  337.   
  338.                 @Override  
  339.                 public int compare(Point o1, Point o2) {  
  340.                     // TODO Auto-generated method stub  
  341.                     return o1.x.compareTo(o2.x);  
  342.                 }  
  343.             });  
  344.         } else {  
  345.             Collections.sort(datas, new Comparator<Point>() {  
  346.   
  347.                 @Override  
  348.                 public int compare(Point o1, Point o2) {  
  349.                     // TODO Auto-generated method stub  
  350.                     return o1.y.compareTo(o2.y);  
  351.                 }  
  352.             });  
  353.         }  
  354.   
  355.         // 取出中位数  
  356.         middlePoint = datas.get(index);  
  357.   
  358.         return middlePoint;  
  359.     }  
  360.   
  361.     /** 
  362.      * 根据方向得到原部分节点集合左侧的数据点 
  363.      *  
  364.      * @param datas 
  365.      *            原始数据点集合 
  366.      * @param nodeData 
  367.      *            数据矢量 
  368.      * @param dir 
  369.      *            分割方向 
  370.      * @return 
  371.      */  
  372.     private ArrayList<Point> getLeftSideDatas(ArrayList<Point> datas,  
  373.             Point nodeData, int dir) {  
  374.         ArrayList<Point> leftSideDatas = new ArrayList<>();  
  375.   
  376.         for (Point p : datas) {  
  377.             if (dir == DIRECTION_X && p.x < nodeData.x) {  
  378.                 leftSideDatas.add(p);  
  379.             } else if (dir == DIRECTION_Y && p.y < nodeData.y) {  
  380.                 leftSideDatas.add(p);  
  381.             }  
  382.         }  
  383.   
  384.         return leftSideDatas;  
  385.     }  
  386. }  

场景测试类Client.java:

[java] view plain copy
 
  1. package DataMining_KDTree;  
  2.   
  3. import java.text.MessageFormat;  
  4.   
  5. /** 
  6.  * KD树算法测试类 
  7.  *  
  8.  * @author lyq 
  9.  *  
  10.  */  
  11. public class Client {  
  12.     public static void main(String[] args) {  
  13.         String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";  
  14.         Point queryNode;  
  15.         Point searchedNode;  
  16.         KDTreeTool tool = new KDTreeTool(filePath);  
  17.   
  18.         // 进行KD树的构建  
  19.         tool.createKDTree();  
  20.   
  21.         // 通过KD树进行数据点的最近点查询  
  22.         queryNode = new Point(2.1, 3.1);  
  23.         searchedNode = tool.searchNearestData(queryNode);  
  24.         System.out.println(MessageFormat.format(  
  25.                 "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y,  
  26.                 searchedNode.x, searchedNode.y));  
  27.           
  28.         //重新构造KD树,去除之前的访问记录  
  29.         tool.createKDTree();  
  30.         queryNode = new Point(2, 4.5);  
  31.         searchedNode = tool.searchNearestData(queryNode);  
  32.         System.out.println(MessageFormat.format(  
  33.                 "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y,  
  34.                 searchedNode.x, searchedNode.y));  
  35.     }  
  36. }  

算法的输出结果:

 
  1. 距离查询点(2.1, 3.1)最近的坐标点为(2, 3)  
  2. 距离查询点(2, 4.5)最近的坐标点为(2, 3)  

算法的输出结果与期望值还是一致的。

目前KD-Tree的使用场景是SIFT算法做特征点匹配的时候使用到了,特征点匹配指的是通过距离函数在高维矢量空间进行相似性检索。

posted @ 2017-09-22 22:41  princessd8251  阅读(663)  评论(1)    收藏  举报