【模板】【计几】KD树

KD树的作用就是:给你一个点集,然后对这个点集建立一颗KD树,然后可以在logn — 根号n,的范围内查询距离 一个给定点 最近的、第k近的,前k近的点。

建立KD树不仅仅只是两维,它可以多维。假设我们建树到第dep层,那么我们就以第dep % D (D是总维数)为基准,像普通的二叉树那样,选一个分裂点 m,然后 [ l , m-1] 的点的 第dep%D维都是小于 m 这个点,[ m + 1,r ]大于这个点,(用到了nth_element 这个函数)然后往下建树。(其实有另一种方法是按照方差最大的维度为基准,但是代码量++,而且跑的可能还没有这种交替维度来的快。。。。)

然后查询给定点就是:假如到达第dep层,这里的中间节点是 m ,然后我们先对 m 和给定点的距离更新一下答案,然后假如 第 dep % D 维 中,给定点 < m ,就查左子树,反之查右子树。然后假如查了一次子树之后,假设 给定点 第 dep % D 维 和 m 点 差是deta , 那么 如果ans  <= deta * deta ,(这里ans记录距离的平方),那么就不用查另一颗子树了,反之要查。这个和分治法求平面最近点对是一样的,因为另一颗子数的距离肯定比 deta * deta 来得大。

所以,KD树其实就是个暴力,给定点先往和和它接近的平面走,更新完答案再按需求查另一个子树。

而对于第k大,只需要在之前的操作中加一个优先队列,每一次更新答案变成往优先队列里面插入,维护优先队列的大小就可。

 

 

例题:hdu4347(第k近):http://acm.hdu.edu.cn/showproblem.php?pid=4347

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int M = 7;
 5 const int N = 5e4 + 9;
 6 int n,D,cmp_d,K;
 7 struct Point{
 8     ll x[M],dis;
 9     int son[2];
10     void print(){
11         for(int i = 0;i<D;++i) printf("%lld%c",x[i],i == D-1 ? '\n' : ' ');
12     } 
13     bool operator < (const Point& b)const{
14         return dis < b.dis;
15     }
16 }tr[N],ans[N],Q;
17 priority_queue<Point> pq;
18 bool cmp(Point a,Point b){
19     return a.x[cmp_d] < b.x[cmp_d];
20 }
21 ll distance(Point a,Point b){
22     ll res = 0;
23     for(int i = 0; i < D;++i) res += (a.x[i] - b.x[i]) * (a.x[i] - b.x[i]);
24     return res;
25 }
26 int build(int l,int r,int now_d){
27     if( l > r ) return 0;
28     int m = (l+r)>>1;
29     cmp_d = now_d;
30     nth_element(tr+l,tr+m,tr+r+1,cmp);
31     tr[m].son[0] = build(l,m-1,(now_d+1)%D );
32     tr[m].son[1] = build(m+1,r,(now_d+1)%D );
33     return m;
34 }
35 void query(int pos,int now_d){
36     if( !pos ) return;
37     ll cur_dis = distance( tr[pos],Q );
38     tr[pos].dis = cur_dis;
39     ll deta = tr[pos].x[now_d] - Q.x[now_d];
40     int which = (deta < 0);
41     query(tr[pos].son[which],(now_d + 1)%D );
42 
43     if( pq.size() < K ) pq.push(tr[pos]);
44     else{
45         if(cur_dis < pq.top().dis ){
46             pq.pop();
47             pq.push(tr[pos]);
48         }
49     }
50     if( pq.size() < K || pq.top().dis > deta*deta ) query(tr[pos].son[which^1],(now_d+1)%D);
51 }
52 int main(){
53     while(~scanf("%d%d",&n,&D)){
54         for(int i = 1;i<=n;++i){
55             tr[i].son[0] = tr[i].son[1] = 0;
56             for(int j = 0;j<D;++j) scanf("%lld",&tr[i].x[j]);
57         }
58         int root = build(1,n,0);
59         int m; scanf("%d",&m);
60         while(m--){
61             for(int j = 0;j<D;++j) scanf("%lld",&Q.x[j]);
62             scanf("%d",&K);
63             query(root,0);
64             printf("the closest %d points are:\n",K);
65             for(int i = K;i>=1;--i) ans[i] = pq.top(),pq.pop();
66             for(int i = 1;i<=K;++i) ans[i].print();
67         }
68     }
69 
70 }
View Code

 

例题:hdu2966(最近):http://acm.hdu.edu.cn/showproblem.php?pid=2966

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const ll inf = 1000000000000000000;
 5 const int M = 7;
 6 const int N = 1e5 + 9;
 7 int n,D,cmp_d;
 8 ll ans;
 9 struct Point{
10     ll x[M];
11     int son[2];
12     int id;
13     void print(){
14         for(int i = 0;i<D;++i) printf("%lld%c",x[i],i == D-1 ? '\n' : ' ');
15     }
16 }tr[N],Q,query_p[N];
17 bool cmp(Point a,Point b){
18     return a.x[cmp_d] < b.x[cmp_d];
19 }
20 ll distance(Point a,Point b){
21     if( a.id == b.id ) return inf;
22     ll res = 0;
23     for(int i = 0; i < D;++i) res += (a.x[i] - b.x[i]) * (a.x[i] - b.x[i]);
24     return res;
25 }
26 int build(int l,int r,int now_d){
27     if( l > r ) return 0;
28     int m = (l+r)>>1;
29     cmp_d = now_d;
30     nth_element(tr+l,tr+m,tr+r+1,cmp);
31     tr[m].son[0] = build(l,m-1,(now_d+1)%D );
32     tr[m].son[1] = build(m+1,r,(now_d+1)%D );
33     return m;
34 }
35 void query(int pos,int now_d){
36     if( !pos ) return;
37     ll cur_dis = distance( tr[pos],Q );
38     ans = min(cur_dis,ans);
39     ll deta = tr[pos].x[now_d] - Q.x[now_d];
40     int which = (deta < 0);
41     query(tr[pos].son[which],(now_d + 1)%D );
42     if(ans > deta*deta ) query(tr[pos].son[which^1],(now_d+1)%D);
43 }
44 int main(){
45     int T; scanf("%d",&T);
46     D = 2;
47     while(T--){
48         int n; scanf("%d",&n);
49         for(int i = 1;i<=n;++i){
50             tr[i].son[0] = tr[i].son[1] = 0;
51             tr[i].id = i;
52             for(int j = 0;j<D;++j) scanf("%lld",&tr[i].x[j]);
53             query_p[i] = tr[i];
54         }
55         int root = build(1,n,0);
56         for(int i = 1;i<=n;++i){
57             Q = query_p[i];
58             ans = 1000000000000000000;
59             query(root,0);
60             printf("%lld\n",ans);
61         }
62     }
63 }
View Code

 

posted @ 2020-02-05 21:32  小布鞋  阅读(350)  评论(0编辑  收藏  举报