树链剖分
前言
树剖代码往往很长,但是逻辑并不复杂,理清每一步在干什么就可以。
基础概念
重儿子与轻儿子
重儿子就是树上一个节点的儿子中,子树节点数最多的那一个儿子(多个任选一个),其余的都是轻儿子。
重边与轻边
由一个节点连向它重儿子的边就是重边,另外的边就是轻边。
重链
重链是首尾相接的重边组成的链(也可以是不是任何一条重边端点的一个单独的点)。
重链的数量等于叶子节点的数量,感性理解一下,从上往下出现的重链都会在叶子节点才会结束,并且不会共用也不会遗漏。
重链剖分(树链剖分)
由于每个点都会在一条重链上,我们就可以把树看成由很多条重链和一些轻边组成,而我们可以把在树上路径上点的问题变成一段一段的重链上的问题。
初始化
为了方便对重链上的一段进行处理,我们可以把一条重链放到线段树上。
一种方法是对每一条重链都建一颗线段树,但是有些麻烦。
所以我们可以把所有节点放到同一颗线段树上,只要保证同一条重链上的节点是一个连续区间就可以了。
我们通过 dfs 序建树,如果我们每次先遍历重儿子就可以保证一条重链是连续的一段。
所以我们通过两个 dfs,一个用于求重儿子,一个求 dfs 序。
求好后我们就可以愉快的建树了。
路径修改/查询
对于从 \(x\) 到 \(y\) 的路径上的修改,我们可以一边往上跳,找他们的 LCA,一边对经过的重链修改。
我们可以记录下每一个点 \(i\) 所在重链上深度最小的节点 \(anc_i\),往上跳时从当前节点到此节点做一次区间修改。
\(x\) 和 \(y\) 一起往上跳,每次跳深度 \(d\) 大的,保证跳的过程中两者都还未到 LCA,如果 \(anc_x=anc_y\),那么最后跳一次,使 \(x=y\),结束。
子树修改/查询
因为在 dfs 的过程中,只有遍历完一棵子树才会离开这颗子树,所以一棵子树内的所有节点的 dfs 序也是连续的。
记录下每颗子树内 dfs 序最大是多少,最小肯定是子树根的 dfs 序。
总结
树剖的每一步都很好理解也很好实现,但步骤繁多,对于刚开始写长代码的同学不太友好,需要练习。
以下是 P3384 【模板】重链剖分/树链剖分 的代码。
#include <bits/stdc++.h>
#define int long long
#define debug() cout << "--,--" << endl
using namespace std;
const int N = 100010;
int n, m, root, mod, timestamp;
int a[N];
vector <int> v[N];
int son[N], siz[N], dfn[N], maxn[N], anc[N], anti[N], d[N], father[N];
void dfs1(int u, int fa)
{
father[u] = fa;
siz[u] = 1;
int maxn = 0;
for (int j : v[u])
{
if (j == fa) continue;
d[j] = d[u] + 1;
dfs1 (j, u);
siz[u] += siz[j];
if (siz[j] > siz[son[u]]) son[u] = j;
}
}
void dfs2(int u, int fa)
{
if (u != son[fa]) anc[u] = u;
else anc[u] = anc[fa];
maxn[u] = dfn[u] = ++timestamp;
anti[timestamp] = u;
if (son[u])
dfs2 (son[u], u), maxn[u] = max (maxn[u], maxn[son[u]]);
for (int j : v[u])
{
if (j == fa || j == son[u]) continue;
dfs2 (j, u);
maxn[u] = max (maxn[u], maxn[j]);
}
}
struct node
{
int l, r, val, tag;
} tr[N << 2];
int len(int u)
{
return tr[u].r - tr[u].l + 1;
}
void pushup(int u)
{
tr[u].val = (tr[u << 1].val + tr[u << 1 | 1].val) % mod;
}
void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r;
if (l == r)
{
tr[u].val = a[anti[l]];
return;
}
int mid = l + r >> 1;
build (u << 1, l, mid);
build (u << 1 | 1, mid + 1, r);
pushup (u);
}
void pushdown(int u)
{
int& tag = tr[u].tag;
if (!tag) return;
tr[u << 1].tag = (tr[u << 1].tag + tag) % mod;
tr[u << 1 | 1].tag = (tr[u << 1 | 1].tag + tag) % mod;
tr[u << 1].val = (tr[u << 1].val + len (u << 1) * tag % mod) % mod;
tr[u << 1 | 1].val = (tr[u << 1 | 1].val + len (u << 1 | 1) * tag % mod) % mod;
tag = 0;
}
void add(int u, int l, int r, int x)
{
if (l <= tr[u].l && tr[u].r <= r)
{
tr[u].tag = (tr[u].tag + x + mod) % mod;
tr[u].val = (tr[u].val + len (u) * x % mod + mod) % mod;
return;
}
pushdown (u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) add (u << 1, l, r, x);
if (r > mid) add (u << 1 | 1, l, r, x);
pushup (u);
}
int find(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r)
return tr[u].val;
pushdown (u);
int mid = tr[u].l + tr[u].r >> 1;
int ans = 0;
if (l <= mid) ans += find (u << 1, l, r);
if (r > mid) ans += find (u << 1 | 1, l, r);
return ans % mod;
}
signed main()
{
cin >> n >> m >> root >> mod;
for (int i = 1; i <= n; i++)
cin >> a[i], a[i] %= mod;
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
v[x].push_back (y);
v[y].push_back (x);
}
dfs1 (root, 0);
dfs2 (root, 0);
build (1, 1, n);
for (int i = 1; i <= m; i++)
{
int op, x, y, z;
cin >> op;
if (op == 1)
{
cin >> x >> y >> z;
z %= mod;
while (anc[x] != anc[y])
{
if (d[anc[x]] > d[anc[y]])
add (1, dfn[anc[x]], dfn[x], z), x = father[anc[x]];
else add (1, dfn[anc[y]], dfn[y], z), y = father[anc[y]];
}
add (1, min (dfn[x], dfn[y]), max (dfn[x], dfn[y]), z);
}
else if (op == 2)
{
cin >> x >> y;
int ans = 0;
while (anc[x] != anc[y])
{
if (d[anc[x]] > d[anc[y]])
ans = (ans + find (1, dfn[anc[x]], dfn[x])) % mod, x = father[anc[x]];
else ans = (ans + find (1, dfn[anc[y]], dfn[y])) % mod, y = father[anc[y]];
}
ans = (ans + find (1, min (dfn[x], dfn[y]), max (dfn[x], dfn[y]))) % mod;
cout << ans << endl;
}
else if (op == 3)
{
cin >> x >> z;
add (1, dfn[x], maxn[x], z);
}
else
{
cin >> x;
cout << find (1, dfn[x], maxn[x]) << endl;
}
}
return 0;
}