kd树 C++实现

参考:百科kd-tree

  1 /*
  2  * kdtree.h
  3  *
  4  *  Created on: Mar 3, 2017
  5  *      Author: wxquare
  6  */
  7 
  8 #ifndef KDTREE_H_
  9 #define KDTREE_H_
 10 
 11 #include <vector>
 12 #include <cmath>
 13 #include <algorithm>
 14 #include <iostream>
 15 #include <stack>
 16 
 17 template<typename T>
 18 class KdTree {
 19     struct kdNode {
 20         std::vector<T> vec;  //data
 21         //split attribute,-1 means leftNode,no split attribute
 22         int splitAttribute;
 23         kdNode* lChild;
 24         kdNode* rChild;
 25         kdNode* parent;
 26 
 27         kdNode(std::vector<T> v = { }, int split = 0, kdNode* lch = nullptr,
 28                 kdNode* rch = nullptr, kdNode* par = nullptr) :
 29                 vec(v), splitAttribute(split), lChild(lch), rChild(rch), parent(par) {}
 30     };
 31 
 32 private:
 33     kdNode *root;
 34 
 35 public:
 36     KdTree() {
 37         root = nullptr;
 38     }
 39 
 40     KdTree(std::vector<std::vector<T>>& data) {
 41         root = createKdTree(data);
 42     }
 43 
 44 
 45     //matrix transpose
 46     std::vector<std::vector<T>> transpose(std::vector<std::vector<T>>& data) {
 47         int m = data.size();
 48         int n = data[0].size();
 49         std::vector<std::vector<T>> trans(n, std::vector<T>(m, 0));
 50         for (int i = 0; i < n; i++) {
 51             for (int j = 0; j < m; j++) {
 52                 trans[i][j] = data[j][i];
 53             }
 54         }
 55         return trans;
 56     }
 57 
 58     //get variance of a vector
 59     double getVariance(std::vector<T>& vec) {
 60         int n = vec.size();
 61         double sum = 0;
 62         for (int i = 0; i < n; i++) {
 63             sum = sum + vec[i];
 64         }
 65         double avg = sum / n;
 66         sum = 0; //sum of squaNN
 67         for (int i = 0; i < n; i++) {
 68             sum += pow(vec[i] - avg, 2); //#include<cmath>
 69         }
 70         return sum / n;
 71     }
 72 
 73     //According to maximum variance get split attribute.
 74     int getSplitAttribute(const std::vector<std::vector<T>>& data) {
 75         int k = data.size();
 76         int splitAttribute = 0;
 77         double maxVar = getVariance(data[0]);
 78         for (int i = 1; i < k; i++) {
 79             double temp = getVariance(data[i]);
 80             if (temp > maxVar) {
 81                 splitAttribute = i;
 82                 maxVar = temp;
 83             }
 84         }
 85         return splitAttribute;
 86     }
 87 
 88     //find middle value
 89     T getSplitValue(std::vector<T>& vec) {
 90         std::sort(vec.begin(), vec.end());
 91         return vec[vec.size() / 2];
 92     }
 93 
 94     //compute distance of two vector
 95     static double getDistance(std::vector<T>& v1, std::vector<T>& v2) {
 96         double sum = 0;
 97         for (size_t i = 0; i < v1.size(); i++) {
 98             sum += pow(v1[i] - v2[i], 2);
 99         }
100         return sqrt(sum) / v1.size();
101     }
102 
103     kdNode* createKdTree(std::vector<std::vector<T>>& data) {
104         //the number of samples(data)
105         if (data.empty()) return nullptr;
106         int n = data.size();
107         if (n == 1) {
108             return new kdNode(data[0], -1); //叶子节点
109         }
110 
111         //get split attribute and value
112         std::vector<std::vector<T>> data_T = transpose(data);
113         int splitAttribute = getSplitAttribute(data_T);
114         int splitValue = getSplitValue(data_T[splitAttribute]);
115 
116         //split data according splitAttribute and splitValue
117         std::vector<std::vector<T>> left;
118         std::vector<std::vector<T>> right;
119 
120         int flag = 0; //the first sample's splitValue become splitnode
121         kdNode *splitNode;
122         for (int i = 0; i < n; i++) {
123             if (flag == 0 && data[i][splitAttribute] == splitValue) {
124                 splitNode = new kdNode(data[i]);
125                 splitNode->splitAttribute = splitAttribute;
126                 flag = 1;
127                 continue;
128             }
129             if (data[i][splitAttribute] <= splitValue) {
130                 left.push_back(data[i]);
131             } else {
132                 right.push_back(data[i]);
133             }
134         }
135 
136         splitNode->lChild = createKdTree(left);
137         splitNode->rChild = createKdTree(right);
138         return splitNode;
139     }
140 
141     //search nearest neighbor
142     /* 参考百度百科
143      * 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。
144        如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。
145        然后通过stack回溯:
146        如果当前点的距离比最近邻点距离近,更新最近邻节点.
147        然后检查以最近距离为半径的圆是否和父节点的超平面相交.
148        如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。
149        如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中.
150        当搜索回到root节点时,搜索完成,得到最近邻节点。
151      */
152     std::vector<T> searchNearestNeighbor(std::vector<T>& target,kdNode* start) {
153         std::vector<T> NN;
154         std::stack<kdNode*> searchPath;
155         kdNode* p = start;
156         while (p->splitAttribute != -1) {
157             searchPath.push(p);
158             int splitAttribute = p->splitAttribute;
159             if (target[splitAttribute] <= p->vec[splitAttribute]) {
160                 p = p->lChild;
161             } else {
162                 p = p->rChild;
163             }
164         }
165         NN = p->vec;
166         double mindis = KdTree::getDistance(target, NN);
167 
168         kdNode* cur;
169         double dis;
170         while (!searchPath.empty()) {
171             cur = searchPath.top();
172             searchPath.pop();
173             dis = KdTree::getDistance(target, cur->vec);
174             if (dis < mindis) {
175                 mindis = dis;
176                 NN = cur->vec;
177                 //判断以target为中心,以dis为半径的球是否和节点的超平面相交
178                 if (cur->vec[cur->splitAttribute]
179                         >= target[cur->splitAttribute] - dis
180                         && cur->vec[cur->splitAttribute]
181                                 <= target[cur->splitAttribute] + dis) {
182                     std::vector<T> nn = searchNearestNeighbor(target,
183                             cur->lChild);
184                     if (KdTree::getDistance(target, nn)
185                             < KdTree::getDistance(target, NN)) {
186                         NN = nn;
187                     }
188                 }
189             }
190         }
191         return NN;
192     }
193 
194     std::vector<T> searchNearestNeighbor(std::vector<T>& target) {
195         std::vector<T> NN;
196         NN = searchNearestNeighbor(target, root);
197         return NN;
198     }
199 
200     void print(kdNode* root) {
201         std::cout << "[";
202         if (root->lChild) {
203             std::cout << "left:";
204             print(root->lChild);
205         }
206 
207         if (root) {
208             std::cout << "(";
209             for (size_t i = 0; i < root->vec.size(); i++) {
210                 std::cout << root->vec[i];
211                 if (i != (root->vec.size() - 1))
212                     std::cout << ",";
213             }
214             std::cout << ")";
215         }
216 
217         if (root->rChild) {
218             std::cout << "right:";
219             print(root->rChild);
220         }
221         std::cout << "]";
222     }
223 
224 };
225 
226 #endif /* KDTREE_H_ */

 

posted @ 2017-03-03 16:26  wxquare  阅读(1057)  评论(0编辑  收藏  举报