peiwenjun's blog 没有知识的荒原

K-D Tree 学习笔记

一、简介

此处强烈推荐 oi-wiki

\(\texttt{K-D Tree}\) ,全称 \(\texttt{K-Dimension Tree}\) ,是维护 \(k\) 维空间信息的数据结构。

二、静态建树

操作步骤:每次选择一个维度,取该维度的中位数作为根节点,左右递归建树。

两种常见切割方式:

  • 旋转切割:第 \(i\) 次选择第 \(i\bmod k\) 个维度进行切割。
  • 方差切割:选择方差最大的一维进行切割。

由于可以构造数据使方差切割全沿着一个维度切割,容易导致复杂度退化,而且旋转切割更好写,所以下面只介绍旋转切割。


对于求中位数的操作,贴心的 \(\texttt{STL}\) 已经帮我们封装好了。调用 nth_element(a+l,a+x,a+r+1,cmp) 可以按 cmp 函数将 \(a[l\sim r]\) 中排名为 \(x\) 的元素归位,保证 \(a[x]\) 左边更小,右边更大。

底层原理是每次快速排序只处理一边,则有 \(T(n)=T(\frac n2)+\mathcal O(n)\) ,解得 \(T(n)=\mathcal O(n)\)

对于 \(\texttt{K-D Tree}\) 的建树部分, \(T(n)=2\cdot T(\frac n2)+O(n)\) ,解得 \(T(n)=O(n\log n)\)

struct poi
{
    int v[3];///长为 k 的数组存储坐标
    int val;///每个点的其他信息(比如权值)
}a[maxn];
struct node
{
    int ls,rs,sum;///左右儿子、其他信息(比如权值和)
    poi q;
    /*int mn[3],mx[3];*/ ///矩形查询需要维护每维坐标最小最大值,后面会讲
}f[maxn];
int build(int l,int r,int d=0)
{///对下标a[l],...,a[r]建立KDT,当前按照第d维切割,返回根节点
    if(l>r) return 0;
    int p=++tot,mid=(l+r+1)>>1;///+1使得KDT形态左偏,如果不加是右偏
    nth_element(a+l,a+mid,a+r+1,[&](const poi &x,const poi &y){return x.v[d]<y.v[d];});
    f[p].q=a[mid];
    f[p].ls=build(l,mid-1,(d+1)%k),f[p].rs=build(mid+1,r,(d+1)%k);
    return pushup(p),p;
}

温馨提示:

  • 如果 \(k=2\) ,笔者建议用 pair 或两个 int 变量存储坐标。
  • 如果 \(k\ge 3\) ,为方便查询边界坐标传参,也可以用 array 存储坐标。

三、矩形查询

在建树时对每个节点 \(p\) 维护子树中每一维度坐标的最小、最大值。

  • 若当前节点与询问矩形无交,直接返回。
  • 若当前节点子树完全被询问矩形包含,统计信息后返回。
  • 否则判断根节点是否在询问矩形内,并递归左右子树。

时间复杂度分析:

计算递归节点数本质上是计算有多少个矩形被询问矩形 "切开" ,可以对询问矩形的每条边分别分析。

连续 \(k\) 轮切割会将整个超矩形分成 \(2^k\) 个部分,其中一条直线至多穿过其中 \(2^{k-1}\) 个部分。

因此 \(T(n)=2^{k-1}T(n/2^k)+\mathcal O(1)\) ,解得 \(T(n)=\mathcal O(n^{1-1/k})\)


\(\texttt{K-D Tree}\) 建树常数很小,但矩形查询常数非常大,主要瓶颈在于递归本身和三类判断自带 \(3k\) 倍常数。

假设节点数与查询数同阶,则 \(k=2\) 时时间复杂度 \(\mathcal O(n\sqrt n)\)\(k=3\) 时时间复杂度 \(\mathcal O(n^{5/3})\) ,当 \(k\) 更大时实际运行效率可能比不过 \(\mathcal O(n)\) 循环,因此目前的竞赛题大多满足 \(k\in\{2,3\}\)

为了让读者有更直观的感受,这里和分块做个对比:当 \(n=10^5\) 时,\(n\sqrt n\approx 3.1\cdot 10^7\) 。一般分块题只需要跑 \(\mathcal O(n\sqrt n)\) 长度的循环,但 \(\texttt{K-D Tree}\) 需要做 \(\mathcal O(n\sqrt n)\) 次递归,更别提 \(4\) 条边自带 \(4\) 倍常数,递归内部还有很多复杂操作。

对于 \(\texttt{K-D Tree}\) 模板题,笔者的代码实现在 \(k=2,m=1.5\cdot 10^5\)\(k=3,m=10^5\) 时单个测试点用时大约在 \(2\texttt{s}\sim 2.1\texttt{s}\) ,数据仅供参考。

以查询矩形内节点数量为例:

#define arr array<int,3>/// 为方便传参,这里用 array 存储节点信息
int query(int p,const arr l,const arr r)
{
    if(!p) return 0;
    bool f1=1,f2=1,f3=1;/// f1 包含, f2 无交, f3 根节点在矩形内
    for(int i=0;i<k;i++)
    {
        f1&=l[i]<=f[p].mn[i]&&f[p].mx[i]<=r[i];
        f2&=l[i]<=f[p].mx[i]&&f[p].mn[i]<=r[i];
        f3&=l[i]<=f[p].q.v[i]&&f[p].q.v[i]<=r[i];
    }
    if(f1) return f[p].sz;
    if(!f2) return 0;
    return f3+query(f[p].ls,l,r)+query(f[p].rs,l,r);
}

四、动态插入

替罪羊树(不推荐)

同二叉平衡树维持平衡的方法,设定平衡因子 \(\alpha\in [0,7,0.8]\) ,如果左右某棵子树大小超过了总大小乘以 \(\alpha\) ,则将整棵树重构。

int top,st[maxn];
int newnode(const poi &q)
{///节点回收
    int p=top?st[top--]:++tot;
    return f[p].q={/*由点 q 构成的子树信息*/},q;
}
/// build 函数中通过 newnode 函数新建节点
bool bad(int p)
{
    return max(f[f[p].ls].sz,f[f[p].rs].sz)>0.7*f[p].sz;
}
void dfs(int p)
{
    if(!p) return ;
    a[++cnt]=f[p].q,st[++top]=p,dfs(f[p].ls),dfs(f[p].rs);
}
void insert(int &p,int d,const poi &q)
{
    if(!p) return f[p=++tot].q=q,pushup(p),p;
    insert(q.v[d]<f[p].q.v[d]?f[p].ls:f[p].rs,(d+1)%k,q);
    if(bad(p)) cnt=0,dfs(p),p=build(1,cnt,d);
}

可以证明插入时间复杂度 \(\mathcal O(n\log n)\) ,但由于替罪羊树树高 \(\mathcal O(\log n)\) 并非严格 \(\log n\) ,所以查询时间复杂度无法保证。

根号分治(推荐)

\(B\) 次插入后重构一次,时间复杂度 \(\mathcal O(\frac nB\cdot n\log n+n(B+\sqrt n))\) ,取 \(B=\sqrt{n\log n}\) ,则时间复杂度为 \(\mathcal O(n\sqrt{n\log n})\) 。由于查询部分常数较小,所以可以适当将 \(B\) 调大。

void insert(const poi &q)
{
    a[++cnt]=q;
    if(cnt%B==0) build(1,lst=cnt,tot=0);
}
int ask(int l,int r,int d,int u)
{
    int res=query(1,l,r,d,u);
    for(int i=lst+1;i<=cnt;i++) res+=/* a[i] 贡献 */;
    return res;
}

二进制分组(推荐)

维护若干棵大小为 \(2\) 的幂次的 \(\texttt{K-D Tree}\) ,插入新节点时,假如存在大小为 \(2^0,\cdots,2^{t-1}\)\(\texttt{K-D Tree}\) ,且不存在大小为 \(2^t\)\(\texttt{K-D Tree}\) ,则将新节点连同这些 \(\texttt{K-D Tree}\) 中的节点(共 \(2^t\) 个)构建成一棵新的 \(\texttt{K-D Tree}\)

vector<pii> T;///pair<根节点,子树大小>
int top,st[maxn];
int newnode(const poi &q)
{///节点回收
    int p=top?st[top--]:++tot;
    return f[p].q={/*由点 q 构成的子树信息*/},q;
}
/// build 函数中通过 newnode 函数新建节点
void clean(int p)
{
    if(!p) return ;
    st[++top]=p,a[++cnt]=f[p].q,clean(f[p].ls),clean(f[p].rs);
}
void insert(const poi &q)
{
    a[cnt=1]=q;
    while(!T.empty()&&T.back().se==cnt) clean(T.back().fi),T.pop_back();
    T.push_back(mp(build(1,n,0),cnt));
}
int ask(int l,int r,int d,int u)
{
    int res=0;
    for(auto p:T) res+=query(p.fi,l,r,d,u);
    return res;
}

由于大小为 \(2^t\) 的树只需构建 \(\frac{n}{2^t}\) 次,因此建树时间复杂度 \(\mathcal O(\sum_{t=1}^{\log n}\frac n{2^t}\cdot 2^t\log 2^t)=\mathcal O(n\log^2n)\)

矩形查询合并 \(\mathcal O(\log n)\)\(\texttt{K-D Tree}\) 的查询结果即可,时间复杂度 \(\mathcal O(\sum_{t=1}^{\log n}\sqrt{2^t})=\mathcal O(\sqrt n)\)

效率对比

以最简单的动态加点、矩形数点为例:

替罪羊树 根号分治 二进制分组
插入 \(\mathcal O(n\log n)\) \(\mathcal O(\frac nB\cdot n\log n)\) \(\mathcal O(n\log^2n)\)
查询 无法保证 \(\mathcal O(\sqrt n+B)\) \(\mathcal O(\sqrt n)\)
\(n=10^5\) 平均单次查询访问节点数 860~880 (800~850)+1600 1250~1300
\(n=2\cdot 10^5\) 平均单次查询访问节点数 1200~1300 (1150~1200)+2400 1800~1900
\(10^5\) 插入 + \(10^5\) 查询 1.4s 1.4s 1.4s
代码量 2.28k 2.02k 2.16k

注:第 \(3,4\) 项数据通过先插入 \(n\) 个节点再查询 \(1000\) 次取平均得到。(\(\alpha=0.7,B=1600\) ,忽略不交的节点,坐标和矩形边界在 \([0,10^9]\) 内随机生成)

  • 替罪羊树做法在随机数据下表现良好,但可以构造数据(虽然博主不会)将其卡掉。
  • 根号分治做法整体复杂度较高,但插入常数较小,而且查询用的 \(\texttt{KD-Tree}\) 最优,因此常数上有优势。调参发现插入、查询各 \(10^5\) 次时,取 \(B=1600\) 最优。
  • 二进制分组做法虽然插入复杂度高于静态 \(\texttt{KD-Tree}\) ,但常数较小且并非复杂度瓶颈,所以问题不大,反倒是查询部分需要访问约 \(4\sqrt n\) 个节点。

五、最近邻域查询

注:

  • 多次查询某点到 \(n\) 个固定点的最短距离,使用 \(\texttt{K-D Tree}\) 加剪枝,随机数据下是 \(\mathcal O(\log n)\) ,但可以被卡到 \(\mathcal O(n)\)
  • 借鉴 \(\texttt{K-D Tree}\) 的思想,分治做法可以在 \(\mathcal O(n\log^2n)\) 的时间内解决平面最近点对问题。

初始将所有点按横坐标排序,维护当前最短距离 \(d\) ,先递归左右两边并更新 \(d\)

设中心点横坐标为 \(x_0\) ,我们需要统计跨过竖直线 \(x=x_0\) 的最短距离,显然只有横坐标在 \([x_0-d,x_0+d]\) 内的点是有用的,下面称为关键点。

如果直接枚举关键点对,时间复杂度会退化至 \(\mathcal O(n^2)\)

将关键点按纵坐标排序,对每个关键点 \((x,y)\) ,我们只需枚举 \(y'\in (y-d,y]\) 的关键点 \((x',y')\)

这样我们需要枚举的关键点被限制在了如下图所示的 \(2d\times d\) 的矩形内,且红色和绿色正方形内任意两个点距离都不超过 \(d\)

如果不算重复点的话矩形内至多只能塞下 \(6\) 个点,所以往前枚举 \(5\) 个点即可。为了克服重复点的影响,我们可以在排序上动手脚,如果 \(y\) 相同按 \(x\) 排序即可。

image

使用归并排序,时间复杂度 \(\mathcal O(n\log n)\)

///P7883
#include<bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define mp make_pair
#define pii pair<int,int>
using namespace std;
const int maxn=4e5+5;
int n;
ll d=1e18;
pii a[maxn],p[maxn];
ll dis(pii x,pii y)
{
    return 1ll*(x.fi-y.fi)*(x.fi-y.fi)+1ll*(x.se-y.se)*(x.se-y.se);
}
void solve(int l,int r)
{
    if(l>=r) return ;
    int mid=(l+r)>>1,x=p[mid].fi;
    solve(l,mid),solve(mid+1,r);
    inplace_merge(p+l,p+mid+1,p+r+1,[&](pii x,pii y){return x.se!=y.se?x.se<y.se:x.fi<y.fi;});
    for(int i=l,j=0;i<=r;i++)
        if(1ll*(p[i].fi-x)*(p[i].fi-x)<d)/// d 为平方量级,注意不要写成 abs(p[i].fi-x)<d
        {
            a[++j]=p[i];
            for(int k=max(j-5,1);k<j;k++) d=min(d,dis(a[k],a[j]));
        }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%d%d",&p[i].fi,&p[i].se);
    sort(p+1,p+n+1,[&](pii x,pii y){return x.fi<y.fi;});
    solve(1,n);
    printf("%lld\n",d);
    return 0;
}

例题

例1、\(\texttt{P14312 【模板】K-D Tree}\)

题目描述

维护一个 \(k\) 维空间的可重点集 \(S\)\(m\) 次操作:

  • 1 x_1 .. x_k v :向 \(S\) 插入坐标为 \((x_1,\cdots,x_k)\) ,权值为 \(v\) 的点。
  • 2 x_1 .. x_k y_1 .. y_k v :将 \(S\) 中所有在以 \(A=(x_1,\cdots,x_k)\)\(B=(y_1,\cdots,y_k)\) 为顶点的高维矩形中的点权值增加 \(v\)
  • 3 x_1 .. x_k y_1 .. y_k :询问 \(S\) 中所有在以 \(A=(x_1,\cdots,x_k)\)\(B=(y_1,\cdots,y_k)\) 为顶点的高维矩形中的点的权值和。

强制在线。

数据范围

  • \(k\in\{2,3\}\) 。当 \(k=2\) 时, \(1\le m\le 1.5\cdot 10^5\) ;当 \(k=3\) 时, \(1\le m\le 10^5\)
  • \(1\le x_i\le y_i\le 10^{18},1\le v\le 10^5\)

时间限制 \(\texttt{5s}\) ,空间限制 \(\texttt{32MB}\)

分析

本例用于展示二进制分组写法的代码。

平衡树版本的标记下传较为复杂,敲代码时一定要明确各变量含义。

时间复杂度 \(\mathcal O(m^{2-\frac 1k})\)

温馨提示:如果查询代价过高,请检查 f[0] 是否赋了初始值。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=1.5e5+5;
const ll inf=1e18;
int k,m,n,op,tot;
ll x,lst;
int top,st[maxn];
vector<int> T;
struct poi
{
    array<ll,3> v;
    ll val;
}a[maxn];
struct node
{
    int ls,rs,sz;
    ll sum,tag;/// sum 为不考虑 [rt,p) 标记时的子树和
    poi q;/// q.val 为不考虑 [rt,p] 标记时的节点权值
    array<ll,3> mn,mx;
}f[maxn];
int newnode(const poi &q)
{
    int p=top?st[top--]:++tot;
    return f[p]={0,0,1,q.val,0,q,q.v,q.v},p;
}
void pushup(int p)
{
    f[p].sz=f[f[p].ls].sz+1+f[f[p].rs].sz;
    f[p].sum=f[f[p].ls].sum+f[p].q.val+f[f[p].rs].sum+f[p].tag*f[p].sz;///关键!!! 需要配合 sum,val 的定义理解
    for(int i=0;i<k;i++)
    {
        f[p].mn[i]=min({f[f[p].ls].mn[i],f[p].q.v[i],f[f[p].rs].mn[i]});
        f[p].mx[i]=max({f[f[p].ls].mx[i],f[p].q.v[i],f[f[p].rs].mx[i]});
    }
}
int build(int l,int r,int d)
{
    if(l>r) return 0;
    int mid=(l+r+1)>>1;
    nth_element(a+l,a+mid,a+r+1,[&](const poi &x,const poi &y){return x.v[d]<y.v[d];});
    int p=newnode(a[mid]);
    f[p].ls=build(l,mid-1,(d+1)%k),f[p].rs=build(mid+1,r,(d+1)%k);
    return pushup(p),p;
}
void clean(int p,ll add=0)
{
    if(!p) return ;
    st[++top]=p,add+=f[p].tag,f[p].q.val+=add,a[++n]=f[p].q;
    clean(f[p].ls,add),clean(f[p].rs,add);
}
void modify(int p,array<ll,3> l,array<ll,3> r,ll add)
{
    if(!p) return ;
    bool f1=1,f2=1,f3=1;/// f1 表示完全包含在矩形内, f2 表示与矩形有交, f3 表示根节点在矩形内
    for(int i=0;i<k;i++)
    {
        f1&=l[i]<=f[p].mn[i]&&f[p].mx[i]<=r[i];
        f2&=l[i]<=f[p].mx[i]&&f[p].mn[i]<=r[i];
        f3&=l[i]<=f[p].q.v[i]&&f[p].q.v[i]<=r[i];
    }
    if(f1) return f[p].tag+=add,f[p].sum+=add*f[p].sz,void();
    if(!f2) return ;
    if(f3) f[p].q.val+=add;
    modify(f[p].ls,l,r,add),modify(f[p].rs,l,r,add);
    pushup(p);
}
ll query(int p,array<ll,3> l,array<ll,3> r,ll add=0)
{
    if(!p) return 0;
    bool f1=1,f2=1,f3=1;
    for(int i=0;i<k;i++)
    {
        f1&=l[i]<=f[p].mn[i]&&f[p].mx[i]<=r[i];
        f2&=l[i]<=f[p].mx[i]&&f[p].mn[i]<=r[i];
        f3&=l[i]<=f[p].q.v[i]&&f[p].q.v[i]<=r[i];
    }
    if(f1) return f[p].sum+add*f[p].sz;
    add+=f[p].tag;
    if(!f2) return 0;
    return (f3?f[p].q.val+add:0ll)+query(f[p].ls,l,r,add)+query(f[p].rs,l,r,add);
}
void read(array<ll,3> &x)
{
    for(int i=0;i<k;i++) scanf("%lld",&x[i]),x[i]^=lst;
}
int main()
{
    scanf("%d%d",&k,&m),f[0].mn={inf,inf,inf};
    for(array<ll,3> l,r;m--;)
    {
        scanf("%d",&op);
        if(op==1)
        {
            read(a[n=1].v),scanf("%lld",&a[1].val),a[1].val^=lst;
            while(T.size()&&f[T.back()].sz==n) clean(T.back()),T.pop_back();
            T.push_back(build(1,n,0));
        }
        if(op==2)
        {
            read(l),read(r),scanf("%lld",&x),x^=lst;
            for(auto p:T) modify(p,l,r,x);
        }
        if(op==3)
        {
            read(l),read(r),x=0;
            for(auto p:T) x+=query(p,l,r);
            printf("%lld\n",lst=x);
        }
    }
    return 0;
}

例2、\(\texttt{P3810 【模板】三维偏序 / 陌上花开}\)

题目描述

\(n\) 个元素 \((a_i,b_i,c_i)\) ,记 \(f(i)=|\{j\mid a_j\le a_i,b_j\le b_i,c_j\le c_i\}|\)

\(\forall d\in [0,n)\) ,求 \(f(i)=d\) 的元素数量。

数据范围

  • \(1\le n\le 10^5,1\le a_i,b_i,c_i\le 2\cdot 10^5\)

时间限制 \(\texttt{1s}\) ,空间限制 \(\texttt{500MB}\)

分析

本例用于展示根号分治写法的代码。

三维的 \(\texttt{K-D Tree}\) 是过不去的,问就是博主替大家试过了。

将所有点按 \(a\) 排序并去重,问题转化为二维平面上的动态加点,矩形查询。

时间复杂度 \(\mathcal O(n\sqrt n)\)

温馨提示:如果查询代价过高,请检查 f[0] 是否赋了初始值。

#include<bits/stdc++.h>
using namespace std;
const int B=1600,maxn=1e5+5;
int m,n,lst,tot;
int buc[maxn];
struct poi
{
    int a,b,c,v;
}e[maxn];
struct node
{
    int ls,rs,sum;
    int l1,r1,l2,r2;
    poi q;
}f[maxn];
void pushup(int p)
{
    f[p].sum=f[f[p].ls].sum+f[p].q.v+f[f[p].rs].sum;
    f[p].l1=min({f[f[p].ls].l1,f[p].q.b,f[f[p].rs].l1});
    f[p].r1=max({f[f[p].ls].r1,f[p].q.b,f[f[p].rs].r1});
    f[p].l2=min({f[f[p].ls].l2,f[p].q.c,f[f[p].rs].l2});
    f[p].r2=max({f[f[p].ls].r2,f[p].q.c,f[f[p].rs].r2});
}
int newnode(const poi &q)
{
    return f[++tot]={0,0,q.v,q.b,q.b,q.c,q.c,q},tot;
}
int build(int l,int r,int d)
{
    if(l>r) return 0;
    int mid=(l+r+1)>>1;
    nth_element(e+l,e+mid,e+r+1,[&](const poi &x,const poi &y){return !d?x.b<y.b:x.c<y.c;});
    int p=newnode(e[mid]);
    f[p].ls=build(l,mid-1,d^1),f[p].rs=build(mid+1,r,d^1);
    return pushup(p),p;
}
int query(int p,int r1,int r2)
{
    if(!p||f[p].l1>r1||f[p].l2>r2) return 0;
    if(f[p].r1<=r1&&f[p].r2<=r2) return f[p].sum;
    return query(f[p].ls,r1,r2)+(f[p].q.b<=r1&&f[p].q.c<=r2?f[p].q.v:0)+query(f[p].rs,r1,r2);
}
int main()
{
    scanf("%d%*d",&n),f[0].l1=f[0].l2=1e9;
    for(int i=1;i<=n;i++) scanf("%d%d%d",&e[i].a,&e[i].b,&e[i].c);
    sort(e+1,e+n+1,[](poi x,poi y){return x.a!=y.a?x.a<y.a:(x.b!=y.b?x.b<y.b:x.c<y.c);});
    for(int i=1,j;i<=n;)
    {
        for(j=i;j<=n&&e[i].a==e[j].a&&e[i].b==e[j].b&&e[i].c==e[j].c;j++) ;
        e[++m]=e[i],e[m].v=j-i,i=j;
    }
    for(int i=1;i<=m;i++)
    {
        int res=query(1,e[i].b,e[i].c);
        for(int j=lst+1;j<i;j++) if(e[j].b<=e[i].b&&e[j].c<=e[i].c) res+=e[j].v;
        buc[res+e[i].v-1]+=e[i].v;
        if(i%B==0) build(1,lst=i,tot=0);
    }
    for(int i=0;i<n;i++) printf("%d\n",buc[i]);
    return 0;
}

posted on 2026-02-15 11:13  peiwenjun  阅读(13)  评论(0)    收藏  举报

导航