[模板] 重链剖分/树链剖分
前言:
这是不知道第几次说要学树链剖分,终于是达成了。
题目描述:
如题,已知一棵包含 \(N\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
-
1 x y z,表示将树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值都加上 \(z\)。 -
2 x y,表示求树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值之和。 -
3 x z,表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)。 -
4 x,表示求以 \(x\) 为根节点的子树内所有节点值之和。
代码实现:
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int N = 1e5 + 10;
int n, Q, rt, mod;
int tot, cnt = 1, a[N], b[N];
vector<int> v[N];
int siz[N], idx[N], dfn[N], dep[N], top[N], fa[N];
void dfs(int x, int fa){
dfn[x] = ++tot, siz[x] = 1;
for(int y : v[x]){
if(y == fa) continue;
dfs(y, x);
siz[x] += siz[y];
}
}
bool cmp(int A, int B){return siz[A] > siz[B];}
int mx[N], mn[N];
void add(int &x, int y){x = (x + y) % mod;}
void dfs1(int x, int fa){
dep[x] = dep[fa] + 1, dfn[x] = ++tot, idx[x] = cnt;
mx[x] = mn[x] = dfn[x], ::fa[x] = fa;
bool flag = 0;
for(int y : v[x]){
if(y == fa) continue;
if(flag) top[++cnt] = y, dfs1(y, x);
else flag = 1, dfs1(y, x);
mx[x] = max(mx[x], mx[y]);
mn[x] = min(mn[x], mn[y]);
}
}
struct Segment_Tree{
int t[N << 2], tag[N << 2];
void pushup(int op){
int ls = op << 1, rs = op << 1 | 1;
t[op] = (t[ls] + t[rs]) % mod;
}
void pushdown(int l, int r, int op){
int mid = (l + r) >> 1, ls = op << 1, rs = op << 1 | 1;
if(tag[op]){
add(t[ls], tag[op] * (mid - l + 1) % mod), add(tag[ls], tag[op]);
add(t[rs], tag[op] * (r - mid) % mod), add(tag[rs], tag[op]);
tag[op] = 0;
}
}
void build(int l, int r, int op){
if(l == r) return t[op] = a[l], void();
int mid = (l + r) >> 1;
build(l, mid, op << 1);
build(mid + 1, r, op << 1 | 1);
pushup(op);
}
void update(int l, int r, int op, int x, int y, int val){
if(x <= l && r <= y){
add(t[op], val * (r - l + 1) % mod);
add(tag[op], val);
return;
}
pushdown(l, r, op);
int mid = (l + r) >> 1;
if(x <= mid) update(l, mid, op << 1, x, y, val);
if(y > mid) update(mid + 1, r, op << 1 | 1, x, y, val);
pushup(op);
}
int query(int l, int r, int op, int x, int y){
if(x <= l && r <= y) return t[op];
pushdown(l, r, op);
int mid = (l + r) >> 1, res = 0;
if(x <= mid) add(res, query(l, mid, op << 1, x, y));
if(y > mid) add(res, query(mid + 1, r, op << 1 | 1, x, y));
return res;
}
}T;
void solve(){
int op, x, y, z;
cin >> op;
if(op == 1){
cin >> x >> y >> z, z %= mod;
while(top[idx[x]] != top[idx[y]]){
if(dep[top[idx[x]]] < dep[top[idx[y]]]) swap(x, y);
T.update(1, n, 1, dfn[top[idx[x]]], dfn[x], z);
x = fa[top[idx[x]]];
}
T.update(1, n, 1, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]), z);
}else if(op == 2){
int ans = 0;
cin >> x >> y;
while(top[idx[x]] != top[idx[y]]){
if(dep[top[idx[x]]] < dep[top[idx[y]]]) swap(x, y);
add(ans, T.query(1, n, 1, dfn[top[idx[x]]], dfn[x]));
x = fa[top[idx[x]]];
}
add(ans, T.query(1, n, 1, min(dfn[x], dfn[y]), max(dfn[x], dfn[y])));
cout << ans << endl;
}else if(op == 3){
cin >> x >> z, z %= mod;
T.update(1, n, 1, mn[x], mx[x], z);
}else{
cin >> x;
cout << T.query(1, n, 1, mn[x], mx[x]) << endl;
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> Q >> rt >> mod;
for(int i = 1; i <= n; i++) cin >> a[i], a[i] %= mod;
for(int i = 1, x, y; i < n; i++){
cin >> x >> y;
v[x].push_back(y), v[y].push_back(x);
}
dfs(rt, 0);
for(int i = 1; i <= n; i++) sort(v[i].begin(), v[i].end(), cmp);
tot = 0, dfs1(rt, 0);
for(int i = 1; i <= n; i++) b[dfn[i]] = a[i];
for(int i = 1; i <= n; i++) a[i] = b[i];
T.build(1, n, 1);
while(Q--) solve();
return 0;
}

浙公网安备 33010602011771号