Loading

树链剖分-重链剖分

参考资料:

https://www.cnblogs.com/ivanovcraft/p/9019090.html
https://www.cnblogs.com/hanruyun/p/9577500.html


前置知识:

树的性质(深度,子树大小,距离,祖先等)
线段树,树状数组


树链剖分

树链剖分可以说是个数据结构,但是更是个存储、操作树上信息的方法。树链剖分的主要思想是把一棵树拆成若干条链,并建立数据结构进行存储、操作,根据拆的方法不同,分为重链剖分和长链剖分。重链剖分用得较多,蒟蒻也更喜欢重链剖分,这篇就讲重链剖分。


重链剖分

例题

为了让萌新对重链剖分有所了解,蒟蒻拿出了例题(几乎包括了重链剖分的所有操作):

【模板】轻重链剖分

有一棵 \(n\) 个节点的树,每个节点有权值,初始时为 \(a\{n\}\)。有 \(m\) 个如下操作:
操作 \(1\)\(1\ x\ y\ z\) 表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)
操作 \(2\)\(2\ x\ y\) 表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。
操作 \(3\)\(3\ x\ z\) 表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)
操作 \(4\)\(4\ x\) 表示求以 \(x\) 为根节点的子树内所有节点值之和。

数据范围:\(1\le n,m\le 10^5\)


重点

看了例题后,先看重链剖分关键概念:

重儿子:父亲节点的子节点中子树大小最大的。

轻儿子:父亲节点的子节点中除了重儿子的节点。

重边:父亲节点与重儿子连成的边。

轻边:父亲节点与轻儿子连成的边(即除了重边以外的边)。

重链:由重边构成的极长路径。

轻链:由轻边构成的极长路径。

链头:重链深度最小的节点,是个轻儿子

特殊的,根节点一般被看作轻儿子。

如下图:

黑色节点是轻儿子,橙色节点是重儿子。黑色边是轻边,橙色边是重边。底下有橙线的节点单独为一条重链。

图中一共有 \(5\) 条重链:

  1. \(1\to 3\to 8\to 9\)
  2. \(2\to 5\)
  3. \(4\)
  4. \(6\)
  5. \(7\)

每个轻儿子下面都有一条重链,所以重链数 \(=\) 轻儿子数

这里有一个关键性的性质,使得树链剖分有实际意义:

任何一个节点到根节点的路径中,包含不超过 \(\log n\) 条重链。

简证:重链与重链间用轻边连接,所以一个节点到根节点的路径中的重链数 \(=\) 轻边数 \(+1\)。轻边连的子树大小较小,所以必然 \(< \frac 12\) 父亲节点子树大小。所以一个节点到根节点的路径中最多有 \(\log_2n\) 条轻边,所以一个节点到根节点的路径中,包含不超过 \(\log n\) 条重链。

如果重节点优先\(\texttt{Dfs}\) 遍历整棵树,为每个节点标上序号,如下:

那么同个重链、同个子树的节点序号就是连续的了。

所以可以维护棵线段树,把修改查询子树转换为区间修改查询,把修改查询最短路径转换为几个区间修改查询(上文提到,任何一个节点到根节点的路径中,包含不超过 \(\log n\) 条重链。所以可以通过修改 \(\log n\) 个区间,达到 \(\Theta(\log n)\) 的修改查询最短路径的目的)。

然后树链剖分的思想就到这里了,如果你理解了,就可以看操作实现了。


操作实现

重链剖分的第一步都是两个 \(\texttt{Dfs}\)

第一个求出每个节点的深度、子树大小、父亲节点、重儿子

void Dfs1(int x){
	sz[x]=1,dep[x]=dep[fa[x]]+1;
	for(int to:e[x])if(to!=fa[x]){
		fa[to]=x,Dfs1(to),sz[x]+=sz[to];
		if(sz[to]>sz[son[x]]) son[x]=to;
	}
}

第二个找每条重链,并且为节点标上序号

void Dfs2(int x,int an){
	tp[x]=an,dfn[x]=++ind,rk[ind]=x;
	if(son[x]) Dfs2(son[x],an);
	for(int to:e[x]) if(to!=fa[x]&&to!=son[x]) Dfs2(to,to);
}

\(\texttt{Dfs}\) 完后,可以造数据结构了。就拿例题来说,需要造一棵线段树。线段树的任务有维护区间和,区间修改。可以用 \(\texttt{lazytag+pushdown}\)

lng v[(N<<2)+7],ma[(N<<2)+7];
#define mid ((l+r)>>1)
void pushdown(int k,int l,int r){
	if(!ma[k]) return;
	(v[k<<1]+=ma[k]*(mid-l+1))%=mod;
	(v[k<<1|1]+=ma[k]*(r-mid))%=mod;
	(ma[k<<1]+=ma[k])%=mod;
	(ma[k<<1|1]+=ma[k])%=mod;
	ma[k]=0;
}
void build(int k=1,int l=1,int r=n){
	if(l==r){v[k]=a[rk[l]];return;}
	build(k<<1,l,mid),build(k<<1|1,mid+1,r),v[k]=v[k<<1]+v[k<<1|1];
}
void fixson(int x,int y,lng z,int k=1,int l=1,int r=n){ // 修改子树(序号区间)权值
	if(x<=l&&r<=y){(ma[k]+=z)%=mod,v[k]+=z*(r-l+1);return;}
	pushdown(k,l,r);
	if(mid>=x)fixson(x,y,z,k<<1,l,mid);
	if(mid<y)fixson(x,y,z,k<<1|1,mid+1,r);
	v[k]=v[k<<1]+v[k<<1|1];
}
lng sumson(int x,int y,int k=1,int l=1,int r=n){ // 求子树(序号区间)权值和
	if(x<=l&&r<=y) return v[k];
	lng res=0;
	pushdown(k,l,r);
	if(mid>=x) res+=sumson(x,y,k<<1,l,mid);
	if(mid<y)res+=sumson(x,y,k<<1|1,mid+1,r);
	return res%mod;
}

至于修改查询最短路径,依赖于序号区间修改查询。依次遍历最短路径上的每条重链,整个操作过程如同求 \(\texttt{LCA}\)

void fixdis(int x,int y,int z){ //修改最短路径权值
	for(;tp[x]!=tp[y];fixson(dfn[tp[x]],dfn[x],z),x=fa[tp[x]])
		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
	fixson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y]),z);
}
lng sumdis(int x,int y){ //求最短路径权值和
	lng res=0;
	for(;tp[x]!=tp[y];(res+=sumson(dfn[tp[x]],dfn[x]))%=mod,x=fa[tp[x]])
		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
	return (res+sumson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y])))%mod;
}

然后例题就迎刃而解了。


代码

#include <bits/stdc++.h>
using namespace std;

//Start
#define lng long long
#define db double
#define mk make_pair
#define pb push_back
#define fi first
#define se second
#define rz resize
const int inf=0x3f3f3f3f;
const lng INF=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1e5;
int n,m,rt; lng mod,a[N+7]; vector<int> e[N+7];

//Treesplit
int ind,fa[N+7],sz[N+7],dep[N+7],son[N+7],tp[N+7],dfn[N+7],rk[N+7];
void Dfs1(int x){
	sz[x]=1,dep[x]=dep[fa[x]]+1;
	for(int to:e[x])if(to!=fa[x]){
		fa[to]=x,Dfs1(to),sz[x]+=sz[to];
		if(sz[to]>sz[son[x]]) son[x]=to;
	}
}
void Dfs2(int x,int an){
	tp[x]=an,dfn[x]=++ind,rk[ind]=x;
	if(son[x]) Dfs2(son[x],an);
	for(int to:e[x]) if(to!=fa[x]&&to!=son[x]) Dfs2(to,to);
}
lng v[(N<<2)+7],ma[(N<<2)+7];
#define mid ((l+r)>>1)
void pushdown(int k,int l,int r){
	if(!ma[k]) return;
	(v[k<<1]+=ma[k]*(mid-l+1))%=mod;
	(v[k<<1|1]+=ma[k]*(r-mid))%=mod;
	(ma[k<<1]+=ma[k])%=mod;
	(ma[k<<1|1]+=ma[k])%=mod;
	ma[k]=0;
}
void build(int k=1,int l=1,int r=n){
	if(l==r){v[k]=a[rk[l]];return;}
	build(k<<1,l,mid),build(k<<1|1,mid+1,r),v[k]=v[k<<1]+v[k<<1|1];
}
void fixson(int x,int y,lng z,int k=1,int l=1,int r=n){
	if(x<=l&&r<=y){(ma[k]+=z)%=mod,v[k]+=z*(r-l+1);return;}
	pushdown(k,l,r);
	if(mid>=x)fixson(x,y,z,k<<1,l,mid);
	if(mid<y)fixson(x,y,z,k<<1|1,mid+1,r);
	v[k]=v[k<<1]+v[k<<1|1];
}
lng sumson(int x,int y,int k=1,int l=1,int r=n){
	if(x<=l&&r<=y) return v[k];
	lng res=0;
	pushdown(k,l,r);
	if(mid>=x) res+=sumson(x,y,k<<1,l,mid);
	if(mid<y)res+=sumson(x,y,k<<1|1,mid+1,r);
	return res%mod;
}
void fixdis(int x,int y,int z){
	for(;tp[x]!=tp[y];fixson(dfn[tp[x]],dfn[x],z),x=fa[tp[x]])
		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
	fixson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y]),z);
}
lng sumdis(int x,int y){
	lng res=0;
	for(;tp[x]!=tp[y];(res+=sumson(dfn[tp[x]],dfn[x]))%=mod,x=fa[tp[x]])
		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
	return (res+sumson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y])))%mod;
}

//Main
int main(){
	scanf("%d%d%d%lld",&n,&m,&rt,&mod);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
	for(int i=1,x,y;i<=n-1;i++) scanf("%d%d",&x,&y),e[x].pb(y),e[y].pb(x);
	Dfs1(rt),Dfs2(rt,rt),build();
	for(int i=1;i<=m;i++){
		int o,x,y; lng z; scanf("%d",&o);
		if(o==1) scanf("%d%d%lld",&x,&y,&z),fixdis(x,y,z);
		else if(o==2) scanf("%d%d",&x,&y),printf("%lld\n",sumdis(x,y));
		else if(o==3) scanf("%d%lld",&x,&z),fixson(dfn[x],dfn[x]+sz[x]-1,z);
		else if(o==4) scanf("%d",&x),printf("%lld\n",sumson(dfn[x],dfn[x]+sz[x]-1));
	}
	return 0;
}

特例

有些题目要维护的是边的权值(例题维护的是点的权值),比如[USACO11DEC]Grass Planting G,方法是用点存储它与父亲节点相连的边的权值,具体看代码:

#include <bits/stdc++.h>
using namespace std;

//Start
#define lng long long
#define db double
#define mk make_pair
#define pb push_back
#define fi first
#define se second
#define rz resize
const int inf=0x3f3f3f3f;
const lng INF=0x3f3f3f3f3f3f3f3f;

//Data
const int N=1e5;
int n,m;
vector<int> e[N+7];

//Treesplit
int ind,fa[N+7],son[N+7],sz[N+7],dep[N+7],dfn[N+7],rk[N+7],tp[N+7];
void Dfs1(int x){
	sz[x]=1,dep[x]=dep[fa[x]]+1;
	for(int to:e[x])if(to!=fa[x]){
		fa[to]=x,Dfs1(to),sz[x]+=sz[to];
		if(sz[to]>sz[son[x]]) son[x]=to;
	}
}
void Dfs2(int x,int f){
	tp[x]=f,dfn[x]=++ind,rk[ind]=x;
	if(son[x]) Dfs2(son[x],f);
	for(int to:e[x])if(to!=fa[x]&&to!=son[x]) Dfs2(to,to);
}
int v[(N<<2)+7],ma[(N<<2)+7];
#define mid ((l+r)>>1)
void pushdown(int k,int l,int r){
	if(!ma[k]) return;
	v[k<<1]+=ma[k],v[k<<1|1]+=ma[k];
	ma[k<<1]+=ma[k],ma[k<<1|1]+=ma[k],ma[k]=0;
}
void fixson(int x,int y,int z,int k=1,int l=1,int r=n){
	if(x<=l&&r<=y){v[k]+=z,ma[k]+=z;return;}
	if(l==r) return; // 因为可能出现无边情况
	pushdown(k,l,r);
	if(mid>=x) fixson(x,y,z,k<<1,l,mid);
	if(mid<y) fixson(x,y,z,k<<1|1,mid+1,r);
	v[k]=max(v[k<<1],v[k<<1|1]);
}
int sumson(int x,int y,int k=1,int l=1,int r=n){
	if(x<=l&&r<=y) return v[k];
	if(l==r) return 0; // 因为可能出现无边情况
	int res=0;
	pushdown(k,l,r);
	if(mid>=x) res=max(res,sumson(x,y,k<<1,l,mid));
	if(mid<y) res=max(res,sumson(x,y,k<<1|1,mid+1,r));
	return res;
}
void fixdis(int x,int y,int z){
	for(;tp[x]!=tp[y];fixson(dfn[tp[x]],dfn[x],z),x=fa[tp[x]])
		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
	fixson(min(dfn[x],dfn[y])+1,max(dfn[x],dfn[y]),z); //+1 表示最短路径的最近公共祖先和其父亲连的边不改
}
int sumdis(int x,int y){
	int res=0;
	for(;tp[x]!=tp[y];res+=sumson(dfn[tp[x]],dfn[x]),x=fa[tp[x]])
		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
	return res+sumson(min(dfn[x],dfn[y])+1,max(dfn[x],dfn[y])); //+1 表示最短路径的最近公共祖先和其父亲连的边不算
}

//Main
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),e[x].pb(y),e[y].pb(x);
	Dfs1(1),Dfs2(1,1);
	for(int i=1;i<=m;i++){
		vector<char> s(5); int x,y;
		scanf("%s%d%d",&s[1],&x,&y);
		if(s[1]=='P') fixdis(x,y,1);
		else if(s[1]=='Q') printf("%d\n",sumdis(x,y));
	}
	return 0;
}

然后就这样讲完了,非常简单的东西。例题就不给了,到洛谷上随便撸几道即可。


祝大家学习愉快!

posted @ 2020-04-10 15:59  George1123  阅读(293)  评论(0编辑  收藏  举报