[hdu4347]The Closest M Points(线段树形式kd-tree)

解题关键:kdtree模板题,距离某点最近的m个点。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<iostream>
#include<cmath>
#include<queue>
#define sq(x) (x)*(x)
using namespace std;
typedef long long ll;
const int N=55555;

int idx,k,n,m,q;
struct node{
    int x[5];
    bool operator<(const node &u)const{
        return x[idx]<u.x[idx];
    }
}P[N];
//线段树形式kd-tree
typedef pair<double,node>PDN;
priority_queue<PDN>que;
struct KD_Tree{
    int sz[N<<2];
    node p[N<<2];
    void build(int rt,int l,int r,int dep){
        if(l>r) return;
        int mid=(l+r)>>1;
        idx=dep%k;sz[rt]=r-l;
        sz[rt<<1]=sz[rt<<1|1]=-1;
        nth_element(P+l,P+mid,P+r+1);
        p[rt]=P[mid];
        build(rt<<1,l,mid-1,dep+1);
        build(rt<<1|1,mid+1,r,dep+1);
    }

    void query(int i,int m,int dep,node a){//寻找m个。
        if(sz[i]==-1) return;
        PDN tmp=PDN(0,p[i]);//新建的tmp,first更新为0了
        for(int j=0;j<k;j++)
            tmp.first+=sq(tmp.second.x[j]-a.x[j]);//距离目标点的距离
        int lc=i<<1,rc=i<<1|1,dim=dep%k,flag=0;
        if(a.x[dim]>=p[i].x[dim]) swap(lc,rc);
        if(~sz[lc]) query(lc,m,dep+1,a);//~sz,sz!=-1
        if(que.size()<m) que.push(tmp),flag=1;
        else{
            if(tmp.first<que.top().first) que.pop(),que.push(tmp);
            if(sq(a.x[dim]-p[i].x[dim])<que.top().first) flag=1;//求到面的距离,空间相交
        }
        if(~sz[rc]&&flag) query(rc,m,dep+1,a);
    }
}KDT;
 
int main(){
    while(~scanf("%d%d",&n,&k)){
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                scanf("%d",&P[i].x[j]);
        KDT.build(1,0,n-1,0);
        scanf("%d",&q);
        while(q--){
            node now;
            for(int i=0;i<k;i++)
                scanf("%d",&now.x[i]);
            scanf("%d",&m);
            KDT.query(1,m,0,now);
            node pp[21];
            int t=0;
            while(!que.empty()){
                pp[++t]=que.top().second;
                que.pop();
            }
            printf("the closest %d points are:\n",m);
            for(int i=m;i>0;i--) for(int j=0;j<k;j++) printf("%d%c",pp[i].x[j],j==k-1?'\n':' ');
        }
    }
    return 0;
}

 

posted @ 2019-03-07 16:06  Elpsywk  阅读(270)  评论(0编辑  收藏  举报