树链剖分

学了好长时间没写题,忘了,所以写 blog 以记之。

树链剖分是把一棵树拆成若干条不相交的链,每条链内点的时间戳都是连续的,再用数据结构去维护这些链。

先了解一些术语

  • 重儿子:一个点所有儿子中,子树最大的那个,有多个任选一个。
  • 重边:连接一个点和它重儿子的边。
  • 轻儿子:一个点除了重儿子外的所有点。
  • 轻边:连接一个点与它轻儿子的边。
  • 重链:一条由重边构成的链,特殊地,将落单的点也看作重链。所有的重链包括了树的所有点。

树剖的核心是求出重链,想要求重链要先求重儿子 s o n son son,想求重儿子要求每个点子树的大小 s z sz sz。为了让重链的节点编号连续,在打时间戳 d f n dfn dfn 的 dfs 中,要优先遍历重儿子。为了维护每条链的信息,还需要知道每条重链深度最小的节点 t o p top top。另外,还需要用到深度 d e p dep dep 和每个点的父亲 f a fa fa

这些信息要分两次 dfs 求,dfs1 求 s z sz sz d e p dep dep s o n son son f a fa fa。dfs 2 求 t o p top top d f n dfn dfn

在找出重链后,对于一条路径的操作,目的是找出包含哪些重链。假设路径为 ( u , v ) (u,v) (u,v),要让 u u u v v v 跳到一个重链上,所以比较 d e p t o p x dep_{top_x} deptopx d e p t o p y dep_{top_y} deptopy,让深度大的跳到 t o p top top 的父亲,直到 t o p x = t o p y top_x = top_y topx=topy,将路上完整的重链,和两点跳到的同一重链用数据结构维护即可。

树剖的过程可以看作一个常数小的 log ⁡ \log log

#include <bits/stdc++.h>
using namespace std; 
#define lson (x << 1)
#define rson (x << 1 | 1)
typedef long long ll;  
const int N = 1e5 + 5; 
inline ll read() {
    ll X = 0, w = 0;  char ch = 0; 
    while(!isdigit(ch)) {w |= ch == '-'; ch = getchar(); }
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar(); 
    return w ? -X : X; 
}
int n, m, r, p; 
struct edge{
    int to, nxt; 
}e[N << 1]; 
int head[N], tot; 
void addedge(int x, int y){
    e[++tot].to = y, e[tot].nxt = head[x], head[x] = tot; 
}
int fa[N], sz[N], son[N], dfn[N], b[N], a[N], dep[N], top[N]; 
void dfs1(int x, int f){
    dep[x] = dep[f] + 1; fa[x] = f; sz[x] = 1; 
    for(int i = head[x]; i; i = e[i].nxt){
    	int y = e[i].to; 
        if(e[i].to != f){
            dfs1(y, x); 
            sz[x] += sz[y]; 
            if(sz[y] > sz[son[x]]) son[x] = y; 
        }
    }
}
void dfs2(int x, int topf){
    dfn[x] = ++tot; b[tot] = x; top[x] = topf; 
    if(son[x]) {
		dfs2(son[x], topf); 
    	for(int i = head[x]; i; i = e[i].nxt) {
    		int y = e[i].to; 
    		if(!dfn[y]) dfs2(y, y); 
		}	 
	}
}
struct node{
    int l, r; 
    ll val, add; 
}t[N << 2]; 
void pushdown(int x){
    if(t[x].add){
        t[lson].val += t[x].add * (t[lson].r - t[lson].l + 1); 
        t[rson].val += t[x].add * (t[rson].r - t[rson].l + 1); 
        t[lson].add += t[x].add; 
        t[rson].add += t[x].add; 
        t[x].add = 0; 
    }
}
void pushup(int x){
    t[x].val = (t[lson].val % p + t[rson].val % p) % p; 
}
void build(int x, int l, int r){
    t[x].l = l;  t[x].r = r; 
    if(l == r){
        t[x].val = a[b[l]]; 
        return; 
    }
    int mid = (l + r) >> 1;
    build(lson, l, mid);  build(rson, mid + 1, r); 
    pushup(x); 
}
void update(int x, int l, int r, int k){
    if(l <= t[x].l && r >= t[x].r){
        t[x].val += (t[x].r - t[x].l + 1) * k; t[x].add += k; 
        t[x].val %= p; t[x].add %= p; 
        return; 
    }
    pushdown(x); 
    int mid = (t[x].r + t[x].l) >> 1;
    if(l <= mid) update(lson, l, r, k); 
    if(r > mid) update(rson, l, r, k); 
    pushup(x);  
}
ll query(int x, int l, int r){
    ll ans = 0; 
    if(l <= t[x].l && r >= t[x].r) return t[x].val;   
    pushdown(x); 
    int mid = (t[x].r + t[x].l) >> 1;
    if(l <= mid) ans += query(lson, l, r); 
    if(r > mid) ans += query(rson, l, r); 
    return ans % p; 
}
void uptree(int x, int y, int k){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y); 
        update(1, dfn[top[x]], dfn[x], k); 
        x = fa[top[x]]; 
    }
    if(dep[x] > dep[y]) swap(x, y); 
    update(1, dfn[x], dfn[y], k); 
}
ll quetree(int x, int y){
    ll ans = 0; 
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y); 
        ans += query(1, dfn[top[x]], dfn[x]); 
        ans %= p; 
        x = fa[top[x]]; 
    }
    if(dep[x] > dep[y]) swap(x, y); 
    ans += query(1, dfn[x], dfn[y]); 
    return ans % p; 
}
int main(){
    n = read(), m = read(), r = read(), p = read(); 
    for(int i = 1; i <= n; i++) a[i] = read(); 
    for(int i = 1; i < n; i++){
        int x = read(), y = read(); 
        addedge(x, y); 
        addedge(y, x); 
    }
    tot = 0;  
    dfs1(r, 0);  dfs2(r, 0); 
    build(1, 1, n); 
    for(int i = 1; i <= m; i++){
        int opt = read(), x, y, k; 
        if(opt == 1){
            x = read(), y = read(), k = read(); 
            uptree(x, y, k);         
        }
        if(opt == 2){
            x = read(), y = read();  
            printf("%lld\n", quetree(x, y)); 
        }
        if(opt == 3){
            x = read(), k = read(); 
            update(1, dfn[x], dfn[x] + sz[x] - 1, k); 
        }
        if(opt == 4){
            x = read(); 
            printf("%lld\n", query(1, dfn[x], dfn[x] + sz[x] - 1)); 
        }
    }
    return 0; 
}

值得一提的是,树链剖分是在线求 lca 最快的算法。

int sz[N], fa[N], son[N], dep[N], top[N];
void dfs1(int x, int f) {
    fa[x] = f, dep[x] = dep[f] + 1, sz[x] = 1;
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != f){
            dfs1(y, x); sz[x] += sz[y];
            if (sz[y] > sz[son[x]]) son[x] = y;
        }
    }
}
void dfs2(int x, int rt) {
    top[x] = rt;
    if (son[x]) dfs2(son[x], rt);
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != son[x] && y != fa[x]) dfs2(y, y);
    }
}
int lca(int x, int y) {
    for (; top[x] != top[y]; x = fa[top[x]]) if (dep[top[x]] < dep[top[y]]) swap(x, y);
    return dep[x] < dep[y] ? x : y;
}
posted @ 2020-07-31 17:46  ylxmf2005  阅读(29)  评论(0)    收藏  举报