【BZOJ3053】The Closest M Points-KD树

测试地址:The Closest M Points
题目大意:kk维空间里有nn个点,qq次询问,每次询问这nn个点中与某个点的欧式距离前mm小的是哪些。
做法: 本题需要用到KD树。
KD树的一道比较复杂的模板题,复杂度我也不懂,好像是随机情况下O(nlogn)O(n\log n)?总之能过。于是KD树就照常建,在询问的时候,答案用一个堆来做会快一点。里面会涉及到求一个点到一个超立方体的最小距离,直接从代数形式看,求出每一维的最小距离再相加即可。这样我们就完成了这一题。
以下是本人代码

#include <bits/stdc++.h>
using namespace std;
int n,k,q,m,now[5],rt,d;
struct point
{
	int l,r;
	int x[5],mn[5],mx[5],dis;
	bool operator < (point a) const
	{
		return dis<a.dis;
	}
}p[50010],inf,anss[20];
priority_queue <point> ans;

bool cmp(point a,point b)
{
	return a.x[d]<b.x[d];
}

void pushup(int v)
{
	for(int i=0;i<k;i++)
	{
		if (p[v].l)
		{
			p[v].mn[i]=min(p[v].mn[i],p[p[v].l].mn[i]);
			p[v].mx[i]=max(p[v].mx[i],p[p[v].l].mx[i]);
		}
		if (p[v].r)
		{
			p[v].mn[i]=min(p[v].mn[i],p[p[v].r].mn[i]);
			p[v].mx[i]=max(p[v].mx[i],p[p[v].r].mx[i]);
		}
	}
}

int build(int l,int r,int now)
{
	if (l==r) return l;
	if (l>r) return 0;
	int mid=(l+r)>>1;
	d=now%k;
	nth_element(p+l,p+mid,p+r+1,cmp);
	p[mid].l=build(l,mid-1,now+1);
	p[mid].r=build(mid+1,r,now+1);
	pushup(mid);
	return mid;
}

int getdis(int v)
{
	int ans=0;
	for(int i=0;i<k;i++)
	{
		int s1=p[v].mn[i]-now[i];
		int s2=p[v].mx[i]-now[i];
		if (!((s1<0&&s2>0)||(s1>0&&s2<0)))
			ans+=min(s1*s1,s2*s2);
	}
	return ans;
}

void solve(int v)
{
	p[v].dis=0;
	for(int i=0;i<k;i++)
		p[v].dis+=(now[i]-p[v].x[i])*(now[i]-p[v].x[i]);
	if (ans.size()<m)
		ans.push(p[v]);
	else if (p[v].dis<ans.top().dis)
	{
		ans.pop();
		ans.push(p[v]);
	}
	
	int lp=p[v].l,rp=p[v].r;
	int dl=getdis(lp),dr=getdis(rp);
	if (dl>dr) swap(lp,rp),swap(dl,dr);
	if (lp&&dl<ans.top().dis) solve(lp);
	if (rp&&dr<ans.top().dis) solve(rp);
}

int main()
{
	inf.dis=2147483647;
	
	while(scanf("%d%d",&n,&k)!=EOF)
	{
		for(int i=1;i<=n;i++)
		{
			p[i].l=p[i].r=0;
			for(int j=0;j<k;j++)
			{
				scanf("%d",&p[i].x[j]);		
				p[i].mn[j]=p[i].mx[j]=p[i].x[j];
			}
		}
		rt=build(1,n,0);
		
		scanf("%d",&q);
		for(int i=1;i<=q;i++)
		{
			for(int j=0;j<k;j++)
				scanf("%d",&now[j]);
			scanf("%d",&m);
			while(!ans.empty()) ans.pop();
			ans.push(inf);
			solve(rt);
			for(int i=m;i;i--)
			{
				anss[i]=ans.top();
				ans.pop();
			}
			printf("the closest %d points are:\n",m);
			for(int i=1;i<=m;i++)
			{
				for(int j=0;j<k;j++)
					printf("%d ",anss[i].x[j]);
				printf("\n");
			}
		}
	}
	
	return 0;
}
posted @ 2018-09-18 13:40  Maxwei_wzj  阅读(111)  评论(0编辑  收藏  举报