K-D Tree
K-D Tree
简介
K-D Tree 是一种能够维护多维数据的二叉树。一般来说,竞赛中用到的都是 2-D Tree,即维护二维数据,比如平面上的点的数据。
K-D Tree 能解决的问题用很多其他的数据结构也能解决,比如 CDQ 分治、树套树,抑或是一些计算几何知识。因为 K-D Tree 在 \(n\gg2^k\) 有较好的时间效率,且其好写好调、支持在线、空间消耗小,所以也是解决部分问题的专有手段。
但需要注意的是,K-D Tree 的复杂度是依赖于数据的。换句话说,它在极端情况下其实能被构造数据卡掉。只是目前卡它的人不多,所以它大多数时候都被用作一种较好的骗分手段。
静态建树
这种情况建立在所有的点都预先给定,然后我们需要把它们建成一棵 K-D Tree。
我们显然希望这棵 K-D Tree 尽可能是平衡的,即树高为 \(\log n\)。又因为它要维护二维的数据,所以我们尽可能让当前维度坐标的中位数为根。具体而言,步骤为:
- 第一次划分,找到所有点的 \(x\) 坐标的中位数作为当前根节点,分为左右两部分;
- 第二次划分,将左右两部分再按照 \(y\) 坐标的中位数作为当前根节点,继续递归;
- 第三次划分,将每个部分再找到 \(x\) 坐标的中位数作为根节点,再递归下去。
以此类推,直到建树完成。总之,按照 \(x\) 和 \(y\) 坐标的中值交替建树。
至于找出某一维度的中位数,这里我们合理使用 STL 中的 nth_element 函数即可。总建树复杂度为 \(O(N\log N)\)。
void build(int d,int s,int t,int&p){
if(s>t) return p=0,void();
int mid=(s+t)>>1;
nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
return x.v[d]<y.v[d];
});
p=++tot;
st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1];
build(d^1,s,mid-1,lp);
build(d^1,mid+1,t,rp);
pushup(p);
}
这里的 pushup 与一般的平衡树是类似的,都是用来维护子树信息。一般来说 K-D Tree 的一个节点都需要维护当前节点的最小/最大的 \(x\) 和 \(y\) 坐标。
动态建树
动态建树一般有根号重构和二进制分组两种方式,一般二进制分组会快一些。
我们维护若干棵大小为 \(2^i\) 的 K-D Tree,满足这些树的大小之和为 \(n\)。每插入的时候,就新增一棵大小为 \(2^0\) 的 K-D Tree,然后不断地向上合并。实际实现的时候可以先将合并在一起的树拍扁,然后只需重构一次即可。复杂度均摊 \(O(n\log^2n)\)。
struct{
#define lp st[p].lc
#define rp st[p].rc
struct KDT{
int lc,rc,v[2],mx[2],mn[2];
}st[MAXN];
int tot,used[MAXN],pt;
void del(int&p){
used[++pt]=p;
st[p]={0,0,0,0,0,0,0,0};
p=0;
}
int newnode(){
return pt?used[pt--]:++tot;
}
void pushup(int p){
for(int i=0;i<2;i++){
st[p].mx[i]=st[p].mn[i]=st[p].v[i];
if(lp){
st[p].mx[i]=max(st[p].mx[i],st[lp].mx[i]);
st[p].mn[i]=min(st[p].mn[i],st[lp].mn[i]);
}
if(rp){
st[p].mx[i]=max(st[p].mx[i],st[rp].mx[i]);
st[p].mn[i]=min(st[p].mn[i],st[rp].mn[i]);
}
}
}
void build(int d,int s,int t,int&p){
if(s>t) return;
int mid=(s+t)>>1;
nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
return x.v[d]<y.v[d];
});
p=newnode();
st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1];
build(d^1,s,mid-1,lp);
build(d^1,mid+1,t,rp);
pushup(p);
}
void redo(int&p){
if(!p) return;
a[++cnt]={st[p].v[0],st[p].v[1]};
redo(lp),redo(rp);
del(p);
}
void ins(int x,int y){
a[cnt=1]={x,y};
for(int i=0;i<=F;i++){
if(!rt[i]){
build(0,1,cnt,rt[i]);
break;
}else redo(rt[i]);
}
}
// 以上为插入部分函数
}T;
实际插入的时候调用 \(\operatorname{ins}(x,y)\) 即可。注意,二进制分组的组数(即代码中的 F)是等于 \(\log(\text{点数})\) 的。
查询操作
矩阵查询
这种查询方式一般要求查询一个矩阵范围内的点的相关信息。具体实现上,在递归途中,如果目标矩形和当前矩形无交点,则跳出;如果当前矩形被目标矩形完全包含,则直接返回当前信息;否则先判断当前矩形是否合法,然后继续递归搜索。
可以证明这样做的复杂度是 \(O(n^{1-1/k})\) 的。对于常见的 2-D Tree,那就是 \(O(n\sqrt n)\)。
典型例题:P4475 巧克力王国。
将每块巧克力看成在二维平面上的一个点,坐标 \((x,y)\),点权为 \(h\)。建出 K-D Tree 后,维护 K-D Tree 上每个节点的坐标最大/最小值。对于每次查询,代入系数 \(a,b\) 进行计算,对于当前节点维护的坐标极值分别算出答案判断是否小于 \(c\),然后再使用上文的判断方法继续递归即可。
实际上,这题并不是严格意义上的矩阵查询,所以它的复杂度是由随机数据保证的。
核心代码如下:
using ll=long long;
constexpr int MAXN=50005;
int n,m,rt;
ll A,B,C;
struct Node{
int v[2],vl;
}a[MAXN];
struct{
#define lp st[p].lc
#define rp st[p].rc
struct KDT{
int lc,rc,v[2],mx[2],mn[2];
ll sm,vl;
}st[MAXN];
int tot;
void pushup(int p){
st[p].sm=st[lp].sm+st[rp].sm+st[p].vl;
for(int i=0;i<2;i++){
st[p].mx[i]=st[p].mn[i]=st[p].v[i];
if(lp){
st[p].mx[i]=max(st[p].mx[i],st[lp].mx[i]);
st[p].mn[i]=min(st[p].mn[i],st[lp].mn[i]);
}
if(rp){
st[p].mx[i]=max(st[p].mx[i],st[rp].mx[i]);
st[p].mn[i]=min(st[p].mn[i],st[rp].mn[i]);
}
}
}
void build(int d,int s,int t,int&p){
if(s>t) return p=0,void();
int mid=(s+t)>>1;
nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
return x.v[d]<y.v[d];
});
p=++tot;
st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1],st[p].vl=a[mid].vl;
build(d^1,s,mid-1,lp);
build(d^1,mid+1,t,rp);
pushup(p);
}
int chk(int p){
int res=0;
if(st[p].mn[0]*A+st[p].mn[1]*B<C) res++;
if(st[p].mn[0]*A+st[p].mx[1]*B<C) res++;
if(st[p].mx[0]*A+st[p].mn[1]*B<C) res++;
if(st[p].mx[0]*A+st[p].mx[1]*B<C) res++;
return res;
}
ll query(int p){
switch(chk(p)){
case 0: return 0;
case 4: return st[p].sm;
default:{
ll res=0;
if(st[p].v[0]*A+st[p].v[1]*B<C) res+=st[p].vl;
if(lp) res+=query(lp);
if(rp) res+=query(rp);
return res;
}
}
}
}T;
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++) a[i]={read(),read(),read()};
T.build(0,1,n,rt);
while(m--){
A=read(),B=read(),C=read();
write(T.query(rt));
}
return fw,0;
}
邻域查询
这种查询方式一般要求查询距离一个点最近/最远的点。我们一般需要贪心地为一棵子树设计一个估价函数,优先往估价函数优的一方去搜索。再加上最优性剪枝即可。
K-D Tree 在邻域查询上的时间复杂度依旧是由随机数据保证的。
典型例题:P2479 [SDOI2010] 捉迷藏。
查询最近点的板子题。代码如下:
constexpr int MAXN=1e5+5,INF=0x3f3f3f3f;
int n,rt;
struct Node{
int v[2];
}a[MAXN];
struct{
#define lp st[p].lc
#define rp st[p].rc
struct KDT{
int lc,rc,v[2],mx[2],mn[2];
}st[MAXN];
int tot;
void pushup(int p){
for(int i=0;i<2;i++){
st[p].mx[i]=st[p].mn[i]=st[p].v[i];
if(lp){
st[p].mx[i]=max(st[p].mx[i],st[lp].mx[i]);
st[p].mn[i]=min(st[p].mn[i],st[lp].mn[i]);
}
if(rp){
st[p].mx[i]=max(st[p].mx[i],st[rp].mx[i]);
st[p].mn[i]=min(st[p].mn[i],st[rp].mn[i]);
}
}
}
void build(int d,int s,int t,int&p){
if(s>t) return p=0,void();
int mid=(s+t)>>1;
nth_element(a+s,a+mid,a+t+1,[d](const Node&x,const Node&y){
return x.v[d]<y.v[d];
});
p=++tot;
st[p].v[0]=a[mid].v[0],st[p].v[1]=a[mid].v[1];
build(d^1,s,mid-1,lp);
build(d^1,mid+1,t,rp);
pushup(p);
}
int dis(int x1,int y1,int x2,int y2){
return abs(x1-x2)+abs(y1-y2);
}
int fmin(int x,int y,int p){
int res=0;
if(x<st[p].mn[0]) res+=st[p].mn[0]-x;
if(x>st[p].mx[0]) res+=x-st[p].mx[0];
if(y<st[p].mn[1]) res+=st[p].mn[1]-y;
if(y>st[p].mx[1]) res+=y-st[p].mx[1];
return res;
}
int fmax(int x,int y,int p){
int res=0;
res+=max(abs(x-st[p].mn[0]),abs(x-st[p].mx[0]));
res+=max(abs(y-st[p].mn[1]),abs(y-st[p].mx[1]));
return res;
}
void qmin(int&mn,int x,int y,int p){
if(!p) return;
if(x!=st[p].v[0]||y!=st[p].v[1]) mn=min(mn,dis(x,y,st[p].v[0],st[p].v[1]));
int vl=INF,vr=INF;
if(lp) vl=fmin(x,y,lp);
if(rp) vr=fmin(x,y,rp);
if(vl<vr){
if(vl<mn) qmin(mn,x,y,lp);
if(vr<mn) qmin(mn,x,y,rp);
}else{
if(vr<mn) qmin(mn,x,y,rp);
if(vl<mn) qmin(mn,x,y,lp);
}
}
void qmax(int&mx,int x,int y,int p){
if(!p) return;
if(x!=st[p].v[0]||y!=st[p].v[1]) mx=max(mx,dis(x,y,st[p].v[0],st[p].v[1]));
int vl=-INF,vr=-INF;
if(lp) vl=fmax(x,y,lp);
if(rp) vr=fmax(x,y,rp);
if(vl>vr){
if(vl>mx) qmax(mx,x,y,lp);
if(vr>mx) qmax(mx,x,y,rp);
}else{
if(vr>mx) qmax(mx,x,y,rp);
if(vl>mx) qmax(mx,x,y,lp);
}
}
}T;
int main(){
n=read();
for(int i=1;i<=n;i++) a[i]={read(),read()};
T.build(0,1,n,rt);
int ans=INF;
for(int i=1,mx,mn;i<=n;i++){
mx=-INF,mn=INF;
T.qmax(mx,a[i].v[0],a[i].v[1],rt);
T.qmin(mn,a[i].v[0],a[i].v[1],rt);
ans=min(ans,mx-mn);
}
printf("%d\n",ans);
return 0;
}

浙公网安备 33010602011771号