长链剖分

长链剖分

1. 概念

类似于轻重链剖分,定义 子树最深 的子节点为重儿子,长链为重边组成的链。

一些性质:

  1. 所有长链大小之和为 \(n\),终点都是叶子;

  2. 某一点向上跳长链,最多跳 \(O(\sqrt{n})\) 次。

    考虑从一个点 \(x\) 开始,每跳一条轻边,一定有另一个子节点深度 \(\geq maxdep[x]\)。相邻两个这样的 \(maxdep[x]\) 之前至少差 \(1\),所以最多 \(O(\sqrt{n})\) 个。

    这样看来,长链比轻重链要劣,所以一般不这样用。

2. 树上 k 级祖先

对于每条大小为 \(C\) 的长链,都处理出长为 \(2C\) 的序列,按从上到下的顺序储存 链首向上 \(C\) 个节点 与 链内的节点。

首先处理出 \(f_{i,j}\) 表示 \(i\)\(2^j\) 级祖先。

对于询问 \((x,k)\),若 \(k=0\),则答案为 \(x\);

否则,找到最大的 \(w\) 满足 \(2^w\leq k\),令 \(x'= f_{x,w}\)

\(x'\) 所在长链的长度一定 \(\geq 2^w\),所以接下来在序列中直接查 \(x'\)\(k-2^w\) 级祖先即可。

3. 长链剖分优化 DP

当树上 dp 的某一维与深度有关系时,可以应用长链剖分来优化。

处理每个点时,直接继承其重儿子的信息,暴力合并轻儿子。每条长链只在链头被合并一次,所以时间复杂度 \(O(n)\)

Problem A. [ARC086E] Smuggling Marbles

不同层之间互不影响,所以可以独立考虑每一层。

为了方便转移,考虑先计算概率,最后乘 \(2^{n+1}\)

\(f_{x,d}\)\(x\) 子树内与 \(x\) 相距 \(d\) 的节点到达 \(x\) 后有石子的概率。容易得到转移:

\[f_{x,d}=\sum_{v\in son(x)} (f_{v,d-1})\prod_{v'\in son(x)\land v'\neq v} (1-f_{v',d-1}) \]

长链剖分优化即可 \(O(n)\)

int n,fa[N];
int tot,head[N],dep[N],mxd[N],son[N],siz[N];
ll _f[N],*f[N],*fp=_f,g[N],pw[N];

const ll mod=1e9+7,inv2=(mod+1)>>1;

inline ll Mod(ll x){return (x>=mod)?(x-mod):(x);}

struct Edge{
	int to,nxt;
}edge[N];

void Add(int u,int v){
	edge[++tot]={v,head[u]};
	head[u]=tot;
}

void dfs1(int x){
	mxd[x]=1; int mx=0; siz[x]=1;
	for(int i=head[x];i;i=edge[i].nxt){
		int y=edge[i].to;
		dep[y]=dep[x]+1;
		dfs1(y);
		siz[x]+=siz[y];
		Ckmax(mxd[x],mxd[y]+1);
		if(mxd[y]>mx) mx=mxd[y],son[x]=y;
	}
}

void Allocate(int x){
	f[x]=fp;
	fp+=mxd[x];
}

void dfs2(int x,int tp){
	if(x==tp) Allocate(x);
	if(son[x]){
		f[son[x]]=f[x]+1;
		dfs2(son[x],tp);
	}
	int mx=0;
	for(int i=head[x];i;i=edge[i].nxt){
		int y=edge[i].to;
		if(y==son[x]) continue;
		dfs2(y,y);
		Ckmax(mx,mxd[y]);
	}
	f[x][0]=inv2;
	for(int i=1;i<=mx;i++) g[i]=Mod(1-f[x][i]+mod);
	for(int i=head[x];i;i=edge[i].nxt){
		int y=edge[i].to;
		if(y==son[x]) continue;
		for(int d=0;d<mxd[y];d++){
			f[x][d+1]=(f[x][d+1]*Mod(1-f[y][d]+mod)+f[y][d]*g[d+1])%mod;
			(g[d+1]*=Mod(1-f[y][d]+mod))%=mod;
		}
	}
}

signed main(){
	read(n);
	for(int i=1;i<=n;i++){
		read(fa[i]);
		Add(fa[i],i);
	}
	pw[0]=1;
	for(int i=1;i<=n+1;i++) pw[i]=Mod(pw[i-1]<<1);
	dep[0]=1; dfs1(0); dfs2(0,0);
	ll ans=0;
	for(int i=0;i<mxd[0];i++) ans=Mod(ans+f[0][i]);
	(ans*=pw[n+1])%=mod;
	printf("%lld\n",ans);
    return 0;
}
posted @ 2025-08-05 21:27  XP3301_Pipi  阅读(39)  评论(0)    收藏  举报
Title