树链剖分 从入门到入门
本文在线更新中。
概念
这里头有很多定义,理解了就记住了。
- 重儿子
一个节点所有儿子里子树节点最多的节点(即子树最大),每个节点只有一个重儿子
- 轻儿子
除了重儿子都叫轻儿子
- 重边
节点和自己重儿子的连边
- 轻边
节点和自己轻儿子的连边
- 重链
由多条连续重边连成的链,每个节点都有一个所属的重链
- 轻链
由多条连续轻边连成的链
因此我们相当于把这个树给剖成了一堆链,下面用一些经典问题引入树链剖分思想。
例题一 : lca问题
这里有一个非常基础的复杂度证明。
证:树上任意两点之间的路径都由不超过 \(O(\log n)\) 条重链组成。
首先我们发现树上的边不是重边就是轻边,那么如果我们证明了任意两点间的路径都有不超过 \(O(\log n)\) 条轻边,就会把剩下的重边自然而然的分成 \(O(\log n)\) 条重链,就能得证这个命题。
考虑自下而上遍历的过程。我们的节点往上走的时候,每遇到一个轻边,就一定存在一条和它同深度的重边。而重儿子的子树大小一定是至少等于轻儿子的子树大小的。这就导致每遇到一条轻边,size至少 *= 2。
一共 \(n\) 个节点,显然最多能有 \(\log n\) 条轻边,于是命题得证。
所以,我们可以通过同时往上跳重链的方式来找 lca。
与倍增法相似的,我们看两个节点哪个深度深就让哪个节点往上跳,直到两个节点处于同一条重链,输出深度小的节点。
这个问题主要是让你了解一下具体的树链剖分编码中要预处理的东西。
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 5e5 + 7;
int x , y , n , m , root , cnt , head[N] , dep[N] , siz[N] , top[N] , son[N] , father[N];
struct edge {
int t , n;
}e[N << 1];
inline void add(int a , int b) {
e[++cnt].t = b , e[cnt].n = head[a] , head[a] = cnt;
}
void dfs(int s , int fa) {
siz[s] = 1;//子树大小
dep[s] = dep[fa] + 1;//深度
father[s] = fa;//父节点
for(register int i = head[s]; i; i = e[i].n) {
int to = e[i].t;
if(to == fa) {
continue;
}
dfs(to , s);
if(siz[son[s]] < siz[to]) {
son[s] = to;//重儿子
}
siz[s] += siz[to];
}
}
void Dfs(int s , int belong) {
top[s] = belong;//每个节点所在重链
if(!son[s]) {
return;//已经没有儿子就返回
}
Dfs(son[s] , belong);//首先处理重儿子,他的重链可以跟当前节点的连上
for(register int i = head[s]; i ; i = e[i].n) {
int to = e[i].t;
if(to == father[s] || to == son[s]) {
continue;
}
Dfs(to , to);//再处理轻儿子
}
}
inline int lca(int u , int v) {
while(top[u] != top[v]) {//往上跳
if(dep[top[u]] < dep[top[v]]) {
swap(u , v);
}
u = father[top[u]];//这里注意链头也包括在链里,所以要跳到链头的父节点
}
return dep[u] < dep[v] ? u : v;
}
int main() {
ios :: sync_with_stdio(0) , cin.tie(0) , cout.tie(0);
cin >> n >> m >> root;
for(register int i = 1; i < n; ++i) {
cin >> x >> y;
add(x , y) , add(y , x);
}
dfs(root , 0); //处理基础信息,siz,dep,father,son
Dfs(root , root);// 处理top,即每个节点所在重链的链头
while(m--) {
cin >> x >> y;
cout << lca(x , y) << '\n';
}
return 0;
}
复杂度 \(O(n + m \log n)\)
例题二: 有关子树与路径的操作
这个是很典的东西。我们用树链剖分与 dfn 来构建线段树,实现快速的修改与查询。
通过在第二次的 dfs 中设置 dfn,由于每次先遍历重儿子,我们可以实现每个节点的和它的重儿子的 dfn 都是连续的。
那么操作一就变成了每次从 \(dfn_{top_x}\) 修改到 \(dfn_x\)。
操作二和操作一只有修改和查询的区别。
操作三则是利用节点子树的 dfn 和节点的 dfn 是连续的特性,直接修改 \(dfn_x\) 到 \(dfn_{x + siz_x - 1}\)。
操作四和操作三也只有修改和查询的区别。
#include <bits/stdc++.h>
#define int long long
constexpr int N = 1e5 + 7;
using namespace std;
int tot , opt , cnt , x , y , z , n , m , root , mod , v[N] , a[N] , head[N] , dep[N] , top[N] , siz[N] , son[N] , dfn[N] , father[N];
struct edge {
int t , n;
}e[N << 1];
// --------------------------------------------------------------------------------------- 这一大坨都是线段树
struct Segment_Tree {
int sum[N << 2] , tag[N << 2];
inline void pushup(int i) {
sum[i] = (sum[i << 1] + sum[i << 1 | 1]) % mod;
}
inline void add_tag(int i , int l , int r , int k) {
tag[i] += k;
sum[i] = (sum[i] + (r - l + 1) * k) % mod;
}
inline void pushdown(int i , int l , int r) {
if(tag[i]) {
int mid = l + r >> 1;
add_tag(i << 1 , l , mid , tag[i]);
add_tag(i << 1 | 1 , mid + 1 , r , tag[i]);
tag[i] = 0;
}
}
inline void build(int i , int l , int r) {
if(l == r) {
sum[i] = a[l] % mod; return;
}
int mid = l + r >> 1;
build(i << 1 , l , mid);
build(i << 1 | 1 , mid + 1 , r);
pushup(i);
}
inline void modify(int i , int l , int r , int L , int R , int k) {
if(L <= l && r <= R) {
sum[i] = (sum[i] + k * (r - l + 1)) % mod;
tag[i] += k;
return;
}
pushdown(i , l , r);
int mid = l + r >> 1;
if(L <= mid) {
modify(i << 1 , l , mid , L , R , k);
}
if(R > mid) {
modify(i << 1 | 1 , mid + 1 , r , L , R , k);
}
pushup(i);
}
inline int query(int i , int l , int r , int L , int R) {
if(L <= l && r <= R) {
return sum[i];
}
pushdown(i , l , r);
int mid = l + r >> 1 , res = 0;
if(L <= mid) {
res = query(i << 1 , l , mid , L , R) % mod;
}
if(R > mid) {
res = (res + query(i << 1 | 1 , mid + 1 , r , L , R)) % mod;
}
return res;
}
}T;
//-----------------------------------------------------------------------------------------------------
inline void add(int x , int y) {
e[++cnt].t = y , e[cnt].n = head[x] , head[x] = cnt;
}
// ----------------------------------------------------------------------------------------------------
void dfs(int s , int fa) {
father[s] = fa;
dep[s] = dep[fa] + 1;
siz[s] = 1;
for(register int i = head[s]; i; i = e[i].n) {
if(e[i].t != fa) {
dfs(e[i].t , s);
siz[s] += siz[e[i].t];
if(siz[e[i].t] > siz[son[s]]) {
son[s] = e[i].t;
}
}
}
}
void Dfs(int s , int belong) {
top[s] = belong;
dfn[s] = ++tot;//其实跟刚才是一样的,只不过多了一个记录dfn
a[tot] = v[s];
if(!son[s]) {
return;
}
Dfs(son[s] , belong);
for(register int i = head[s]; i; i = e[i].n) {
if(e[i].t != father[s] && e[i].t != son[s]) {
Dfs(e[i].t , e[i].t);
}
}
}
//----------------------------------------------------------------------------------------------------
signed main() {
ios :: sync_with_stdio(0) , cin.tie(0) , cout.tie(0);
cin >> n >> m >> root >> mod;
for(register int i = 1; i <= n; ++i) {
cin >> v[i];
}
for(register int i = 1; i < n; ++i) {
cin >> x >> y;
add(x , y); add(y , x);
}
dfs(root , 0);
Dfs(root , root);
T.build(1 , 1 , n);
while(m--) {
cin >> opt;
switch(opt) {
case 1 : {
cin >> x >> y >> z;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) {
swap(x , y);
}
T.modify(1 , 1 , n , dfn[top[x]] , dfn[x] , z);
x = father[top[x]];
}
if(dep[x] > dep[y]) {
swap(x , y);
}
T.modify(1 , 1 , n , dfn[x] , dfn[y] , z);
break;
}
case 2 : {
cin >> x >> y;
int res = 0;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) {
swap(x , y);
}
res = (res + T.query(1 , 1 , n , dfn[top[x]] , dfn[x])) % mod;
x = father[top[x]];
}
if(dep[x] > dep[y]) {
swap(x , y);
}
res = (res + T.query(1 , 1 , n , dfn[x] , dfn[y])) % mod;
cout << res << '\n';
break;
}
case 3 : {
cin >> x >> z;
T.modify(1 , 1 , n , dfn[x] , dfn[x] + siz[x] - 1 , z);
break;
}
case 4 : {
cin >> x;
cout << (T.query(1 , 1 , n , dfn[x] , dfn[x] + siz[x] - 1)) % mod << '\n';
break;
}
}
}
return 0;
}
复杂度 \(O(m\log n\log n)\)
例题三:关于边权的转化
剖分本身是针对点权的。如果题目给我们的是边权,由于一个节点只有一个父亲,我们可以把值赋在子节点上。
只放第一道题的代码,另一个没啥区别。
#include <bits/stdc++.h>
#define m_p(x, y) make_pair(x, y)
#define pb(x) push_back(x)
#define int long long
#define ull unsigned long long
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define rep_(i, a, b) for(int i = a; i >= b; --i)
using namespace std;
namespace FYH {
struct Edge {
int to, weight, nxt;
};
int n, u, v, w, cnt, tot, c, d;
vector<pair<int, int>> record;
vector<int> father, head, son, siz, top, dep, a, dfn, rnk;
vector<Edge> e;
inline void add(int x, int y, int z) {
e[++cnt].to = y, e[cnt].weight = z, e[cnt].nxt = head[x];
head[x] = cnt;
}
void dfs(int s, int fa) {
dep[s] = dep[fa] + 1;
siz[s] = 1;
father[s] = fa;
for(register int i = head[s]; i; i = e[i].nxt) {
if(e[i].to == fa) {
continue;
}
a[e[i].to] = e[i].weight;
dfs(e[i].to, s);
siz[s] += siz[e[i].to];
if(siz[e[i].to] > siz[son[s]]) {
son[s] = e[i].to;
}
}
}
void Dfs(int s, int belong) {
dfn[s] = ++tot;
rnk[tot] = s;
top[s] = belong;
if(!son[s]) {return;}
Dfs(son[s], belong);
for(register int i = head[s]; i; i = e[i].nxt) {
if(e[i].to == father[s] || e[i].to == son[s]) {
continue;
}
Dfs(e[i].to, e[i].to);
}
}
struct Seg {
vector<int> Max;
#define ls (i << 1)
#define rs (i << 1 | 1)
inline void push_up(int i) {
Max[i] = max(Max[ls], Max[rs]);
}
void build(int i, int l, int r) {
if(l == r) {
Max[i] = a[rnk[l]];
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
push_up(i);
}
void modify(int i, int l, int r, int pos, int k) {
if(l == r) {
Max[i] = k;
return;
}
int mid = (l + r) >> 1;
if(pos <= mid) {
modify(ls, l, mid, pos, k);
}
else {
modify(rs, mid + 1, r, pos, k);
}
push_up(i);
}
int query(int i, int l, int r, int L, int R) {
if(L <= l && r <= R) {
return Max[i];
}
int res = 0, mid = (l + r) >> 1;
if(mid >= L) {
res = max(res, query(ls, l, mid, L, R));
}
if(mid < R) {
res = max(res, query(rs, mid + 1, r, L, R));
}
return res;
}
}T;
inline int check(int l, int r) {
if(l == r) {
return 0;
}
int res = 0;
while(top[l] != top[r]) {
if(dep[top[l]] < dep[top[r]]) {
swap(l, r);
}
res = max(res, T.query(1, 1, n, dfn[top[l]], dfn[l]));
l = father[top[l]];
}
if(l == r) {
return res;
}
if(dep[l] > dep[r]) {
swap(l, r);
}
res = max(res, T.query(1, 1, n, dfn[son[l]], dfn[r]));
return res;
}
void main() {
cin >> n;
a.assign(n + 1, 0);
record.resize(n + 1);
dfn.resize(n + 1);
top.resize(n + 1);
siz.resize(n + 1);
rnk.resize(n + 1);
dep.assign(n + 1, 0);
son.assign(n + 1, 0);
T.Max.assign((n << 2) + 1, 0);
e.resize((n << 1) + 1);
father.resize(n + 1);
head.resize(n + 1);
rep(i, 1, n - 1) {
cin >> u >> v >> w;
add(u, v, w);
add(v, u, w);
record[i] = m_p(u, v);
}
dfs(1, 0);
Dfs(1, 1);
T.build(1, 1, n);
string opt;
while(cin >> opt) {
if(opt == "DONE") {break;}
cin >> c >> d;
if(opt == "CHANGE") {
int u = record[c].first, v = record[c].second;
if(dep[u] > dep[v]) {
swap(u, v);
}
T.modify(1, 1, n, dfn[v], d);
}
else {
cout << check(c, d) << '\n';
}
}
}
}
signed main() {
ios :: sync_with_stdio(0), cin.tie(0);
int t = 1;
//cin >> t;
while(t--) {
FYH :: main();
}
return 0;
}
一个需要注意的问题就是,边转点后我们跳重链的时候最后查询区间应该是 \([dfn_{x + 1}, dfn_y]\),其中 \(dep_x < dep_y\) 且二节点同属于一重链。因为 \(x\) 上方的边不属于查询范围。

浙公网安备 33010602011771号