树链剖分

树链剖分

树链剖分是一种将树结构化为线性结构的方法
配合线段树区间修改和区间查询的功能,实现树上的\(O(logn)\)级别的修改与查询

其实现方式是将树划分为若干条链,称为重链
而具体如何剖分呢?接下来要引入几个定义:

重儿子:对于一个父节点的若干个子节点,其中子树节点最多的子节点为重儿子
轻儿子:除重儿子以外的节点
重边:父节点和重儿子连成的边
重链:由多条重边连成的链

我们可以得出以下结论:

  1. 整棵树会被剖分成若干条重链
  2. 每条重链的顶端一定是轻儿子 (单独一个轻儿子也是一条重链,是特殊的重链)
  3. 任意一条路径被切分成不超过\(logn\)条重链,也是树链剖分能将时间复杂度控制在\(O(logn)\)的核心原因

我们按照上述原理,将树剖分为若干条重链存储进一个数组中,再用线段树维护区间的特性,将树上修改查询转化为区间的修改和查询;那么我们如何实现这个过程呢?


\(First\)
我们需要先开几个数组:
(这里推荐用全局变量,否则要么传递超多参数,要么用\(lambda\)表达式写函数,都非常麻烦;一般的出题人要考察树剖,大抵也不会是多实例测试,那样就太恶心了)

/*
        g[u]: 和u相连的所有节点(存图)
        w[u]: u的权值
       fa[u]: u的父节点
    depth[u]: u的深度
      son[u]: u的重儿子
     size[u]: 以u为根的子树节点数
      top[u]: u所在重链的顶点
      idx[u]: u剖分后的新编号
      New[u]: 新编号在树中对应节点的权值,作为线段树维护的序列
*/ 

std::vector<int> g[N], w(N);
std::vector<int> fa(N), depth(N), size(N), son(N);
std::vector<int> top(N), idx(N), New(N);
int cnt = 0;

\(Second\)
接下来,我们要预处理一下这些数组,通过两次\(dfs\)实现
第一次\(dfs\),我们传递两个参数,一个是当前节点\(now\),一个是当前节点的父节点\(last\)
在每一次搜索时,更新\(fa、depth、size、son\)数组

其中,\(fa\)数组和\(depth\)数组都很好办,显然\(fa[now] = last, depth[now] = depth[last] + 1\)
\(size\)数组和\(son\)数组相对麻烦一些,我们只能在计算完,也就是递归返回时才能正确更新
我们定义\(next\)\(now\)节点的子节点,那么\(size[now] = ∑size[next]\)
那么\(son[now]\),也就是具有最大\(size[next]\)的子节点了,代码如下:

/* 预处理fa、depth、size、son数组 */
void dfs1 (int now, int last) {
    fa[now] = last;               // 存储当前节点的父亲节点 
    depth[now] = depth[last] + 1; // 更新当前节点的深度,为父亲节点的深度 + 1
    size[now] = 1;                // 先初始化当前节点的子树节点数为1,后面再进行更新

    /* 接下来处理now节点的子节点 */
    for (auto next : g[now]) {
        if (next == last) continue; // 相当于判断next是否被访问过,访问过就直接跳过了

        dfs1(next, now);
        size[now] += size[next]; // 递归返回时再更新子树节点数,这也是不能用bfs替代的核心原因

        if (size[son[now]] < size[next]) son[now] = next; // 更新重儿子,递归返回时才可以更新
    }
}

下面需要处理\(top、idx、New\)数组,也用\(dfs\)实现,但传递的参数和上面不一样
这次要传递两个参数:当前节点\(now\),当前节点所在重链的顶端的轻儿子\(top\:dot\)
那么很显然,每次搜索时,\(top[now] = top\:dot\)

然后我们需要将这些点按顺序在\(idx\)数组中标记,实现方法是开一个变量\(cnt\),然后\(idx[now] =\) ++\(cnt\)
再按照新的编号,在\(New\)数组中对应的给带上权值,也就是\(New[cnt] = w[now]\)

在第一个\(dfs\)中,我们已经正确找到了每个结点的重儿子\(son[now]\)
在第二次\(dfs\)中,我们就要借助这个,先搜索重儿子,再搜索轻儿子,以保证\(dfs\)序的正确

/* 预处理top、idx、New数组 */
void dfs2 (int now, int top_dot) {
    top[now] = top_dot;  // 存储当前节点的顶端,也就是当前节点所在重链的顶端 
    idx[now] = ++cnt;    // 当前节点剖分后的新编号 
    New[cnt] = w[now];   // 新编号对应的原来的权值 

    /* 先处理重儿子,保证dfs序 */
    if (son[now]) dfs2(son[now], top_dot);

    /* 再处理轻儿子 */
    for (auto next : g[now]) {
        if (next == fa[now] || next == son[now]) continue; // 这里也是判断next是否被访问过
        dfs2(next, next);
    }
}

至此预处理过程全部结束,接下来可以开始区间维护的部分了


\(Third\)
区间维护用到的线段树,是针对\(New\)数组建立的区间和线段树
区间的更新和查询也都是在\(New\)数组上进行,我们先进行建树过程,代码如下:
(这里的线段树就是最最基础的区间和线段树的板子,直接搬过来就可以用;但不会线段树还请先缓一缓,补完区间和线段树的知识再来学树剖吧)

/* 区间和线段树模板 */
struct SegTree {
    #define lc u << 1
    #define rc u << 1 | 1

    struct node {
        int l, r;
        i64 val, tag;   
    } tree[4 * N];

    void push_up (int u) {
        tree[u].val = tree[lc].val + tree[rc].val;
    }

    void push_down (int u) {
        if (tree[u].tag) {
        // Left:    
            tree[lc].val += (tree[lc].r - tree[lc].l + 1) * tree[u].tag;
            tree[lc].tag += tree[u].tag;
        // Right:
            tree[rc].val += (tree[rc].r - tree[rc].l + 1) * tree[u].tag;
            tree[rc].tag += tree[u].tag;
        // Final
            tree[u].tag = 0;
        }
    }

    void build (int u, int l, int r) {
        tree[u].l = l;
        tree[u].r = r;
        tree[u].tag = 0;

        if (l == r) {
            tree[u].val = New[l]; // 对New数组进行建树
            return;
        }

        int mid = tree[u].l + tree[u].r >> 1;
        build(lc, l, mid);
        build(rc, mid + 1, r);
        push_up(u);
    }

    void update (int u, int x, int y, i64 k) {
        if (x <= tree[u].l && tree[u].r <= y) {
            tree[u].val += (tree[u].r - tree[u].l + 1) * k;
            tree[u].tag += k;
            return;
        }
        push_down(u);

        int mid = tree[u].l + tree[u].r >> 1;
        if (x <= mid) update(lc, x, y, k);
        if (y >  mid) update(rc, x, y, k);
        push_up(u);
    }

    i64 query (int u, int x, int y) {
        if (x <= tree[u].l && tree[u].r <= y) {
            return tree[u].val;
        }
        push_down(u);

        int mid = tree[u].l + tree[u].r >> 1;
        i64 ans = 0;
        if (x <= mid) ans += query(lc, x, y);
        if (y >  mid) ans += query(rc, x, y);
        return ans;
    }
} T;

\(Fourth\)
完成了建树过程,接下来我们开始写区间修改和查询的代码
我们回忆一下刚才的剖分过程:\(size[now]\)代表了\(now\)节点的子树大小,而我们也是按照\(dfs\)序存进了\(New\)数组中
因此,如果要修改 / 查询某个节点\(u\)子树的所有节点之和,
也就是利用线段树在\(New\)序列上修改 / 查询 \([idx[u],idx[u] + size[u] - 1]\) 这段区间!
那么,修改 / 查询子树的代码就非常简单了,如下所示:

/* 
    这里更新u的子树,所以点u对应的idx[u]就是根
    而剖分为链后,idx[u] + size[u] - 1是该链的最后一个节点
*/

void update_tree (int u, i64 k) {
    T.update(1, idx[u], idx[u] + size[u] - 1, k);
}

i64 query_tree (int u) {
    return T.query(1, idx[u], idx[u] + size[u] - 1);
}

但是还有另一种修改 / 查询,也是更常见的修改 / 查询方式
就是给定两个点\(u\)\(v\),修改点\(u\)到点\(v\)的简单路径上的所有点的权值,或是查询点\(u\)到点\(v\)的简单路径上的所有点的权值之和
这时我们就要利用先前剖分的重链的特性了:

首先我们要判断一下,\(top[u]\)\(top[v]\)是否相等
如果他们不相等,说明当前\(u\)\(v\)所在的两条重链不相交,而此时,我们要分别处理这两条重链
处理完之后,再让点\(u\)和点\(v\)分别跳到他们所在重链的\(top\)节点的父节点,也就是\(fa[top[u]]\)\(fa[top[v]]\)

但是为了防止重复的问题,我们先对点\(u\)和点\(v\)进行判断,谁深度大,处理谁
也就是 if (depth[top[u]] < depth[top[v]]) swap(u, v) 保证\(u\)是更深的节点
然后去处理\(u\)所在的重链,再让\(u\)跳到\(fa[top[u]]\)

当执行完所有的\(top[u] ≠ top[v]\)的情况后,也就是\(u\)\(v\)所在重链顶端是同一个节点
此时到了最后一步,我们不需要再分开处理了,直接一步处理完,代码如下:

void update_path (int u, int v, i64 k) {
/* 
    当top[u] != top[v],说明此时的顶端不是最近公共祖先 
    我们进行while循环的目的是在找最近公共祖先
*/
    while (top[u] != top[v]) {
        if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
        T.update(1, idx[top[u]], idx[u], k); // 更新u节点所在的重链 
        u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的更新
    }

    /* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步更新操作 */
    if (depth[u] < depth[v]) std::swap(u, v);
    T.update(1, idx[v], idx[u], k);
}

i64 query_path (int u, int v) {
    i64 ans = 0;
/* 
    当top[u] != top[v],说明此时的顶端不是最近公共祖先 
    我们进行while循环的目的是在找最近公共祖先
*/
    while (top[u] != top[v]) {
        if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
        ans += T.query(1, idx[top[u]], idx[u]); // 更新u节点所在的重链
        u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的查询
    }

    /* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步查询操作 */
    if (depth[u] < depth[v]) std::swap(u, v);
    ans += T.query(1, idx[v], idx[u]);
    return ans;
}

\(Final\)
最终代码如下:(可通过洛谷P3384,https://www.luogu.com.cn/problem/P3384)

#include <bits/stdc++.h>
#include <bits/extc++.h>

using i64 = int64_t;
using pii = std::pair<int, int>;
using namespace __gnu_pbds;
using ordered_set = tree<i64, null_type, std::less<i64>, rb_tree_tag, tree_order_statistics_node_update>;
using ordered_multiset = tree<i64, null_type, std::less_equal<i64>, rb_tree_tag, tree_order_statistics_node_update>;

constexpr int N = 1e5 + 10;
constexpr int inf = INT_MAX;
constexpr i64 INF = LLONG_MAX;
// constexpr int mod = 998244353;
constexpr int dx[] = {+0, -1, +0, +1, -1, +1, +1, -1};
constexpr int dy[] = {+1, +0, -1, +0, +1, +1, -1, -1};

/*
        w[u]: u的权值
       fa[u]: u的父节点
    depth[u]: u的深度
      son[u]: u的重儿子
     size[u]: 以u为根的子树节点数
      top[u]: u所在重链的顶点
      idx[u]: u剖分后的新编号
      New[u]: 新编号在树中对应节点的权值,作为线段树维护的序列
*/ 

std::vector<int> g[N], w(N);
std::vector<int> fa(N), depth(N), size(N), son(N);
std::vector<int> top(N), idx(N), New(N);
int cnt = 0;

/* 区间和线段树模板 */
struct SegTree {
    #define lc u << 1
    #define rc u << 1 | 1

    struct node {
        int l, r;
        i64 val, tag;   
    } tree[4 * N];

    void push_up (int u) {
        tree[u].val = tree[lc].val + tree[rc].val;
    }

    void push_down (int u) {
        if (tree[u].tag) {
        // Left:    
            tree[lc].val += (tree[lc].r - tree[lc].l + 1) * tree[u].tag;
            tree[lc].tag += tree[u].tag;
        // Right:
            tree[rc].val += (tree[rc].r - tree[rc].l + 1) * tree[u].tag;
            tree[rc].tag += tree[u].tag;
        // Final
            tree[u].tag = 0;
        }
    }

    void build (int u, int l, int r) {
        tree[u].l = l;
        tree[u].r = r;
        tree[u].tag = 0;

        if (l == r) {
            tree[u].val = New[l];
            return;
        }

        int mid = tree[u].l + tree[u].r >> 1;
        build(lc, l, mid);
        build(rc, mid + 1, r);
        push_up(u);
    }

    void update (int u, int x, int y, i64 k) {
        if (x <= tree[u].l && tree[u].r <= y) {
            tree[u].val += (tree[u].r - tree[u].l + 1) * k;
            tree[u].tag += k;
            return;
        }
        push_down(u);

        int mid = tree[u].l + tree[u].r >> 1;
        if (x <= mid) update(lc, x, y, k);
        if (y >  mid) update(rc, x, y, k);
        push_up(u);
    }

    i64 query (int u, int x, int y) {
        if (x <= tree[u].l && tree[u].r <= y) {
            return tree[u].val;
        }
        push_down(u);

        int mid = tree[u].l + tree[u].r >> 1;
        i64 ans = 0;
        if (x <= mid) ans += query(lc, x, y);
        if (y >  mid) ans += query(rc, x, y);
        return ans;
    }
} T;

/* 预处理fa、depth、size、son数组 */
void dfs1 (int now, int last) {
    fa[now] = last;               // 存储当前节点的父亲节点 
    depth[now] = depth[last] + 1; // 更新当前节点的深度,为父亲节点的深度 + 1
    size[now] = 1;                // 先初始化当前节点的子树节点数为1,后面再进行更新

    /* 接下来处理now节点的子节点 */
    for (auto next : g[now]) {
        if (next == last) continue; // 相当于判断next是否被访问过,访问过就直接跳过了

        dfs1(next, now);
        size[now] += size[next]; // 递归返回时再更新子树节点数,这也是不能用bfs替代的核心原因

        if (size[son[now]] < size[next]) son[now] = next; // 更新重儿子,递归返回时才可以更新
    }
}

/* 预处理top、idx、New数组 */
void dfs2 (int now, int top_dot) {
    top[now] = top_dot;  // 存储当前节点的顶端,也就是当前节点所在重链的顶端 
    idx[now] = ++cnt;    // 当前节点剖分后的新编号 
    New[cnt] = w[now];   // 新编号对应的原来的权值 

    /* 先处理重儿子,保证dfs序 */
    if (son[now]) dfs2(son[now], top_dot);

    /* 再处理轻儿子 */
    for (auto next : g[now]) {
        if (next == fa[now] || next == son[now]) continue; // 这里也是判断next是否被访问过
        dfs2(next, next);
    }
}

void update_path (int u, int v, i64 k) {
/* 
    当top[u] != top[v],说明此时的顶端不是最近公共祖先 
    我们进行while循环的目的是在找最近公共祖先
*/
    while (top[u] != top[v]) {
        if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
        T.update(1, idx[top[u]], idx[u], k); // 更新u节点所在的重链 
        u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的更新
    }

    /* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步更新操作 */
    if (depth[u] < depth[v]) std::swap(u, v);
    T.update(1, idx[v], idx[u], k);
}

void update_tree (int u, i64 k) {
/* 
    这里更新u的子树,所以点u对应的idx[u]就是根
    而剖分为链后,idx[u] + size[u] - 1是该链的最后一个节点
*/
    T.update(1, idx[u], idx[u] + size[u] - 1, k);
}

i64 query_path (int u, int v) {
    i64 ans = 0;
/* 
    当top[u] != top[v],说明此时的顶端不是最近公共祖先 
    我们进行while循环的目的是在找最近公共祖先
*/
    while (top[u] != top[v]) {
        if (depth[top[u]] < depth[top[v]]) std::swap(u, v); // 这里要保证u是更深的节点
        ans += T.query(1, idx[top[u]], idx[u]); // 更新u节点所在的重链
        u = fa[top[u]]; // 让u跳到当前重链顶端的父节点,完成这条重链的查询
    }

    /* 执行完while循环,此时u和v所在重链的顶端是相同的,直接执行最后一步查询操作 */
    if (depth[u] < depth[v]) std::swap(u, v);
    ans += T.query(1, idx[v], idx[u]);
    return ans;
}

i64 query_tree (int u) {
/* 
    这里查询u的子树,所以点u对应的idx[u]就是根
    而剖分为链后,idx[u] + size[u] - 1是该链的最后一个节点
*/
    return T.query(1, idx[u], idx[u] + size[u] - 1);
}

int32_t main () {   
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
    int n, m, root, mod;
    std::cin >> n >> m >> root >> mod;

    for (int i = 1; i <= n; i++) {
        std::cin >> w[i];
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        std::cin >> u >> v;

        g[u].push_back(v);
        g[v].push_back(u);
    }

    dfs1(root, 0);
    dfs2(root, root);
    T.build(1, 1, n);

    while (m--) {
        i64 op, x, y, z;
        std::cin >> op;

        if (op == 1) {
            std::cin >> x >> y >> z;
            update_path(x, y, z);
        } else if (op == 2) {
            std::cin >> x >> y;
            std::cout << query_path(x, y) % mod << '\n';
        } else if (op == 3) {
            std::cin >> x >> y;
            update_tree(x, y);
        } else {
            std::cin >> x;
            std::cout << query_tree(x) % mod << '\n';
        }
    }

    return 0;
}
posted @ 2025-08-22 13:36  _彩云归  阅读(60)  评论(0)    收藏  举报