hud 4347 The Closest M Points(KD-Tree)

传送门

解题思路

  \(KD-Tree\)模板题,\(KD-Tree\)解决的是多维问题,它是一个可以储存\(K\)维数据的二叉树,每一层都被一维所分割。它的插入删除复杂度为\(log^2 n\),它查询最近点对的复杂度为\(O(n^{\frac{k-1}{k}}\)\(k\)代表维数。用堆维护最近点,查询时就先找到它属于的区域,然后回溯时判断一下它到父节点的距离和堆顶的大小,如果比堆顶还大就不递归它的兄弟节点。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<queue>
#include<algorithm>

using namespace std;
const int N=50005;

inline int rd(){
	int x=0,f=1; char ch=getchar();
	while(!isdigit(ch)) f=ch=='-'?0:1,ch=getchar();
	while(isdigit(ch)) x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
	return f?x:-x;
}

inline int pw(int x){
	return x*x;
}

int n,q,K,t;

struct Node{
	int a[7];
	void init() {
		memset(a,0,sizeof(a));
	}
	friend bool operator<(const Node A,const Node B){
		return A.a[t]<B.a[t];
	}
}node[N],pt[N<<2],ans[25];
priority_queue<pair<double,Node> > Q;

struct KD_Tree{
	#define mid ((l+r)>>1)
	int end[N<<2];
	void build(int x,int l,int r,int dep){
		if(l>r) return; t=dep%K;
		end[x]=0; end[x<<1]=end[x<<1|1]=1;
		nth_element(node+l,node+mid,node+r+1);
		pt[x]=node[mid];
		build(x<<1,l,mid-1,dep+1); build(x<<1|1,mid+1,r,dep+1);
	}
	void query(int x,int dep,int lim,Node now){
		if(end[x]) return;
		pair<double,Node> tmp=make_pair(0,pt[x]);
		for(int i=0;i<K;i++) tmp.first+=pw(pt[x].a[i]-now.a[i]);
		int ls=x<<1,rs=x<<1|1,t=dep%K,flag=0;
		if(now.a[t]>=pt[x].a[t]) swap(ls,rs);
		if(!end[ls]) query(ls,dep+1,lim,now);
		if(Q.size()<lim) Q.push(tmp),flag=1;
		else {
			if(Q.top().first>tmp.first) Q.pop(),Q.push(tmp);
			if(pw(pt[x].a[t]-now.a[t])<Q.top().first) flag=1;
		}
		if(!end[rs] && flag) query(rs,dep+1,lim,now);
	}
	#undef mid
}tree;

int main(){
	while(~scanf("%d%d",&n,&K)){
		for(int i=1;i<=n;i++)
			for(int j=0;j<K;j++) node[i].a[j]=rd();
		tree.build(1,1,n,0);
		for(q=rd();q;q--){
			Node now; now.init();
			for(int i=0;i<K;i++) now.a[i]=rd();
			int t=rd(); tree.query(1,0,t,now); 
			for(int i=1;!Q.empty();i++)
				ans[i]=Q.top().second,Q.pop();
			printf("the closest %d points are:\n",t);
			for(int i=t;i;i--){
				printf("%d",ans[i].a[0]);
				for(int j=1;j<K;j++)
					printf(" %d",ans[i].a[j]);
				putchar('\n');
			}
		}
	}	
	return 0;
}
posted @ 2019-03-07 17:54  Monster_Qi  阅读(144)  评论(0编辑  收藏  举报