树上问题
树链剖分
多指重链剖分
可以解决的问题:
\(1\). 修改 树上两点之间的路径上 所有点的值。
\(2\). 查询 树上两点之间的路径上 节点权值的 和/极值/其它在序列上可以用数据结构维护,便于合并的信息。
概念:
- 重儿子:对于每一个非叶子节点,它的儿子中 儿子数量最多 的那一个儿子为该节点的重儿子
- 轻儿子:对于每一个非叶子节点,它的儿子中非 重儿子 的剩下所有儿子即为轻儿子
- 叶子节点没有重儿子也没有轻儿子(因为它没有儿子)
- 重边:连接任意两个重儿子的边叫做重边
- 轻边:剩下的即为轻边
- 重链:相邻重边连起来的 连接一条重儿子 的链叫重链
- 对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为 \(1\) 的链
- 每一条重链以轻儿子为起点
实现:
第一个 \(dfs\)
- 标记每个点的深度 \(dep\)
- 标记每个点的父亲 \(fa\)
- 标记每个非叶子节点的子树大小(含它自己)
- 标记每个非叶子节点的重儿子编号 \(hes\)
void dfs(int x){
siz[x]=1;
dep[x]=dep[fa[x]]+1;
for(int i=0;i<q[x].size();i++){
int y=q[x][i];
if(y==fa[x]) continue;
fa[y]=x;
dfs(y);
siz[x]+=siz[y];
if(!hes[x]||siz[hes[x]]<siz[y]) hes[x]=y;
}
}
第二个 \(dfs\)
- 标记每个点的新编号
- 赋值每个点的初始值到新编号上
- 处理每个点所在链的顶端
- 处理每条链
void dfs1(int x,int v){
top[x]=v;
if(hes[x]) dfs1(hes[x],v);
for(int i=0;i<q[x].size();i++){
int y=q[x][i];
if(y==fa[x]||y==hes[x]) continue;
dfs1(y,y);
}
}
因为用的是 \(dfs\) 序遍历,所以每条链的编号都是连续的。
接下来可以在此基础上做一些操作来解决其他问题。
比如说下面两道题:
P3379 【模板】最近公共祖先(LCA),P3384 【模板】重链剖分/树链剖分
一个就是纯树链剖分,另一个还要用线段树进行区间查询。
代码如下:
#include<bits/stdc++.h>
using namespace std;
const int N=2e6+10;
vector<int> q[N];
struct Node{
int l,r,val,siz,f;
}tr[N];
int p,cnt,a[N],b[N];
int dep[N],fa[N],hes[N],siz[N],top[N],idx[N];
void dfs(int x){
dep[x]=dep[fa[x]]+1;
siz[x]=1;
for(int i=0;i<q[x].size();i++){
int y=q[x][i];
if(y==fa[x]) continue;
fa[y]=x;
dfs(y);
siz[x]+=siz[y];
if(!hes[x]||siz[hes[x]]<siz[y]) hes[x]=y;
}
}
void update(int x){
tr[x].val=(tr[x*2].val+tr[x*2+1].val+p)%p;
}
void build(int x,int l,int r){
tr[x].l=l,tr[x].r=r;
tr[x].siz=r-l+1;
if(l==r){
tr[x].val=a[l]%p;
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid);
build(x*2+1,mid+1,r);
update(x);
}
void dfs1(int x,int v){
idx[x]=++cnt;
a[cnt]=b[x];
top[x]=v;
if(!hes[x]) return;
dfs1(hes[x],v);
for(int i=0;i<q[x].size();i++){
int y=q[x][i];
if(!idx[y]) dfs1(y,y);
}
}
void pushdown(int x){
if(!tr[x].f) return;
tr[x*2].val=(tr[x*2].val+tr[x*2].siz*tr[x].f)%p;
tr[x*2+1].val=(tr[x*2+1].val+tr[x*2+1].siz*tr[x].f)%p;
tr[x*2].f=(tr[x*2].f+tr[x].f)%p;
tr[x*2+1].f=(tr[x*2+1].f+tr[x].f)%p;
tr[x].f=0;
}
void intervaladd(int x,int l,int r,int val){
if(l<=tr[x].l&&tr[x].r<=r){
tr[x].val+=tr[x].siz*val;
tr[x].f+=val;
return;
}
pushdown(x);
int mid=(tr[x].l+tr[x].r)>>1;
if(l<=mid) intervaladd(x*2,l,r,val);
if(r>mid) intervaladd(x*2+1,l,r,val);
update(x);
}
void treeadd(int x,int y,int val){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
intervaladd(1,idx[top[x]],idx[x],val);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
intervaladd(1,idx[x],idx[y],val);
}
int intervalsum(int x,int l,int r){
int ans=0;
if(l<=tr[x].l&&tr[x].r<=r) return tr[x].val;
pushdown(x);
int mid=(tr[x].l+tr[x].r)>>1;
if(l<=mid) ans=(ans+intervalsum(x*2,l,r))%p;
if(r>mid) ans=(ans+intervalsum(x*2+1,l,r))%p;
return ans;
}
void treesum(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+intervalsum(1,idx[top[x]],idx[x]))%p;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+intervalsum(1,idx[x],idx[y]))%p;
cout<<ans<<endl;
}
int main(){
int n,m,s,x,y,z,opt;
cin>>n>>m>>s>>p;
for(int i=1;i<=n;i++) cin>>b[i];
for(int i=1;i<n;i++){
cin>>x>>y;
q[x].push_back(y);
q[y].push_back(x);
}
dfs(s);
dfs1(s,s);
build(1,1,n);
while(m--){
cin>>opt;
if(opt==1){
cin>>x>>y>>z,z=z%p;
treeadd(x,y,z);
}
else if(opt==2){
cin>>x>>y;
treesum(x,y);
}
else if(opt==3){
cin>>x>>z;
intervaladd(1,idx[x],idx[x]+siz[x]-1,z%p);
}
else if(opt==4){
cin>>x;
cout<<intervalsum(1,idx[x],idx[x]+siz[x]-1)<<endl;
}
}
return 0;
}

浙公网安备 33010602011771号