长链剖分
长链剖分
1. 概念
类似于轻重链剖分,定义 子树最深 的子节点为重儿子,长链为重边组成的链。
一些性质:
-
所有长链大小之和为 \(n\),终点都是叶子;
-
某一点向上跳长链,最多跳 \(O(\sqrt{n})\) 次。
考虑从一个点 \(x\) 开始,每跳一条轻边,一定有另一个子节点深度 \(\geq maxdep[x]\)。相邻两个这样的 \(maxdep[x]\) 之前至少差 \(1\),所以最多 \(O(\sqrt{n})\) 个。
这样看来,长链比轻重链要劣,所以一般不这样用。
2. 树上 k 级祖先
对于每条大小为 \(C\) 的长链,都处理出长为 \(2C\) 的序列,按从上到下的顺序储存 链首向上 \(C\) 个节点 与 链内的节点。
首先处理出 \(f_{i,j}\) 表示 \(i\) 的 \(2^j\) 级祖先。
对于询问 \((x,k)\),若 \(k=0\),则答案为 \(x\);
否则,找到最大的 \(w\) 满足 \(2^w\leq k\),令 \(x'= f_{x,w}\)。
\(x'\) 所在长链的长度一定 \(\geq 2^w\),所以接下来在序列中直接查 \(x'\) 的 \(k-2^w\) 级祖先即可。
3. 长链剖分优化 DP
当树上 dp 的某一维与深度有关系时,可以应用长链剖分来优化。
处理每个点时,直接继承其重儿子的信息,暴力合并轻儿子。每条长链只在链头被合并一次,所以时间复杂度 \(O(n)\)。
Problem A. [ARC086E] Smuggling Marbles
不同层之间互不影响,所以可以独立考虑每一层。
为了方便转移,考虑先计算概率,最后乘 \(2^{n+1}\)。
设 \(f_{x,d}\) 为 \(x\) 子树内与 \(x\) 相距 \(d\) 的节点到达 \(x\) 后有石子的概率。容易得到转移:
\[f_{x,d}=\sum_{v\in son(x)} (f_{v,d-1})\prod_{v'\in son(x)\land v'\neq v} (1-f_{v',d-1})
\]
长链剖分优化即可 \(O(n)\)。
int n,fa[N];
int tot,head[N],dep[N],mxd[N],son[N],siz[N];
ll _f[N],*f[N],*fp=_f,g[N],pw[N];
const ll mod=1e9+7,inv2=(mod+1)>>1;
inline ll Mod(ll x){return (x>=mod)?(x-mod):(x);}
struct Edge{
int to,nxt;
}edge[N];
void Add(int u,int v){
edge[++tot]={v,head[u]};
head[u]=tot;
}
void dfs1(int x){
mxd[x]=1; int mx=0; siz[x]=1;
for(int i=head[x];i;i=edge[i].nxt){
int y=edge[i].to;
dep[y]=dep[x]+1;
dfs1(y);
siz[x]+=siz[y];
Ckmax(mxd[x],mxd[y]+1);
if(mxd[y]>mx) mx=mxd[y],son[x]=y;
}
}
void Allocate(int x){
f[x]=fp;
fp+=mxd[x];
}
void dfs2(int x,int tp){
if(x==tp) Allocate(x);
if(son[x]){
f[son[x]]=f[x]+1;
dfs2(son[x],tp);
}
int mx=0;
for(int i=head[x];i;i=edge[i].nxt){
int y=edge[i].to;
if(y==son[x]) continue;
dfs2(y,y);
Ckmax(mx,mxd[y]);
}
f[x][0]=inv2;
for(int i=1;i<=mx;i++) g[i]=Mod(1-f[x][i]+mod);
for(int i=head[x];i;i=edge[i].nxt){
int y=edge[i].to;
if(y==son[x]) continue;
for(int d=0;d<mxd[y];d++){
f[x][d+1]=(f[x][d+1]*Mod(1-f[y][d]+mod)+f[y][d]*g[d+1])%mod;
(g[d+1]*=Mod(1-f[y][d]+mod))%=mod;
}
}
}
signed main(){
read(n);
for(int i=1;i<=n;i++){
read(fa[i]);
Add(fa[i],i);
}
pw[0]=1;
for(int i=1;i<=n+1;i++) pw[i]=Mod(pw[i-1]<<1);
dep[0]=1; dfs1(0); dfs2(0,0);
ll ans=0;
for(int i=0;i<mxd[0];i++) ans=Mod(ans+f[0][i]);
(ans*=pw[n+1])%=mod;
printf("%lld\n",ans);
return 0;
}

浙公网安备 33010602011771号