树链剖分
树链剖分
树链剖分是一种将树结构化为线性结构的方法
配合线段树区间修改和区间查询的功能,实现树上的\(O(logn)\)级别的修改与查询
其实现方式是将树划分为若干条链,称为重链
而具体如何剖分呢?接下来要引入几个定义:
重儿子:对于一个父节点的若干个子节点,其中子树节点最多的子节点为重儿子
轻儿子:除重儿子以外的节点
重边:父节点和重儿子连成的边
重链:由多条重边连成的链
我们可以得出以下结论:
- 整棵树会被剖分成若干条重链
- 每条重链的顶端一定是轻儿子 (单独一个轻儿子也是一条重链,是特殊的重链)
- 任意一条路径被切分成不超过\(logn\)条重链,也是树链剖分能将时间复杂度控制在\(O(logn)\)的核心原因
我们按照上述原理,将树剖分为若干条重链存储进一个数组中,再用线段树维护区间的特性,将树上修改查询转化为区间的修改和查询;那么我们如何实现这个过程呢?
\(First\)
我们需要先开几个数组:
(这里推荐用全局变量,否则要么传递超多参数,要么用\(lambda\)表达式写函数,都非常麻烦;一般的出题人要考察树剖,大抵也不会是多实例测试,那样就太恶心了)
/*
g[u]: 和u相连的所有节点(存图)
w[u]: u的权值
fa[u]: u的父节点
depth[u]: u的深度
son[u]: u的重儿子
size[u]: 以u为根的子树节点数
top[u]: u所在重链的顶点
idx[u]: u剖分后的新编号
New[u]: 新编号在树中对应节点的权值,作为线段树维护的序列
*/
std::vector<int> g[N], w(N);
std::vector<int> fa(N), depth(N), size(N), son(N);
std::vector<int> top(N), idx(N), New(N);
int cnt = 0;
\(Second\)
接下来,我们要预处理一下这些数组,通过两次\(dfs\)实现
第一次\(dfs\),我们传递两个参数,一个是当前节点\(now\),一个是当前节点的父节点\(last\)
在每一次搜索时,更新\(fa、depth、size、son\)数组
其中,\(fa\)数组和\(depth\)数组都很好办,显然\(fa[now] = last, depth[now] = depth[last] + 1\)
而\(size\)数组和\(son\)数组相对麻烦一些,我们只能在计算完,也就是递归返回时才能正确更新
我们定义\(next\)为\(now\)节点的子节点,那么\(size[now] = ∑size[next]\)
那么\(son[now]\),也就是具有最大\(size[next]\)的子节点了,代码如下:
/* 预处理fa、depth、size、son数组 */
void dfs1 (int now, int last) {
fa[now] = last; // 存储当前节点的父亲节点
depth[now] = depth[last] + 1; // 更新当前节点的深度,为父亲节点的深度 + 1
size[now] = 1; // 先初始化当前节点的子树节点数为1,后面再进行更新
/* 接下来处理now节点的子节点 */
for (auto next : g[now]) {
if (next == last) continue; // 相当于判断next是否被访问过,访问过就直接跳过了
dfs1(next, now);
size[now] += size[next]; // 递归返回时再更新子树节点数,这也是不能用bfs替代的核心原因
if (size[son[now]] < size[next]) son[now] = next; // 更新重儿子,递归返回时才可以更新
}
}
下面需要处理\(top、idx、New\)数组,也用\(dfs\)实现,但传递的参数和上面不一样
这次要传递两个参数:当前节点\(now\),当前节点所在重链的顶端的轻儿子\(top\:dot\)
那么很显然,每次搜索时,\(top[now] = top\:dot\)
然后我们需要将这些点按顺序在\(idx\)数组中标记,实现方法是开一个变量\(cnt\),然后\(idx[now] =\) ++\(cnt\)
再按照新的编号,在\(New\)数组中对应的给带上权值,也就是\(New[cnt] = w[now]\)
在第一个\(dfs\)中,我们已经正确找到了每个结点的重儿子\(son[now]\)
在第二次\(dfs\)中,我们就要借助这个,先搜索重儿子,再搜索轻儿子,以保证\(dfs\)序的正确
/* 预处理top、idx、New数组 */
void dfs2 (int now, int top_dot) {
top[now] = top_dot; // 存储当前节点的顶端,也就是当前节点所在重链的顶端
idx[now] = ++cnt; // 当前节点剖分后的新编号
New[cnt] = w[now]; // 新编号对应的原来的权值
/* 先处理重儿子,保证dfs序 */
if (son[now]) dfs2(son[now], top_dot);
/* 再处理轻儿子 */
for (auto next : g[now]) {
if (next == fa[now] || next == son[now]) continue; // 这里也是判断next是否被访问过
dfs2(next, next);
}
}
至此预处理过程全部结束,接下来可以开始区间维护的部分了
\(Third\)
区间维护用到的线段树,是针对\(New\)数组建立的区间和线段树
区间的更新和查询也都是在\(New\)数组上进行,我们先进行建树过程,代码如下:
(这里的线段树就是最最基础的区间和线段树的板子,直接搬过来就可以用;但不会线段树还请先缓一缓,补完区间和线段树的知识再来学树剖吧)
/* 区间和线段树模板 */
struct SegTree {
#define lc u << 1
#define rc u << 1 | 1
struct node {
int l, r;
i64 val, tag;
} tree[4 * N];
void push_up (int u) {
tree[u].val = tree[lc].val + tree[rc].val;
}
void push_down (int u) {
if (tree[u].tag) {
// Left:
tree[lc].val += (tree[lc].r - tree[lc].l + 1) * tree[u].tag;
tree[lc].tag += tree[u].tag;
// Right:
tree[rc].val += (tree[rc].r - tree[rc].l + 1) * tree[u].tag;
tree[rc].tag += tree[u].tag;
// Final
tree[u].tag = 0;
}
}
void build (int u, int l, int r) {
tree[u].l = l;
tree[u].r = r;
tree[u].tag = 0;
if (l == r) {
tree[u].val = New[l]; // 对New数组进行建树
return;
}
int mid = tree[u].l + tree[u].r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
push_up(u);
}
void update (int u, int x, int y, i64 k) {
if (x <= tree[u].l && tree[u].r <= y) {
tree[u].val += (tree[u].r - tree[u].l + 1) * k;
tree[u].tag += k;
return;
}
push_down(u);
int mid = tree[u].l + tree[u].r >> 1;
if (x <= mid) update(lc, x, y, k);
if (y > mid) update(rc, x, y, k);
push_up(u);
}
i64 query (int u, int x, int y) {
if (x <= tree[u].l && tree[u].r <= y) {
return tree[u].val;
}
push_down(u);
int mid = tree[u].l + tree[u].r >> 1;
i64 ans = 0;
if (x <= mid) ans += query(lc, x, y);
if (y > mid) ans += query(rc, x, y);
return ans;
}
} T;
\(Fourth\)
完成了建树过程,接下来我们开始写区间修改和查询的代码
我们回忆一下刚才的剖分过程:\(size[now]\)代表了\(now\)节点的子树大小,而我们也是按照\(dfs\)序存进了\(New\)数组中
因此,如果要修改 / 查询某个节点\(u\)子树的所有节点之和,
也就是利用线段树在\(New\)序列上修改 / 查询 \([idx[u],idx[u] + size[u] - 1]\) 这段区间!
那么,修改 / 查询子树的代码就非常简单了,如下所示:
/*
这里更新u的子树,所以点u对应的idx[u]就是根
而剖分为链后,idx[u] + size[u] - 1是该链的最后一个节点
*/
void update_tree (int u, i64 k) {
T.update(1, idx[u], idx[u] + size[u] - 1, k);
}
i64 query_tree (int u) {
return T.query(1, idx[u], idx[u] + size[u] - 1);
}
但是还有另一种修改 / 查询,也是更常见的修改 / 查询方式
就是给定两个点\(u\)和\(v\),修改点\(u\)到点\(v\)的简单路径上的所有点的权值,或是查询点\(u\)到点\(v\)的简单路径上的所有点的权值之和
这时我们就要利用先前剖分的重链的特性了:
首先我们要判断一下,\(top[u]\)和\(top[v]\)是否相等
如果他们不相等,说明当前\(u\)和\(v\)所在的两条重链不相交,而此时,我们要分别处理这两条重链
处理完之后,再让点\(u\)和点\(v\)分别跳到他们所在重链的\(top\)节点的父节点,也就是\(fa[top[u]]\)和\(fa[top[v]]\)
但是为了防止重复的问题,我们先对点\(u\)和点\(v\)进行判断,谁深度大,处理谁
也就是 if (depth[top[u]] < depth[top[v]]) swap(u, v) 保证\(u\)是更深的节点
然后去处理\(u\)所在的重链,再让\(u\)跳到\(fa[top[u]]\)
当执行完所有的\(top[u] ≠ top[v]\)的情况后,也就是\(u\)和\(v\)所在重链顶端是同一个节点
此时到了最后一步,我们不需要再分开处理了,直接一步处理完,代码如下:
void update_path (int u, int v, i64 k) {
/*
当top[u] != top[v],说明此时的顶端不是最近公共祖先
我们进行while循环的目的是在找最近公共祖先
*/
while (top[u] != top[v]) {
if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
T.update(1, idx[top[u]], idx[u], k); // 更新u节点所在的重链
u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的更新
}
/* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步更新操作 */
if (depth[u] < depth[v]) std::swap(u, v);
T.update(1, idx[v], idx[u], k);
}
i64 query_path (int u, int v) {
i64 ans = 0;
/*
当top[u] != top[v],说明此时的顶端不是最近公共祖先
我们进行while循环的目的是在找最近公共祖先
*/
while (top[u] != top[v]) {
if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
ans += T.query(1, idx[top[u]], idx[u]); // 更新u节点所在的重链
u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的查询
}
/* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步查询操作 */
if (depth[u] < depth[v]) std::swap(u, v);
ans += T.query(1, idx[v], idx[u]);
return ans;
}
\(Final\)
最终代码如下:(可通过洛谷P3384,https://www.luogu.com.cn/problem/P3384)
#include <bits/stdc++.h>
#include <bits/extc++.h>
using i64 = int64_t;
using pii = std::pair<int, int>;
using namespace __gnu_pbds;
using ordered_set = tree<i64, null_type, std::less<i64>, rb_tree_tag, tree_order_statistics_node_update>;
using ordered_multiset = tree<i64, null_type, std::less_equal<i64>, rb_tree_tag, tree_order_statistics_node_update>;
constexpr int N = 1e5 + 10;
constexpr int inf = INT_MAX;
constexpr i64 INF = LLONG_MAX;
// constexpr int mod = 998244353;
constexpr int dx[] = {+0, -1, +0, +1, -1, +1, +1, -1};
constexpr int dy[] = {+1, +0, -1, +0, +1, +1, -1, -1};
/*
w[u]: u的权值
fa[u]: u的父节点
depth[u]: u的深度
son[u]: u的重儿子
size[u]: 以u为根的子树节点数
top[u]: u所在重链的顶点
idx[u]: u剖分后的新编号
New[u]: 新编号在树中对应节点的权值,作为线段树维护的序列
*/
std::vector<int> g[N], w(N);
std::vector<int> fa(N), depth(N), size(N), son(N);
std::vector<int> top(N), idx(N), New(N);
int cnt = 0;
/* 区间和线段树模板 */
struct SegTree {
#define lc u << 1
#define rc u << 1 | 1
struct node {
int l, r;
i64 val, tag;
} tree[4 * N];
void push_up (int u) {
tree[u].val = tree[lc].val + tree[rc].val;
}
void push_down (int u) {
if (tree[u].tag) {
// Left:
tree[lc].val += (tree[lc].r - tree[lc].l + 1) * tree[u].tag;
tree[lc].tag += tree[u].tag;
// Right:
tree[rc].val += (tree[rc].r - tree[rc].l + 1) * tree[u].tag;
tree[rc].tag += tree[u].tag;
// Final
tree[u].tag = 0;
}
}
void build (int u, int l, int r) {
tree[u].l = l;
tree[u].r = r;
tree[u].tag = 0;
if (l == r) {
tree[u].val = New[l];
return;
}
int mid = tree[u].l + tree[u].r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
push_up(u);
}
void update (int u, int x, int y, i64 k) {
if (x <= tree[u].l && tree[u].r <= y) {
tree[u].val += (tree[u].r - tree[u].l + 1) * k;
tree[u].tag += k;
return;
}
push_down(u);
int mid = tree[u].l + tree[u].r >> 1;
if (x <= mid) update(lc, x, y, k);
if (y > mid) update(rc, x, y, k);
push_up(u);
}
i64 query (int u, int x, int y) {
if (x <= tree[u].l && tree[u].r <= y) {
return tree[u].val;
}
push_down(u);
int mid = tree[u].l + tree[u].r >> 1;
i64 ans = 0;
if (x <= mid) ans += query(lc, x, y);
if (y > mid) ans += query(rc, x, y);
return ans;
}
} T;
/* 预处理fa、depth、size、son数组 */
void dfs1 (int now, int last) {
fa[now] = last; // 存储当前节点的父亲节点
depth[now] = depth[last] + 1; // 更新当前节点的深度,为父亲节点的深度 + 1
size[now] = 1; // 先初始化当前节点的子树节点数为1,后面再进行更新
/* 接下来处理now节点的子节点 */
for (auto next : g[now]) {
if (next == last) continue; // 相当于判断next是否被访问过,访问过就直接跳过了
dfs1(next, now);
size[now] += size[next]; // 递归返回时再更新子树节点数,这也是不能用bfs替代的核心原因
if (size[son[now]] < size[next]) son[now] = next; // 更新重儿子,递归返回时才可以更新
}
}
/* 预处理top、idx、New数组 */
void dfs2 (int now, int top_dot) {
top[now] = top_dot; // 存储当前节点的顶端,也就是当前节点所在重链的顶端
idx[now] = ++cnt; // 当前节点剖分后的新编号
New[cnt] = w[now]; // 新编号对应的原来的权值
/* 先处理重儿子,保证dfs序 */
if (son[now]) dfs2(son[now], top_dot);
/* 再处理轻儿子 */
for (auto next : g[now]) {
if (next == fa[now] || next == son[now]) continue; // 这里也是判断next是否被访问过
dfs2(next, next);
}
}
void update_path (int u, int v, i64 k) {
/*
当top[u] != top[v],说明此时的顶端不是最近公共祖先
我们进行while循环的目的是在找最近公共祖先
*/
while (top[u] != top[v]) {
if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
T.update(1, idx[top[u]], idx[u], k); // 更新u节点所在的重链
u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的更新
}
/* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步更新操作 */
if (depth[u] < depth[v]) std::swap(u, v);
T.update(1, idx[v], idx[u], k);
}
void update_tree (int u, i64 k) {
/*
这里更新u的子树,所以点u对应的idx[u]就是根
而剖分为链后,idx[u] + size[u] - 1是该链的最后一个节点
*/
T.update(1, idx[u], idx[u] + size[u] - 1, k);
}
i64 query_path (int u, int v) {
i64 ans = 0;
/*
当top[u] != top[v],说明此时的顶端不是最近公共祖先
我们进行while循环的目的是在找最近公共祖先
*/
while (top[u] != top[v]) {
if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
ans += T.query(1, idx[top[u]], idx[u]); // 更新u节点所在的重链
u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的查询
}
/* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步查询操作 */
if (depth[u] < depth[v]) std::swap(u, v);
ans += T.query(1, idx[v], idx[u]);
return ans;
}
i64 query_tree (int u) {
/*
这里查询u的子树,所以点u对应的idx[u]就是根
而剖分为链后,idx[u] + size[u] - 1是该链的最后一个节点
*/
return T.query(1, idx[u], idx[u] + size[u] - 1);
}
int32_t main () {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m, root, mod;
std::cin >> n >> m >> root >> mod;
for (int i = 1; i <= n; i++) {
std::cin >> w[i];
}
for (int i = 1; i < n; i++) {
int u, v;
std::cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(root, 0);
dfs2(root, root);
T.build(1, 1, n);
while (m--) {
i64 op, x, y, z;
std::cin >> op;
if (op == 1) {
std::cin >> x >> y >> z;
update_path(x, y, z);
} else if (op == 2) {
std::cin >> x >> y;
std::cout << query_path(x, y) % mod << '\n';
} else if (op == 3) {
std::cin >> x >> y;
update_tree(x, y);
} else {
std::cin >> x;
std::cout << query_tree(x) % mod << '\n';
}
}
return 0;
}

浙公网安备 33010602011771号