树链剖分
学了好长时间没写题,忘了,所以写 blog 以记之。
树链剖分是把一棵树拆成若干条不相交的链,每条链内点的时间戳都是连续的,再用数据结构去维护这些链。
先了解一些术语
- 重儿子:一个点所有儿子中,子树最大的那个,有多个任选一个。
- 重边:连接一个点和它重儿子的边。
- 轻儿子:一个点除了重儿子外的所有点。
- 轻边:连接一个点与它轻儿子的边。
- 重链:一条由重边构成的链,特殊地,将落单的点也看作重链。所有的重链包括了树的所有点。
树剖的核心是求出重链,想要求重链要先求重儿子 s o n son son,想求重儿子要求每个点子树的大小 s z sz sz。为了让重链的节点编号连续,在打时间戳 d f n dfn dfn 的 dfs 中,要优先遍历重儿子。为了维护每条链的信息,还需要知道每条重链深度最小的节点 t o p top top。另外,还需要用到深度 d e p dep dep 和每个点的父亲 f a fa fa。
这些信息要分两次 dfs 求,dfs1 求 s z sz sz、 d e p dep dep、 s o n son son、 f a fa fa。dfs 2 求 t o p top top 和 d f n dfn dfn。
在找出重链后,对于一条路径的操作,目的是找出包含哪些重链。假设路径为 ( u , v ) (u,v) (u,v),要让 u u u 和 v v v 跳到一个重链上,所以比较 d e p t o p x dep_{top_x} deptopx 与 d e p t o p y dep_{top_y} deptopy,让深度大的跳到 t o p top top 的父亲,直到 t o p x = t o p y top_x = top_y topx=topy,将路上完整的重链,和两点跳到的同一重链用数据结构维护即可。
树剖的过程可以看作一个常数小的 log \log log。
#include <bits/stdc++.h>
using namespace std;
#define lson (x << 1)
#define rson (x << 1 | 1)
typedef long long ll;
const int N = 1e5 + 5;
inline ll read() {
ll X = 0, w = 0; char ch = 0;
while(!isdigit(ch)) {w |= ch == '-'; ch = getchar(); }
while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
return w ? -X : X;
}
int n, m, r, p;
struct edge{
int to, nxt;
}e[N << 1];
int head[N], tot;
void addedge(int x, int y){
e[++tot].to = y, e[tot].nxt = head[x], head[x] = tot;
}
int fa[N], sz[N], son[N], dfn[N], b[N], a[N], dep[N], top[N];
void dfs1(int x, int f){
dep[x] = dep[f] + 1; fa[x] = f; sz[x] = 1;
for(int i = head[x]; i; i = e[i].nxt){
int y = e[i].to;
if(e[i].to != f){
dfs1(y, x);
sz[x] += sz[y];
if(sz[y] > sz[son[x]]) son[x] = y;
}
}
}
void dfs2(int x, int topf){
dfn[x] = ++tot; b[tot] = x; top[x] = topf;
if(son[x]) {
dfs2(son[x], topf);
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to;
if(!dfn[y]) dfs2(y, y);
}
}
}
struct node{
int l, r;
ll val, add;
}t[N << 2];
void pushdown(int x){
if(t[x].add){
t[lson].val += t[x].add * (t[lson].r - t[lson].l + 1);
t[rson].val += t[x].add * (t[rson].r - t[rson].l + 1);
t[lson].add += t[x].add;
t[rson].add += t[x].add;
t[x].add = 0;
}
}
void pushup(int x){
t[x].val = (t[lson].val % p + t[rson].val % p) % p;
}
void build(int x, int l, int r){
t[x].l = l; t[x].r = r;
if(l == r){
t[x].val = a[b[l]];
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid); build(rson, mid + 1, r);
pushup(x);
}
void update(int x, int l, int r, int k){
if(l <= t[x].l && r >= t[x].r){
t[x].val += (t[x].r - t[x].l + 1) * k; t[x].add += k;
t[x].val %= p; t[x].add %= p;
return;
}
pushdown(x);
int mid = (t[x].r + t[x].l) >> 1;
if(l <= mid) update(lson, l, r, k);
if(r > mid) update(rson, l, r, k);
pushup(x);
}
ll query(int x, int l, int r){
ll ans = 0;
if(l <= t[x].l && r >= t[x].r) return t[x].val;
pushdown(x);
int mid = (t[x].r + t[x].l) >> 1;
if(l <= mid) ans += query(lson, l, r);
if(r > mid) ans += query(rson, l, r);
return ans % p;
}
void uptree(int x, int y, int k){
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, dfn[top[x]], dfn[x], k);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
update(1, dfn[x], dfn[y], k);
}
ll quetree(int x, int y){
ll ans = 0;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
ans += query(1, dfn[top[x]], dfn[x]);
ans %= p;
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
ans += query(1, dfn[x], dfn[y]);
return ans % p;
}
int main(){
n = read(), m = read(), r = read(), p = read();
for(int i = 1; i <= n; i++) a[i] = read();
for(int i = 1; i < n; i++){
int x = read(), y = read();
addedge(x, y);
addedge(y, x);
}
tot = 0;
dfs1(r, 0); dfs2(r, 0);
build(1, 1, n);
for(int i = 1; i <= m; i++){
int opt = read(), x, y, k;
if(opt == 1){
x = read(), y = read(), k = read();
uptree(x, y, k);
}
if(opt == 2){
x = read(), y = read();
printf("%lld\n", quetree(x, y));
}
if(opt == 3){
x = read(), k = read();
update(1, dfn[x], dfn[x] + sz[x] - 1, k);
}
if(opt == 4){
x = read();
printf("%lld\n", query(1, dfn[x], dfn[x] + sz[x] - 1));
}
}
return 0;
}
值得一提的是,树链剖分是在线求 lca 最快的算法。
int sz[N], fa[N], son[N], dep[N], top[N];
void dfs1(int x, int f) {
fa[x] = f, dep[x] = dep[f] + 1, sz[x] = 1;
for (int i = head[x]; i; i = e[i].nxt){
int y = e[i].to;
if (y != f){
dfs1(y, x); sz[x] += sz[y];
if (sz[y] > sz[son[x]]) son[x] = y;
}
}
}
void dfs2(int x, int rt) {
top[x] = rt;
if (son[x]) dfs2(son[x], rt);
for (int i = head[x]; i; i = e[i].nxt){
int y = e[i].to;
if (y != son[x] && y != fa[x]) dfs2(y, y);
}
}
int lca(int x, int y) {
for (; top[x] != top[y]; x = fa[top[x]]) if (dep[top[x]] < dep[top[y]]) swap(x, y);
return dep[x] < dep[y] ? x : y;
}

浙公网安备 33010602011771号