【重链剖分】洛谷 P3384 【模板】轻重链剖分
学习自大型同性交友网站,这个B
站up主讲的非常好!
终于学会了第一个维护树上问题的方法,我好兴奋啊!
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define ll long long
#define ULL unsigned long long
#define Pair pair<LL,LL>
#define ls rt<<1
#define rs rt<<1|1
#define Pi acos(-1.0)
#define eps 1e-6
#define DBINF 1e100
//#define mod 998244353
#define MAXN 1e18
#define MS 100009
LL mod;
int n,m,rt,k;
vector<int > vc[MS];
LL w[MS]; // w[i] = 点 i 权值
int fa[MS]; // fa[i] = 点 i 父亲
int sz[MS]; // sz[i] = 包含点 i 的子树大小
int zson[MS]; // zson[i] = 点 i 的重子节点
int dep[MS]; // dep[i] = 点 i 的深度
int tim; // dfs序时间戳
int top[MS]; // top[i] = 点 i 的链头
int dfn[MS]; // dfn[i] = 点 i 的时间戳
LL val[MS]; // val[i] = 时间戳为 i 对应点的权值
LL p[MS<<2]; // 用于线段树维护区间总和
LL la[MS<<2];
void dfs1(int u,int f){ // 第一遍 dfs => fa[] ,sz[] ,zson[] ,dep[]
fa[u] = f;
sz[u] = 1;
dep[u] = dep[f] + 1;
zson[u] = 0;
int maxn_zson = 0;
for(auto &v:vc[u]){
if(v == f) continue;
dfs1(v,u);
sz[u] += sz[v];
if(sz[v] > maxn_zson){
zson[u] = v;
maxn_zson = sz[v];
}
}
}
void dfs2(int u,int tp){ //第二遍 dfs => top[] ,dfn[] ,val[]
dfn[u] = ++tim;
top[u] = tp;
val[tim] = w[u];
if(zson[u]) dfs2(zson[u],tp);
for(auto &v:vc[u]){
if(v != fa[u] && v != zson[u]){
dfs2(v,v);
}
}
}
void push_up(int rt){
p[rt] = p[ls] + p[rs];
p[rt] %= mod;
}
void push_down(int rt,int l,int r){
if(la[rt]){
int m = l+r>>1;
p[ls] += (m-l+1)*la[rt]; p[ls] %= mod;
p[rs] += (r-m)*la[rt]; p[rs] %= mod;
la[ls] += la[rt]; la[ls] %= mod;
la[rs] += la[rt]; la[rs] %= mod;
la[rt] = 0;
}
}
void build(int l,int r,int rt){
if(l == r){
p[rt] = val[l];
return;
}
int m = l+r>>1;
build(l,m,ls);
build(m+1,r,rs);
push_up(rt);
}
void update(int L,int R,int l,int r,int rt,LL tar){ // 区间更新
if(L <= l && r <= R){
p[rt] += (r-l+1)*tar; p[rt] %= mod;
la[rt] += tar; la[rt] %= mod;
return;
}
push_down(rt,l,r);
int m = l+r>>1;
if(m >= L) update(L,R,l,m,ls,tar);
if(m < R) update(L,R,m+1,r,rs,tar);
push_up(rt);
}
LL query(int L,int R,int l,int r,int rt){ // 区间查询
if(L <= l && r <= R){
return p[rt] % mod;
}
push_down(rt,l,r);
LL ans = 0;
int m = l+r>>1;
if(m >= L) ans += query(L,R,l,m,ls) ,ans %= mod;
if(m < R) ans += query(L,R,m+1,r,rs) ,ans %= mod;
return ans%mod;
}
void op1(int x,int y,LL z){
while(top[x] != top[y]){
if(dep[ top[x] ] < dep[ top[y] ]) swap(x,y); // ***选择 *链头深度大 的向上跳
update(dfn[top[x]],dfn[x],1,n,1,z);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
update(dfn[x],dfn[y],1,n,1,z);
}
LL op2(int x,int y){
LL ans = 0;
while(top[x] != top[y]){
if(dep[ top[x] ] < dep[ top[y] ]) swap(x,y); // ***选择 *链头深度大 的向上跳
ans += query(dfn[top[x]],dfn[x],1,n,1);
ans %= mod;
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
ans += query(dfn[x],dfn[y],1,n,1);
return ans%mod;
}
void op3(int x,LL z){
update(dfn[x],dfn[x]+sz[x]-1,1,n,1,z);
}
LL op4(int x){
return query(dfn[x],dfn[x]+sz[x]-1,1,n,1) % mod;
}
int main() {
ios::sync_with_stdio(false);
cin >> n >> m >> rt >> mod;
for(int i=1;i<=n;i++) cin >> w[i] ,w[i] %= mod;
for(int i=1;i<=n-1;i++){ // 建图
int u,v;
cin >> u >> v;
vc[u].push_back(v);
vc[v].push_back(u);
}
dfs1(rt,rt);
dfs2(rt,rt);
build(1,n,1);
while(m--){
int op,x,y;
LL z;
cin >> op;
if(op == 1){
cin >> x >> y >> z;
op1(x,y,z);
}
else if(op == 2){
cin >> x >> y;
cout << op2(x,y) << "\n";
}
else if(op == 3){
cin >> x >> z;
op3(x,z);
}
else if(op == 4){
cin >> x;
cout << op4(x) << "\n";
}
}
return 0;
}