全局二叉平衡树

作用

能够在 \(O(\log n)\) 的复杂度内修改一条链、查询一条链、求最近公共祖先、子树修改、子树查询……

性质

  1. 全局平衡二叉树由若干个二叉树通过轻边连接,每个二叉树维护了原树的一条重链,其中序遍历就是这条重链深度单调递增的顺序。
  2. 每个节点只出现在一颗二叉树中。
  3. 轻边连接的一对点,认父不认子,即只能从子结点访问到父节点。
  4. 树高 \(O(\log n)\) 级别。

建树

首先,先求出每个节点的子树大小、重儿子、父亲节点、深度等基础信息。

因为这些都是原树上的信息,所以封装了一下。

struct Tree {
    vector< int> G[N];
    int siz[N], fa[N], son[N], dep[N];

    void dfs( int u, int fu) {
        fa[u] = fu, siz[u] = 1, dep[u] = dep[fu] + 1;

        for ( auto v : G[u]) {
            if (v == fu) continue ;
            dfs(v, u);
            siz[u] += siz[v];
            son[u] = (siz[v] > siz[son[u]] ? v : son[u]);
        }
    }
} T;

要建全局平衡二叉树时,先找到一条重链,然后把这条重链建成一颗二叉树的形式。

要建立二叉树,先把这条重链中的点存下来,然后每次找这些点中的带权中点为根,再递归进左、右儿子继续处理。这里带的权为原树上轻儿子的子树大小之和 \(+1\),也就是自己的子树大小减去重儿子的子树大小。

int build1( int l, int r, int fu) {
    int L = l, R = r;

    while (L < R) {
        int mid = (L + R + 1) >> 1;

        if (2 * (sums[mid] - sums[l - 1]) <= sums[r] - sums[l - 1]) L = mid;
        else R = mid - 1;
    }
    // 二分计算带权中点

    int u = nd[L];
    siz[u] = r - l + 1, fa[u] = fu, dep[u] = dep[fu] + 1;
    if (l <= L - 1) ls[u] = build1(l, L - 1, u);
    if (L + 1 <= r) rs[u] = build1(L + 1, r, u);
    // 这里的 siz、fa、dep、ls、rs 都是二叉树上的信息

    return u;
}

int build2( int u) {    
    int p = u, cnt = 0;

    do {
        for ( auto v : T.G[p]) {
            if (v == T.fa[p] || v == T.son[p]) continue ;
            fa[build2(v)] = p;
        }
    } while (p = T.son[p]);

    p = u;

    do {
        nd[++ cnt] = p, top[p] = u;
        // 这里 top 表示点 p 在原树重链的链顶
        sums[cnt] = sums[cnt - 1] + T.siz[p] - T.siz[T.son[p]];
    } while (p = T.son[p]);

    return build1(1, cnt, 0);
}

这里建树的复杂度是 \(O(n\log n)\) 的,给一条重链建二叉树带了一个 \(\log\)

建出来的全局平衡二叉树的树高是 \(O(\log n)\) 级别。考虑一个点一直跳自己父亲,如果跳过一条轻边,由重链剖分的性质可得跳轻边最多 \(O(\log n)\) 条;如果跳过一条重边,因为建二叉树时根节点是带权中点,那么跳一次重边算上轻儿子的大小至少翻倍,所以重边也只会跳 \(O(\log n)\) 次。

一些操作

【模板】最近公共祖先(LCA)

求最近公共祖先和数剖差不多,对于两个点 \(u,v\),让它们在全局平衡二叉树上面跳,直到 \(top_u=top_v\),此时深度小的点就是最近公共祖先。

int lca( int u, int v) {
    while (top[u] != top[v]) {
        if (T.dep[top[u]] < T.dep[top[v]]) v = fa[v];
        else u = fa[u];
    }

    return (T.dep[u] < T.dep[v] ? u : v);
}

[ZJOI2008] 树的统计

这里要单点赋值、路径求最大值、路径求和。

考虑对每个点 \(u\) 维护三个东西:\(val_u\) 表示点 \(u\) 的值、\(sum_u\) 表示点 \(u\) 在二叉树上的子树权值和、\(mx_u\) 表示点 \(u\) 在二叉树上的子树最大值。

维护这三个东西是简单的,假设修改了 \(u\),那么 \(u\) 到二叉树的根上的所有点都会改,直接跳父亲即可。

int chk( int u) {
    return (u == ls[fa[u]] || u == rs[fa[u]]);
}
// 这个用来判断 u 是否是一颗二叉树的根,是根返回 0

void upd( int u, int x) {
    int p = u; val[u] = x;

    do {
        sum[p] = val[p] + sum[ls[p]] + sum[rs[p]];
        mx[p] = max({val[p], mx[ls[p]], mx[rs[p]]});
    } while (chk(p) && (p = fa[p]));
}
// 单点赋值

现在考虑路径的询问,以查询路径和为例子。

对于路径的两个端点 \(u\)\(v\),它们两的路径上肯定由若干条重链构成。

那么路径权值和也就转化成了每条重链权值和之和,所以考虑对单独每条重链进行计算,最后把每条重链的答案加起来就是最终答案。

容易发现,一条重链的权值和就是这条重链所建出的二叉树的根的 \(sum\)

当然,有三条特殊的重链,它们的权值和就和上述不一样了:包含 \(u\) 的重链、包含 \(v\) 的重链、包含 \(u\)\(v\) 最近公共祖先的重链。

对于包含 \(u\) 的重链来讲,只要深度小于等于 \(u\) 的点的权值之和。

要维护这个,需要用到上面的一条性质:二叉树的中序遍历就是这条重链深度单调递增的顺序,由这个可以知道只需要求二叉树上在 \(u\) 左边(包含 \(u\))的所有点的权值之和。

求这个就很简单了,让 \(u\) 往上跳,如果 \(u\) 不是从左儿子跳上去的就给答案加上 \(sum_{ls_u}+val_u\)

对于包含 \(v\) 的重链是一样的。

最后只需要考虑包含 \(u\)\(v\) 最近公共祖先的重链了。

肯定要先求一个最近公共祖先出来,像上面说的求最近公共祖先一样,让 \(u\)\(v\) 往上跳,直到 \(top_u=top_v\),这边就假设 \(u\) 是最近公共祖先。

那么要求的就是在二叉树上在 \(u\)\(v\) 两点中间的所有点(包含 \(u、v\))的权值之和。

考虑在二叉树上求出 \(u\)\(v\) 两点的最近公共祖先,那么然后让 \(u\)\(v\) 往上跳,跳到最近公共祖先就停止。

每次跳的 \(v\) 时候,给答案加上 \(sum_{ls_v}+val_v\)(这里还是一样,如果 \(v\) 是从左儿子跳上去的就不加);跳 \(u\) 的时候相反,给答案加上 \(sum_{rs_u}+val_u\)(如果 \(u\) 是从右儿子跳上去的就不加)。

最后让答案加上二叉树上的最近公共祖先的权值。

代码实现比较巧妙,对于普通重链、包含 \(u\) 的重链、包含 \(v\) 的重链,这三种重链的权值之和,可以求 \(u\)\(v\) 它两原树上的最近公共祖先时就统计完。

\(u\)\(v\) 二叉树上的最近公共祖先时,可以先让 \(u\)\(v\) 在二叉树上跳到同一高度,最后在一起往上跳,\(u=v\) 时就跳到最近公共祖先了。

void jpls( int & u, int & op, int & res) {
    if (op) res += sum[ls[u]] + val[u];
    op = (u != ls[fa[u]]);
    u = fa[u];
}
// 取左儿子权值的跳

void jprs( int & u, int & op, int & res) {
    if (op) res += sum[rs[u]] + val[u];
    op = (u != rs[fa[u]]);
    u = fa[u];
}
// 取右儿子权值的跳

int asksum( int u, int v) {
    int res = 0, opu = 1, opv = 1;

    while (top[u] != top[v])
        if (T.dep[top[u]] > T.dep[top[v]]) jpls(u, opu, res);
        else jpls(v, opv, res);
    // 让 u、v$ 往上跳求最近公共祖先,同时计算普通重链、包含 $u$ 的重链、包含 $v$ 的重链,这三种重链的权值之和

    if (T.dep[u] > T.dep[v]) swap(u, v);

    if (dep[u] > dep[v]) while (dep[u] != dep[v]) jprs(u, opu, res);
    else while (dep[v] != dep[u]) jpls(v, opv, res);
    // 让 u、v 跳到同一个高度

    while (u != v) jprs(u, opu, res), jpls(v, opv, res);
    // 一起往上跳

    return res + val[u];
    // 最后要加上最近公共祖先的权值
}

求路径最大值和这个差不多。

全部代码放出来。

#include <bits/stdc++.h>

void Freopen() {
    freopen("", "r", stdin);
    freopen("", "w", stdout);
}

using namespace std;
const int N = 5e5 + 10, M = 2e5 + 10, inf = 1e9, mod = 998244353;

struct Tree {
    vector< int> G[N];
    int siz[N], fa[N], son[N], dep[N];

    void dfs( int u, int fu) {
        fa[u] = fu, siz[u] = 1, dep[u] = dep[fu] + 1;

        for ( auto v : G[u]) {
            if (v == fu) continue ;
            dfs(v, u);
            siz[u] += siz[v];
            son[u] = (siz[v] > siz[son[u]] ? v : son[u]);
        }
    }
} T;

struct ds {
    int siz[N], sums[N], ls[N], rs[N], fa[N], nd[N], top[N], dep[N];
    int val[N];

    int build1( int l, int r, int fu) {
        int L = l, R = r;

        while (L < R) {
            int mid = (L + R + 1) >> 1;

            if (2 * (sums[mid] - sums[l - 1]) <= sums[r] - sums[l - 1]) L = mid;
            else R = mid - 1;
        }

        int u = nd[L];
        siz[u] = r - l + 1, fa[u] = fu, dep[u] = dep[fu] + 1;
        if (l <= L - 1) ls[u] = build1(l, L - 1, u);
        if (L + 1 <= r) rs[u] = build1(L + 1, r, u);

        return u;
    }

    int build2( int u) {    
        int p = u, cnt = 0;

        do {
            for ( auto v : T.G[p]) {
                if (v == T.fa[p] || v == T.son[p]) continue ;
                fa[build2(v)] = p;
            }
        } while (p = T.son[p]);

        p = u;

        do {
            nd[++ cnt] = p, top[p] = u;
            sums[cnt] = sums[cnt - 1] + T.siz[p] - T.siz[T.son[p]];
        } while (p = T.son[p]);
    
        return build1(1, cnt, 0);
    }

    int mx[N], sum[N];

    int chk( int u) {
        return (u == ls[fa[u]] || u == rs[fa[u]]);
    }

    void jpls( int & u, int & op, int & res) {
        if (op) res += sum[ls[u]] + val[u];
        op = (u != ls[fa[u]]);
        u = fa[u];
    }

    void jprs( int & u, int & op, int & res) {
        if (op) res += sum[rs[u]] + val[u];
        op = (u != rs[fa[u]]);
        u = fa[u];
    }

    void jplm( int & u, int & op, int & res) {
        if (op) res = max({res, mx[ls[u]], val[u]});
        op = (u != ls[fa[u]]);
        u = fa[u];
    }

    void jprm( int & u, int & op, int & res) {
        if (op) res = max({res, mx[rs[u]], val[u]});
        op = (u != rs[fa[u]]);
        u = fa[u];
    }

    void upd( int u, int x) {
        int p = u; val[u] = x;

        do {
            sum[p] = val[p] + sum[ls[p]] + sum[rs[p]];
            mx[p] = max({val[p], mx[ls[p]], mx[rs[p]]});
        } while (chk(p) && (p = fa[p]));
    }

    int asksum( int u, int v) {
        int res = 0, opu = 1, opv = 1;

        while (top[u] != top[v])
            if (T.dep[top[u]] > T.dep[top[v]]) jpls(u, opu, res);
            else jpls(v, opv, res);

        if (T.dep[u] > T.dep[v]) swap(u, v);

        if (dep[u] > dep[v]) while (dep[u] != dep[v]) jprs(u, opu, res);
        else while (dep[v] != dep[u]) jpls(v, opv, res);

        while (u != v) jprs(u, opu, res), jpls(v, opv, res);

        return res + val[u];
    }

    int askmax( int u, int v) {
        int res = -inf, opu = 1, opv = 1;

        while (top[u] != top[v])
            if (T.dep[top[u]] > T.dep[top[v]]) jplm(u, opu, res);
            else jplm(v, opv, res);

        if (T.dep[u] > T.dep[v]) swap(u, v);

        if (dep[u] > dep[v]) while (dep[u] != dep[v]) jprm(u, opu, res);
        else while (dep[v] != dep[u]) jplm(v, opv, res);

        while (u != v) jprm(u, opu, res), jplm(v, opv, res);

        return max(res, val[u]);
    }
} T1;

int n, m;

signed main() {
    ios :: sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cin >> n;

    for ( int i = 1; i < n; i ++) {
        int u, v; cin >> u >> v;
        T.G[u].push_back(v), T.G[v].push_back(u);
    }

    memset(T1.mx, 128, sizeof T1.mx);
    T.dfs(1, 0); 
    T1.build2(1);

    for ( int i = 1; i <= n; i ++) {
        int x; cin >> x; T1.upd(i, x);
    }

    cin >> m;

    while (m --) {
        string op; int u, v;
        cin >> op >> u >> v;

        if (op == "QSUM") cout << T1.asksum(u, v) << '\n';
        if (op == "CHANGE") T1.upd(u, v);
        if (op == "QMAX") cout << T1.askmax(u, v) << '\n';
    }

    return 0;
}
posted @ 2025-08-16 16:01  咚咚的锵  阅读(58)  评论(0)    收藏  举报