树链剖分学习笔记
因为被 N总 D了好久,所以痛下决心想要学习一下树链剖分。
强烈推荐 N总对树链剖分的优质讲解博客(感觉比我写得好多了。。。。)。
树链剖分
树链剖分可以通过给树上的点重新编号(重新编号后就变成若干条链),使得可以将树中的任意一条路径转化成 \(O(\log n)\) 段连续的区间。那么树中路径上的所有操作都可以转化为 \(O(\log n)\) 段连续区间中的操作,也就可以用线段树来维护(也可以用其他支持区间维护的数据结构维护)。
如何将树上的路径变成区间?可以用 DFS 序来实现。
一些定义
重儿子:对于每一个非叶子节点,以它的所有儿子为根的子树中,大小最大的那个子树的根节点被称为该节点的重儿子(如果有多个子树大小最大的儿子,任选一个作为重儿子即可)。
轻儿子:对于每一个非叶子节点,除了重儿子以外的所有儿子被称为轻儿子。
重边:重儿子和它父亲之间的连边被称为重边。
轻边:轻儿子和它父亲之间的连边被称为轻边。
重链:由重边构成的路径被称为重链。(每个点都属于且仅属于一条重链,对于一些叶子节点,它们自己本身单独是一条重链)
如何判断每个点属于哪一条重链?对于每一个重儿子,它所在的重链就是它父亲所在的重链;对于每一个轻儿子,它所在的重链就是以它开头的往下走的重链。
在 DFS 求 DFS 序时,优先遍历重儿子。这样可以保证每一条重链上所有点的编号都是连续的。
如下图所示:

定理
树中的任意一条路径均可拆分成 \(O(\log n)\) 个连续的区间,即 \(O(\log n)\) 条重链(不一定完整)。
树链剖分的基础操作
通常情况下,可以通过两遍 DFS 求出所有的信息。
在第一次 DFS 时,求出所有的重儿子。
在第二次 DFS 时,求出 DFS 序,以及每个点所在重链的顶点。
将树上的路径转化为区间
有点类似于倍增求 LCA 的算法。如下图所示,假设要将 \(x,y\) 之间的路径转化为区间。记 \(x\) 所在重链的顶端为 \(top[x]\),\(y\) 所在重链的顶端为 \(top[y]\)。
比较 \(depth[top[x]]\) 与 \(depth[top[y]]\)。将深度较大的那个点跳到链顶点的父节点上。同时记录下这一段区间。
重复上述操作,直到 \(top[x]=top[y]\)。此时再比较一下两个节点的深度大小即可算出最后一段区间(这就体现了为什么要优先遍历重儿子)。
具体过程如下图所示:

而通常情况下题目就是要求对这些区间进行一些修改,用线段树维护的时间复杂度就是 \(O(n \log^2 n)\)。但是常数比较小。
模板题
给定一棵包含 \(N\) 个结点的树,每个节点上包含一个数值,需要支持以下操作:
\(1\) \(x\) \(y\) \(z\),表示将树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值都加上 \(z\)。
\(2\) \(x\) \(y\),表示求树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值之和。
\(3\) \(x\) \(z\),表示将以 xx 为根节点的子树内所有节点值都加上 \(z\)。
\(4\) \(x\) 表示求以 \(x\) 为根节点的子树内所有节点值之和
数据范围
\(1 \leq N \leq 10^5,1 \leq M \leq 10^5\)。
思路
区间修改,区间查询,显然本题用线段树维护信息是最好的。
对于 \(1,2\) 两种操作来说,上面都已经介绍过了。而对于一棵子树内的操作,注意到一棵子树内的 DFS 序一定是连续的,那么这棵子树在序列上就是 \([dfn[x],dfn[x]+size[x]-1]\) 这一段连续的区间。直接区间修改即可。
code:
#include<cstdio>
using namespace std;
const int N=1e5+10;
#define LL long long
LL w[N],a[N];//a是树上的点转化到序列上的点后,这个点在原树的权值
struct edge{
int v,nex;
}e[N<<1];
int h[N],idx,son[N],f[N],siz[N],dfn[N],num,top[N];
int depth[N],n,m,R,mod;
void add(int u,int v){e[++idx].v=v,e[idx].nex=h[u],h[u]=idx;}
struct tree{
int l,r,mid;
LL val,tag;
}tr[N<<2];
void swap(int &a,int &b){int t=a;a=b,b=t;}
void push_up(int p){tr[p].val=(tr[p<<1].val+tr[p<<1|1].val)%mod;}
void push_down(int p)
{
if(tr[p].tag)
{
int lenl=tr[p<<1].r-tr[p<<1].l+1;int lenr=tr[p<<1|1].r-tr[p<<1|1].l+1;
tr[p<<1].val=(tr[p<<1].val+tr[p].tag*lenl)%mod;
tr[p<<1|1].val=(tr[p<<1|1].val+tr[p].tag*lenr)%mod;
tr[p<<1].tag+=tr[p].tag;
tr[p<<1|1].tag+=tr[p].tag;
tr[p].tag=0;
}
}
void build_tree(int p,int l,int r)
{
tr[p].l=l,tr[p].r=r,tr[p].mid=(l+r)>>1;
if(l==r)
{
tr[p].val=a[l];
return ;
}
build_tree(p<<1,tr[p].l,tr[p].mid);
build_tree(p<<1|1,tr[p].mid+1,tr[p].r);
push_up(p);
}
void updata(int p,int l,int r,LL k)
{
if(l<=tr[p].l&&tr[p].r<=r)
{
tr[p].val+=(tr[p].r-tr[p].l+1)*k;
tr[p].tag+=k;
return ;
}
push_down(p);
if(l<=tr[p].mid) updata(p<<1,l,r,k);
if(r>tr[p].mid) updata(p<<1|1,l,r,k);
push_up(p);
}
LL query(int p,int l,int r)
{
if(l<=tr[p].l&&tr[p].r<=r) return tr[p].val;
LL res=0;
push_down(p);
if(l<=tr[p].mid) res+=query(p<<1,l,r);
if(r>tr[p].mid) res+=query(p<<1|1,l,r);
return res;
}
void dfs_yy(int u,int fa)
{
depth[u]=depth[fa]+1;
siz[u]=1;
f[u]=fa;
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;
if(v==fa) continue;
dfs_yy(v,u);
siz[u]+=siz[v];
if(!son[u]||siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs_nlc(int u,int t)
{
dfn[u]=++num;top[u]=t;a[num]=w[u];
if(siz[u]==1) return ;
dfs_nlc(son[u],t);
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;
if(v==f[u]||v==son[u]) continue;
dfs_nlc(v,v);
}
}
void updata_tree(int x,int y,LL k)
{
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]]) swap(x,y);
updata(1,dfn[top[x]],dfn[x],k);
x=f[top[x]];
}
if(depth[x]>depth[y]) swap(x,y);
updata(1,dfn[x],dfn[y],k);
}
LL query_tree(int x,int y)
{
LL res=0;
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]]) swap(x,y);
res=(res+query(1,dfn[top[x]],dfn[x]))%mod;
x=f[top[x]];
}
if(depth[x]>depth[y]) swap(x,y);
res=(res+query(1,dfn[x],dfn[y]))%mod;
return res;
}
int main()
{
scanf("%d%d%d%d",&n,&m,&R,&mod);
for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
for(int u,v,i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs_yy(R,0);
dfs_nlc(R,R);
build_tree(1,1,n);
while(m--)
{
int opt,x,y;
LL k;
scanf("%d%d",&opt,&x);
if(opt==1) scanf("%d%lld",&y,&k),updata_tree(x,y,k);
if(opt==3) scanf("%lld",&k),updata(1,dfn[x],dfn[x]+siz[x]-1,k);
if(opt==2) scanf("%d",&y),printf("%lld\n",query_tree(x,y)%mod);
if(opt==4) printf("%lld\n",query(1,dfn[x],dfn[x]+siz[x]-1)%mod);
}
return 0;
}
应用 [NOI2015] 软件包管理器
有 \(n\) 个软件安装包,它们的依赖关系构成一棵树,其中 \(1\) 号软件不依赖任何软件,为根节点。每次有两种操作:
install \(x\) 表示安装 \(x\) 号软件包
uninstall \(x\) 表示卸载 \(x\) 号软件包
若 \(A\) 依赖 \(B\),则安装 \(A\) 之前必须先安装 \(B\),卸载 \(B\) 之前必须先卸载 \(A\)。
求每次操作更改了多少个软件的状态。
数据范围

思路
通过观察可以发现,如果将软件包的状态标记为 \(0/1\) (未安装/已安装)。那么对于每一次安装操作,就是把 \(x\) 到根节点的路径上的 \(0\) 全部修改成 \(1\)。对于每一次卸载操作,就是把以 \(x\) 为根的子树内部的点全部修改成 \(0\)。
要统计每一次更改多少软件的状态,其实就是比较一下更改前和更改后,树中 \(1\) 的数量的变化。那么也就可以先将原树进行树链剖分操作,用线段树维护区间内 \(1\) 的数量。每次操作的答案就是更改的 \(1\) 的数量。
code:
#include<cstdio>
#include<cstring>
using namespace std;
const int N=1e5+10;
int h[N],idx,n,m,top[N],depth[N],siz[N],dfn[N],num,son[N],f[N];
struct tree{
int l,r,mid;
int val,tag;
}tr[N<<2];
struct edge{
int v,nex;
}e[N];
void swap(int &a,int &b){int t=a;a=b,b=t;}
void add(int u,int v){e[++idx].v=v;e[idx].nex=h[u];h[u]=idx;}
void push_up(int p){tr[p].val=tr[p<<1].val+tr[p<<1|1].val;}
void push_down(int p)
{
if(tr[p].tag!=-1)
{
tr[p<<1].val=(tr[p<<1].r-tr[p<<1].l+1)*tr[p].tag;
tr[p<<1|1].val=(tr[p<<1|1].r-tr[p<<1|1].l+1)*tr[p].tag;
tr[p<<1].tag=tr[p].tag;
tr[p<<1|1].tag=tr[p].tag;
tr[p].tag=-1;
}
}
void build_tree(int p,int l,int r)
{
tr[p].l=l,tr[p].r=r,tr[p].mid=(l+r)>>1;
if(l==r) return ;
build_tree(p<<1,tr[p].l,tr[p].mid);
build_tree(p<<1|1,tr[p].mid+1,tr[p].r);
push_up(p);
}
void updata(int p,int l,int r,int k)
{
if(l<=tr[p].l&&tr[p].r<=r)
{
tr[p].val=(tr[p].r-tr[p].l+1)*k;
tr[p].tag=k;
return ;
}
push_down(p);
if(l<=tr[p].mid) updata(p<<1,l,r,k);
if(r>tr[p].mid) updata(p<<1|1,l,r,k);
push_up(p);
}
int query(int p,int l,int r)
{
if(l<=tr[p].l&&tr[p].r<=r) return tr[p].val;
push_down(p);
int res=0;
if(l<=tr[p].mid) res+=query(p<<1,l,r);
if(r>tr[p].mid) res+=query(p<<1|1,l,r);
return res;
}
void dfs_yy(int u,int fa)
{
f[u]=fa,depth[u]=depth[fa]+1;siz[u]=1;
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;
dfs_yy(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs_nlc(int u,int t)
{
dfn[u]=++num;top[u]=t;
if(siz[u]==1) return;
dfs_nlc(son[u],t);
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;
if(v!=son[u]) dfs_nlc(v,v);
}
}
void updata_tree(int x,int y,int k)
{
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]]) swap(x,y);
updata(1,dfn[top[x]],dfn[x],k);
x=f[top[x]];
}
if(depth[x]>depth[y]) swap(x,y);
updata(1,dfn[x],dfn[y],k);
}
int main()
{
scanf("%d",&n);
for(int u,i=2;i<=n;i++)
{
scanf("%d",&u);
add(u+1,i);
}
dfs_yy(1,0);
dfs_nlc(1,1);
build_tree(1,1,n);
scanf("%d",&m);
for(int x,i=1;i<=m;i++)
{
char s[110];
scanf("%s%d",s,&x);
x++;
if(s[0]=='i')
{
int t=tr[1].val;
updata_tree(x,1,1);
printf("%d\n",tr[1].val-t);
}
else
{
int t=tr[1].val;
updata(1,dfn[x],dfn[x]+siz[x]-1,0);
printf("%d\n",t-tr[1].val);
}
}
return 0;
}
应用 [SDOI2011]染色
给定一棵有 \(n\) 个节点的无根树和 \(m\) 个操作,操作共两类。
将节点 \(a\) 到节点 \(b\) 路径上的所有节点都染上颜色;
询问节点 \(a\) 到节点 \(b\) 路径上的颜色段数量。
连续相同颜色的认为是同一段,例如 \(112221\) 由三段组成:\(11,222,1\)。
数据范围
\(1 \leq n,m\leq10^5\),\(0 \leq c \leq 10^9\),\(1 \leq a,b \leq n\)
思路
看到树上操作,显然可以用树链剖分求解。本题要维护的是颜色的段数,可以用线段树直接维护。具体操作就是记录区间左右端点的颜色,在合并时比较一下两段中间的颜色是否相同即可。难点在于查询。
对于查询树上的路径,因为要被分为 \(O(\log N)\) 段连续的区间,所以不能和线段树一样简单的合并。但是可以发现,由于树链剖分的性质。将一条链转化为一段连续的区间时,右端点一定是深度最大的那个点。所以可以在向上跳的时候记录当前链的最顶点的颜色。再向上跳的时候,判断这一次跳的链的深度最大的点和上一次跳的顶点颜色是否相同。
但是问题又来了,\(x,y\) 往上跳,但是每次只会跳一边,如果无脑合并显然是错误的。所以需要记录两边往上跳的顶点。
而在最后跳到一条链的时候,显然需要特判一下。
后面跳的那个点深度更低。因为两个点最后一次跳的时候,是顶点深度更大的点先跳,那么也就一定跳上最终链更深的地方(最后的判断深度来交换的操作是特判两个点在一条链的情况)。设最终链的左右端点颜色为 \(l,r\),先跳上最终链的顶点为 \(v2\),后跳上的为 \(v1\),那么只需判断一下 \(l\) 和 \(v1\),\(r\) 和 \(v2\) 是否相同即可。
code:
#include<cstdio>
using namespace std;
const int N=1e5+10;
int a[N],f[N],w[N],n,q,top[N],son[N],siz[N],depth[N],h[N],idx;
struct edge{
int v,nex;
}e[N<<1];
int dfn[N],num;
void add(int u,int v){e[++idx].v=v;e[idx].nex=h[u];h[u]=idx;}
struct tree{
int l,r,mid;
int lv,rv,sum,tag;
}tr[N<<2];
int max(int a,int b){return a>b?a:b;}
void swap(int &a,int &b){int t=a;a=b,b=t;}
void push_up(int p)
{
tr[p].sum=tr[p<<1].sum+tr[p<<1|1].sum-(tr[p<<1].rv==tr[p<<1|1].lv);
tr[p].lv=tr[p<<1].lv;
tr[p].rv=tr[p<<1|1].rv;
}
void push_down(int p)
{
if(tr[p].tag)
{
tr[p<<1].rv=tr[p<<1].lv=tr[p<<1|1].lv=tr[p<<1|1].rv=tr[p].tag;
tr[p<<1].sum=tr[p<<1|1].sum=1;
tr[p<<1].tag=tr[p].tag;
tr[p<<1|1].tag=tr[p].tag;
tr[p].tag=0;
}
}
void build_tree(int p,int l,int r)
{
tr[p].l=l,tr[p].r=r,tr[p].mid=(l+r)>>1;
if(l==r)
{
tr[p].sum=1;
tr[p].lv=tr[p].rv=a[l];
tr[p].tag=0;
return ;
}
build_tree(p<<1,tr[p].l,tr[p].mid);
build_tree(p<<1|1,tr[p].mid+1,tr[p].r);
push_up(p);
}
void updata(int p,int l,int r,int k)
{
if(l<=tr[p].l&&tr[p].r<=r)
{
tr[p].sum=1;
tr[p].lv=tr[p].rv=k;
tr[p].tag=k;
return ;
}
push_down(p);
if(l<=tr[p].mid) updata(p<<1,l,r,k);
if(r>tr[p].mid) updata(p<<1|1,l,r,k);
push_up(p);
}
tree query(int p,int l,int r)
{
if(l<=tr[p].l&&tr[p].r<=r) return tr[p];
push_down(p);
tree res1,res2,res;
if(l<=tr[p].mid&&r>tr[p].mid)
{
res1=query(p<<1,l,r);res2=query(p<<1|1,l,r);
res.sum=res1.sum+res2.sum-(res1.rv==res2.lv);
res.lv=res1.lv;res.rv=res2.rv;
}
else if(l<=tr[p].mid) res=query(p<<1,l,r);
else res=query(p<<1|1,l,r);
return res;
}
void dfs_yy(int u,int fa)
{
f[u]=fa,depth[u]=depth[fa]+1;siz[u]=1;
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;
if(v==fa) continue;
dfs_yy(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs_nlc(int u,int t)
{
dfn[u]=++num;top[u]=t;a[num]=w[u];
if(siz[u]==1) return ;
dfs_nlc(son[u],t);
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;
if(v==f[u]||v==son[u]) continue;
dfs_nlc(v,v);
}
}
void updata_tree(int x,int y,int k)
{
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]]) swap(x,y);
updata(1,dfn[top[x]],dfn[x],k);
x=f[top[x]];
}
if(depth[x]>depth[y]) swap(x,y);
updata(1,dfn[x],dfn[y],k);
}
int query_sum(int x,int y)
{
tree res;
int lv1=-1,lv2=-1,ans=0;
while(top[x]!=top[y])
{
if(depth[top[x]]<depth[top[y]]) swap(x,y),swap(lv1,lv2);
res=query(1,dfn[top[x]],dfn[x]);
ans+=res.sum-(lv1==res.rv);
lv1=res.lv;
x=f[top[x]];
}
if(depth[x]>depth[y]) swap(x,y),swap(lv1,lv2);
res=query(1,dfn[x],dfn[y]);
ans+=res.sum-(res.rv==lv2)-(res.lv==lv1);
return ans;
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int u,v,i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs_yy(1,0);
dfs_nlc(1,1);
build_tree(1,1,n);
while(q--)
{
int x,y,k;
char op[2];
scanf("%s%d%d",op,&x,&y);
if(op[0]=='C') scanf("%d",&k),updata_tree(x,y,k);
else if(op[0]=='Q') printf("%d\n",query_sum(x,y));
}
return 0;
}

浙公网安备 33010602011771号