全局二叉平衡树
作用
能够在 \(O(\log n)\) 的复杂度内修改一条链、查询一条链、求最近公共祖先、子树修改、子树查询……
性质
- 全局平衡二叉树由若干个二叉树通过轻边连接,每个二叉树维护了原树的一条重链,其中序遍历就是这条重链深度单调递增的顺序。
- 每个节点只出现在一颗二叉树中。
- 轻边连接的一对点,认父不认子,即只能从子结点访问到父节点。
- 树高 \(O(\log n)\) 级别。
建树
首先,先求出每个节点的子树大小、重儿子、父亲节点、深度等基础信息。
因为这些都是原树上的信息,所以封装了一下。
struct Tree {
vector< int> G[N];
int siz[N], fa[N], son[N], dep[N];
void dfs( int u, int fu) {
fa[u] = fu, siz[u] = 1, dep[u] = dep[fu] + 1;
for ( auto v : G[u]) {
if (v == fu) continue ;
dfs(v, u);
siz[u] += siz[v];
son[u] = (siz[v] > siz[son[u]] ? v : son[u]);
}
}
} T;
要建全局平衡二叉树时,先找到一条重链,然后把这条重链建成一颗二叉树的形式。
要建立二叉树,先把这条重链中的点存下来,然后每次找这些点中的带权中点为根,再递归进左、右儿子继续处理。这里带的权为原树上轻儿子的子树大小之和 \(+1\),也就是自己的子树大小减去重儿子的子树大小。
int build1( int l, int r, int fu) {
int L = l, R = r;
while (L < R) {
int mid = (L + R + 1) >> 1;
if (2 * (sums[mid] - sums[l - 1]) <= sums[r] - sums[l - 1]) L = mid;
else R = mid - 1;
}
// 二分计算带权中点
int u = nd[L];
siz[u] = r - l + 1, fa[u] = fu, dep[u] = dep[fu] + 1;
if (l <= L - 1) ls[u] = build1(l, L - 1, u);
if (L + 1 <= r) rs[u] = build1(L + 1, r, u);
// 这里的 siz、fa、dep、ls、rs 都是二叉树上的信息
return u;
}
int build2( int u) {
int p = u, cnt = 0;
do {
for ( auto v : T.G[p]) {
if (v == T.fa[p] || v == T.son[p]) continue ;
fa[build2(v)] = p;
}
} while (p = T.son[p]);
p = u;
do {
nd[++ cnt] = p, top[p] = u;
// 这里 top 表示点 p 在原树重链的链顶
sums[cnt] = sums[cnt - 1] + T.siz[p] - T.siz[T.son[p]];
} while (p = T.son[p]);
return build1(1, cnt, 0);
}
这里建树的复杂度是 \(O(n\log n)\) 的,给一条重链建二叉树带了一个 \(\log\)。
建出来的全局平衡二叉树的树高是 \(O(\log n)\) 级别。考虑一个点一直跳自己父亲,如果跳过一条轻边,由重链剖分的性质可得跳轻边最多 \(O(\log n)\) 条;如果跳过一条重边,因为建二叉树时根节点是带权中点,那么跳一次重边算上轻儿子的大小至少翻倍,所以重边也只会跳 \(O(\log n)\) 次。
一些操作
【模板】最近公共祖先(LCA)
求最近公共祖先和数剖差不多,对于两个点 \(u,v\),让它们在全局平衡二叉树上面跳,直到 \(top_u=top_v\),此时深度小的点就是最近公共祖先。
int lca( int u, int v) {
while (top[u] != top[v]) {
if (T.dep[top[u]] < T.dep[top[v]]) v = fa[v];
else u = fa[u];
}
return (T.dep[u] < T.dep[v] ? u : v);
}
[ZJOI2008] 树的统计
这里要单点赋值、路径求最大值、路径求和。
考虑对每个点 \(u\) 维护三个东西:\(val_u\) 表示点 \(u\) 的值、\(sum_u\) 表示点 \(u\) 在二叉树上的子树权值和、\(mx_u\) 表示点 \(u\) 在二叉树上的子树最大值。
维护这三个东西是简单的,假设修改了 \(u\),那么 \(u\) 到二叉树的根上的所有点都会改,直接跳父亲即可。
int chk( int u) {
return (u == ls[fa[u]] || u == rs[fa[u]]);
}
// 这个用来判断 u 是否是一颗二叉树的根,是根返回 0
void upd( int u, int x) {
int p = u; val[u] = x;
do {
sum[p] = val[p] + sum[ls[p]] + sum[rs[p]];
mx[p] = max({val[p], mx[ls[p]], mx[rs[p]]});
} while (chk(p) && (p = fa[p]));
}
// 单点赋值
现在考虑路径的询问,以查询路径和为例子。
对于路径的两个端点 \(u\)、\(v\),它们两的路径上肯定由若干条重链构成。
那么路径权值和也就转化成了每条重链权值和之和,所以考虑对单独每条重链进行计算,最后把每条重链的答案加起来就是最终答案。
容易发现,一条重链的权值和就是这条重链所建出的二叉树的根的 \(sum\)。
当然,有三条特殊的重链,它们的权值和就和上述不一样了:包含 \(u\) 的重链、包含 \(v\) 的重链、包含 \(u\)、\(v\) 最近公共祖先的重链。
对于包含 \(u\) 的重链来讲,只要深度小于等于 \(u\) 的点的权值之和。
要维护这个,需要用到上面的一条性质:二叉树的中序遍历就是这条重链深度单调递增的顺序,由这个可以知道只需要求二叉树上在 \(u\) 左边(包含 \(u\))的所有点的权值之和。
求这个就很简单了,让 \(u\) 往上跳,如果 \(u\) 不是从左儿子跳上去的就给答案加上 \(sum_{ls_u}+val_u\)。
对于包含 \(v\) 的重链是一样的。
最后只需要考虑包含 \(u\)、\(v\) 最近公共祖先的重链了。
肯定要先求一个最近公共祖先出来,像上面说的求最近公共祖先一样,让 \(u\)、\(v\) 往上跳,直到 \(top_u=top_v\),这边就假设 \(u\) 是最近公共祖先。
那么要求的就是在二叉树上在 \(u\)、\(v\) 两点中间的所有点(包含 \(u、v\))的权值之和。
考虑在二叉树上求出 \(u\)、\(v\) 两点的最近公共祖先,那么然后让 \(u\)、\(v\) 往上跳,跳到最近公共祖先就停止。
每次跳的 \(v\) 时候,给答案加上 \(sum_{ls_v}+val_v\)(这里还是一样,如果 \(v\) 是从左儿子跳上去的就不加);跳 \(u\) 的时候相反,给答案加上 \(sum_{rs_u}+val_u\)(如果 \(u\) 是从右儿子跳上去的就不加)。
最后让答案加上二叉树上的最近公共祖先的权值。
代码实现比较巧妙,对于普通重链、包含 \(u\) 的重链、包含 \(v\) 的重链,这三种重链的权值之和,可以求 \(u\)、\(v\) 它两原树上的最近公共祖先时就统计完。
求 \(u\)、\(v\) 二叉树上的最近公共祖先时,可以先让 \(u\)、\(v\) 在二叉树上跳到同一高度,最后在一起往上跳,\(u=v\) 时就跳到最近公共祖先了。
void jpls( int & u, int & op, int & res) {
if (op) res += sum[ls[u]] + val[u];
op = (u != ls[fa[u]]);
u = fa[u];
}
// 取左儿子权值的跳
void jprs( int & u, int & op, int & res) {
if (op) res += sum[rs[u]] + val[u];
op = (u != rs[fa[u]]);
u = fa[u];
}
// 取右儿子权值的跳
int asksum( int u, int v) {
int res = 0, opu = 1, opv = 1;
while (top[u] != top[v])
if (T.dep[top[u]] > T.dep[top[v]]) jpls(u, opu, res);
else jpls(v, opv, res);
// 让 u、v$ 往上跳求最近公共祖先,同时计算普通重链、包含 $u$ 的重链、包含 $v$ 的重链,这三种重链的权值之和
if (T.dep[u] > T.dep[v]) swap(u, v);
if (dep[u] > dep[v]) while (dep[u] != dep[v]) jprs(u, opu, res);
else while (dep[v] != dep[u]) jpls(v, opv, res);
// 让 u、v 跳到同一个高度
while (u != v) jprs(u, opu, res), jpls(v, opv, res);
// 一起往上跳
return res + val[u];
// 最后要加上最近公共祖先的权值
}
求路径最大值和这个差不多。
全部代码放出来。
#include <bits/stdc++.h>
void Freopen() {
freopen("", "r", stdin);
freopen("", "w", stdout);
}
using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;
struct Tree {
vector< int> G[N];
int siz[N], fa[N], son[N], dep[N];
void dfs( int u, int fu) {
fa[u] = fu, siz[u] = 1, dep[u] = dep[fu] + 1;
for ( auto v : G[u]) {
if (v == fu) continue ;
dfs(v, u);
siz[u] += siz[v];
son[u] = (siz[v] > siz[son[u]] ? v : son[u]);
}
}
} T;
struct ds {
int siz[N], sums[N], ls[N], rs[N], fa[N], nd[N], top[N], dep[N];
int val[N];
int build1( int l, int r, int fu) {
int L = l, R = r;
while (L < R) {
int mid = (L + R + 1) >> 1;
if (2 * (sums[mid] - sums[l - 1]) <= sums[r] - sums[l - 1]) L = mid;
else R = mid - 1;
}
int u = nd[L];
siz[u] = r - l + 1, fa[u] = fu, dep[u] = dep[fu] + 1;
if (l <= L - 1) ls[u] = build1(l, L - 1, u);
if (L + 1 <= r) rs[u] = build1(L + 1, r, u);
return u;
}
int build2( int u) {
int p = u, cnt = 0;
do {
for ( auto v : T.G[p]) {
if (v == T.fa[p] || v == T.son[p]) continue ;
fa[build2(v)] = p;
}
} while (p = T.son[p]);
p = u;
do {
nd[++ cnt] = p, top[p] = u;
sums[cnt] = sums[cnt - 1] + T.siz[p] - T.siz[T.son[p]];
} while (p = T.son[p]);
return build1(1, cnt, 0);
}
int mx[N], sum[N];
int chk( int u) {
return (u == ls[fa[u]] || u == rs[fa[u]]);
}
void jpls( int & u, int & op, int & res) {
if (op) res += sum[ls[u]] + val[u];
op = (u != ls[fa[u]]);
u = fa[u];
}
void jprs( int & u, int & op, int & res) {
if (op) res += sum[rs[u]] + val[u];
op = (u != rs[fa[u]]);
u = fa[u];
}
void jplm( int & u, int & op, int & res) {
if (op) res = max({res, mx[ls[u]], val[u]});
op = (u != ls[fa[u]]);
u = fa[u];
}
void jprm( int & u, int & op, int & res) {
if (op) res = max({res, mx[rs[u]], val[u]});
op = (u != rs[fa[u]]);
u = fa[u];
}
void upd( int u, int x) {
int p = u; val[u] = x;
do {
sum[p] = val[p] + sum[ls[p]] + sum[rs[p]];
mx[p] = max({val[p], mx[ls[p]], mx[rs[p]]});
} while (chk(p) && (p = fa[p]));
}
int asksum( int u, int v) {
int res = 0, opu = 1, opv = 1;
while (top[u] != top[v])
if (T.dep[top[u]] > T.dep[top[v]]) jpls(u, opu, res);
else jpls(v, opv, res);
if (T.dep[u] > T.dep[v]) swap(u, v);
if (dep[u] > dep[v]) while (dep[u] != dep[v]) jprs(u, opu, res);
else while (dep[v] != dep[u]) jpls(v, opv, res);
while (u != v) jprs(u, opu, res), jpls(v, opv, res);
return res + val[u];
}
int askmax( int u, int v) {
int res = -inf, opu = 1, opv = 1;
while (top[u] != top[v])
if (T.dep[top[u]] > T.dep[top[v]]) jplm(u, opu, res);
else jplm(v, opv, res);
if (T.dep[u] > T.dep[v]) swap(u, v);
if (dep[u] > dep[v]) while (dep[u] != dep[v]) jprm(u, opu, res);
else while (dep[v] != dep[u]) jplm(v, opv, res);
while (u != v) jprm(u, opu, res), jplm(v, opv, res);
return max(res, val[u]);
}
} T1;
int n, m;
signed main() {
ios :: sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n;
for ( int i = 1; i < n; i ++) {
int u, v; cin >> u >> v;
T.G[u].push_back(v), T.G[v].push_back(u);
}
memset(T1.mx, 128, sizeof T1.mx);
T.dfs(1, 0);
T1.build2(1);
for ( int i = 1; i <= n; i ++) {
int x; cin >> x; T1.upd(i, x);
}
cin >> m;
while (m --) {
string op; int u, v;
cin >> op >> u >> v;
if (op == "QSUM") cout << T1.asksum(u, v) << '\n';
if (op == "CHANGE") T1.upd(u, v);
if (op == "QMAX") cout << T1.askmax(u, v) << '\n';
}
return 0;
}

浙公网安备 33010602011771号