K-D Tree

K-D Tree

简介

K-D Tree 是一种能够维护多维数据的二叉树。一般来说,竞赛中用到的都是 2-D Tree,即维护二维数据,比如平面上的点的数据。

K-D Tree 能解决的问题用很多其他的数据结构也能解决,比如 CDQ 分治、树套树,抑或是一些计算几何知识。因为 K-D Tree 在 \(n\gg2^k\) 有较好的时间效率,且其好写好调、支持在线、空间消耗小,所以也是解决部分问题的专有手段。

但需要注意的是,K-D Tree 的复杂度是依赖于数据的。换句话说,它在极端情况下其实能被构造数据卡掉。只是目前卡它的人不多,所以它大多数时候都被用作一种较好的骗分手段。

静态建树

这种情况建立在所有的点都预先给定,然后我们需要把它们建成一棵 K-D Tree。

我们显然希望这棵 K-D Tree 尽可能是平衡的,即树高为 \(\log n\)。又因为它要维护二维的数据,所以我们尽可能让当前维度坐标的中位数为根。具体而言,步骤为:

  1. 第一次划分,找到所有点的 \(x\) 坐标的中位数作为当前根节点,分为左右两部分;
  2. 第二次划分,将左右两部分再按照 \(y\) 坐标的中位数作为当前根节点,继续递归;
  3. 第三次划分,将每个部分再找到 \(x\) 坐标的中位数作为根节点,再递归下去。

以此类推,直到建树完成。总之,按照 \(x\)\(y\) 坐标的中值交替建树。

至于找出某一维度的中位数,这里我们合理使用 STL 中的 nth_element 函数即可。总建树复杂度为 \(O(N\log N)\)

void build(int d,int s,int t,int&p){
    if(s>t) return p=0,void();
    int mid=(s+t)>>1;
    nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
        return x.v[d]<y.v[d];
    });
    p=++tot;
    st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1];
    build(d^1,s,mid-1,lp);
    build(d^1,mid+1,t,rp);
    pushup(p);
}

这里的 pushup 与一般的平衡树是类似的,都是用来维护子树信息。一般来说 K-D Tree 的一个节点都需要维护当前节点的最小/最大的 \(x\)\(y\) 坐标。

动态建树

动态建树一般有根号重构和二进制分组两种方式,一般二进制分组会快一些。

我们维护若干棵大小为 \(2^i\) 的 K-D Tree,满足这些树的大小之和为 \(n\)。每插入的时候,就新增一棵大小为 \(2^0\) 的 K-D Tree,然后不断地向上合并。实际实现的时候可以先将合并在一起的树拍扁,然后只需重构一次即可。复杂度均摊 \(O(n\log^2n)\)

struct{
	#define lp st[p].lc
	#define rp st[p].rc
	struct KDT{
		int lc,rc,v[2],mx[2],mn[2];
	}st[MAXN];
	int tot,used[MAXN],pt;
	void del(int&p){
		used[++pt]=p;
		st[p]={0,0,0,0,0,0,0,0};
		p=0;
	}
	int newnode(){
		return pt?used[pt--]:++tot;
	}
	void pushup(int p){
		for(int i=0;i<2;i++){
			st[p].mx[i]=st[p].mn[i]=st[p].v[i];
			if(lp){
				st[p].mx[i]=max(st[p].mx[i],st[lp].mx[i]);
				st[p].mn[i]=min(st[p].mn[i],st[lp].mn[i]);
			}
			if(rp){
				st[p].mx[i]=max(st[p].mx[i],st[rp].mx[i]);
				st[p].mn[i]=min(st[p].mn[i],st[rp].mn[i]);
			}
		}
	}
	void build(int d,int s,int t,int&p){
		if(s>t) return;
		int mid=(s+t)>>1;
		nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
			return x.v[d]<y.v[d];
		});
		p=newnode();
		st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1];
		build(d^1,s,mid-1,lp);
		build(d^1,mid+1,t,rp);
		pushup(p);
	}
	void redo(int&p){
		if(!p) return;
		a[++cnt]={st[p].v[0],st[p].v[1]};
		redo(lp),redo(rp);
		del(p);
	}
	void ins(int x,int y){
		a[cnt=1]={x,y};
		for(int i=0;i<=F;i++){
			if(!rt[i]){
				build(0,1,cnt,rt[i]);
				break;
			}else redo(rt[i]);
		}
	}
    // 以上为插入部分函数
}T;

实际插入的时候调用 \(\operatorname{ins}(x,y)\) 即可。注意,二进制分组的组数(即代码中的 F)是等于 \(\log(\text{点数})\) 的。

查询操作

矩阵查询

这种查询方式一般要求查询一个矩阵范围内的点的相关信息。具体实现上,在递归途中,如果目标矩形和当前矩形无交点,则跳出;如果当前矩形被目标矩形完全包含,则直接返回当前信息;否则先判断当前矩形是否合法,然后继续递归搜索。

可以证明这样做的复杂度是 \(O(n^{1-1/k})\) 的。对于常见的 2-D Tree,那就是 \(O(n\sqrt n)\)

典型例题:P4475 巧克力王国

将每块巧克力看成在二维平面上的一个点,坐标 \((x,y)\),点权为 \(h\)。建出 K-D Tree 后,维护 K-D Tree 上每个节点的坐标最大/最小值。对于每次查询,代入系数 \(a,b\) 进行计算,对于当前节点维护的坐标极值分别算出答案判断是否小于 \(c\),然后再使用上文的判断方法继续递归即可。

实际上,这题并不是严格意义上的矩阵查询,所以它的复杂度是由随机数据保证的。

核心代码如下:

using ll=long long;
constexpr int MAXN=50005;
int n,m,rt;
ll A,B,C;
struct Node{
	int v[2],vl;
}a[MAXN];
struct{
	#define lp st[p].lc
	#define rp st[p].rc
	struct KDT{
		int lc,rc,v[2],mx[2],mn[2];
		ll sm,vl;
	}st[MAXN];
	int tot;
	void pushup(int p){
		st[p].sm=st[lp].sm+st[rp].sm+st[p].vl;
		for(int i=0;i<2;i++){
			st[p].mx[i]=st[p].mn[i]=st[p].v[i];
			if(lp){
				st[p].mx[i]=max(st[p].mx[i],st[lp].mx[i]);
				st[p].mn[i]=min(st[p].mn[i],st[lp].mn[i]);
			}
			if(rp){
				st[p].mx[i]=max(st[p].mx[i],st[rp].mx[i]);
				st[p].mn[i]=min(st[p].mn[i],st[rp].mn[i]);
			}
		}
	}
	void build(int d,int s,int t,int&p){
		if(s>t) return p=0,void();
		int mid=(s+t)>>1;
		nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
			return x.v[d]<y.v[d];
		});
		p=++tot;
		st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1],st[p].vl=a[mid].vl;
		build(d^1,s,mid-1,lp);
		build(d^1,mid+1,t,rp);
		pushup(p);
	}
	int chk(int p){
		int res=0;
		if(st[p].mn[0]*A+st[p].mn[1]*B<C) res++;
		if(st[p].mn[0]*A+st[p].mx[1]*B<C) res++;
		if(st[p].mx[0]*A+st[p].mn[1]*B<C) res++;
		if(st[p].mx[0]*A+st[p].mx[1]*B<C) res++;
		return res;
	}
	ll query(int p){
		switch(chk(p)){
			case 0: return 0;
			case 4: return st[p].sm;
			default:{
				ll res=0;
				if(st[p].v[0]*A+st[p].v[1]*B<C) res+=st[p].vl;
				if(lp) res+=query(lp);
				if(rp) res+=query(rp);
				return res;
			}
		}
	}
}T;

int main(){
	n=read(),m=read();
	for(int i=1;i<=n;i++) a[i]={read(),read(),read()};
	T.build(0,1,n,rt);
	while(m--){
		A=read(),B=read(),C=read();
		write(T.query(rt));
	}
	return fw,0;
}

邻域查询

这种查询方式一般要求查询距离一个点最近/最远的点。我们一般需要贪心地为一棵子树设计一个估价函数,优先往估价函数优的一方去搜索。再加上最优性剪枝即可。

K-D Tree 在邻域查询上的时间复杂度依旧是由随机数据保证的。

典型例题:P2479 [SDOI2010] 捉迷藏

查询最近点的板子题。代码如下:

constexpr int MAXN=1e5+5,INF=0x3f3f3f3f;
int n,rt;
struct Node{
	int v[2];
}a[MAXN];
struct{
	#define lp st[p].lc
	#define rp st[p].rc
	struct KDT{
		int lc,rc,v[2],mx[2],mn[2];
	}st[MAXN];
	int tot;
	void pushup(int p){
		for(int i=0;i<2;i++){
			st[p].mx[i]=st[p].mn[i]=st[p].v[i];
			if(lp){
				st[p].mx[i]=max(st[p].mx[i],st[lp].mx[i]);
				st[p].mn[i]=min(st[p].mn[i],st[lp].mn[i]);
			}
			if(rp){
				st[p].mx[i]=max(st[p].mx[i],st[rp].mx[i]);
				st[p].mn[i]=min(st[p].mn[i],st[rp].mn[i]);
			}
		}
	}
	void build(int d,int s,int t,int&p){
		if(s>t) return p=0,void();
		int mid=(s+t)>>1;
		nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
			return x.v[d]<y.v[d];
		});
		p=++tot;
		st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1];
		build(d^1,s,mid-1,lp);
		build(d^1,mid+1,t,rp);
		pushup(p);
	}
	int dis(int x1,int y1,int x2,int y2){
		return abs(x1-x2)+abs(y1-y2);
	}
	int fmin(int x,int y,int p){
		int res=0;
		if(x<st[p].mn[0]) res+=st[p].mn[0]-x;
		if(x>st[p].mx[0]) res+=x-st[p].mx[0];
		if(y<st[p].mn[1]) res+=st[p].mn[1]-y;
		if(y>st[p].mx[1]) res+=y-st[p].mx[1];
		return res;
	}
	int fmax(int x,int y,int p){
		int res=0;
		res+=max(abs(x-st[p].mn[0]),abs(x-st[p].mx[0]));
		res+=max(abs(y-st[p].mn[1]),abs(y-st[p].mx[1]));
		return res;
	}
	void qmin(int&mn,int x,int y,int p){
		if(!p) return;
		if(x!=st[p].v[0]||y!=st[p].v[1]) mn=min(mn,dis(x,y,st[p].v[0],st[p].v[1]));
		int vl=INF,vr=INF;
		if(lp) vl=fmin(x,y,lp);
		if(rp) vr=fmin(x,y,rp);
		if(vl<vr){
			if(vl<mn) qmin(mn,x,y,lp);
			if(vr<mn) qmin(mn,x,y,rp);
		}else{
			if(vr<mn) qmin(mn,x,y,rp);
			if(vl<mn) qmin(mn,x,y,lp);
		}
	}
	void qmax(int&mx,int x,int y,int p){
		if(!p) return;
		if(x!=st[p].v[0]||y!=st[p].v[1]) mx=max(mx,dis(x,y,st[p].v[0],st[p].v[1]));
		int vl=-INF,vr=-INF;
		if(lp) vl=fmax(x,y,lp);
		if(rp) vr=fmax(x,y,rp);
		if(vl>vr){
			if(vl>mx) qmax(mx,x,y,lp);
			if(vr>mx) qmax(mx,x,y,rp);
		}else{
			if(vr>mx) qmax(mx,x,y,rp);
			if(vl>mx) qmax(mx,x,y,lp);
		}
	}
}T;

int main(){
	n=read();
	for(int i=1;i<=n;i++) a[i]={read(),read()};
	T.build(0,1,n,rt);
	int ans=INF;
	for(int i=1,mx,mn;i<=n;i++){
		mx=-INF,mn=INF;
		T.qmax(mx,a[i].v[0],a[i].v[1],rt);
		T.qmin(mn,a[i].v[0],a[i].v[1],rt);
		ans=min(ans,mx-mn);
	}
	printf("%d\n",ans);
	return 0;
}
posted @ 2025-05-03 16:49  Laoshan_PLUS  阅读(42)  评论(0)    收藏  举报