【数据结构】树链剖分详细讲解

 “在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。

树链剖分是把一棵树分割成若干条链,以便于维护信息的一种方法,其中最常用的是重链剖分(Heavy Path Decomposition,重路径分解),所以一般提到树链剖分或树剖都是指重链剖分。除此之外还有长链剖分和实链剖分等,本文暂不介绍。

首先我们需要明确概念:

  • 重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;
  • 轻儿子:父亲节点中除了重儿子以外的儿子;
  • 重边:父亲结点和重儿子连成的边;
  • 轻边:父亲节点和轻儿子连成的边;
  • 重链:由多条重边连接而成的路径;
  • 轻链:由多条轻边连接而成的路径;

我们定义树上一个节点的子节点中子树最大的一个为它的重子节点,其余的为轻子节点。一个节点连向其重子节点的边称为重边,连向轻子节点的边则为轻边。如果把根节点看作轻的,那么从每个轻节点出发,不断向下走重边,都对应了一条链,于是我们把树剖分成了 \(l\) 条链,其中 \(l\) 是轻节点的数量。

最近因为画图工具出了点问题,所以转载了Pecco学长的示意图(下面求LCA的方法的部分内容也来自Pecco学长)

剖分后的树(重链)有如下性质:

  1. 对于节点数为 \(n\) 的树,从任意节点向上走到根节点,经过的轻边数量不会超过 \(log\ n\)

    这是因为当我们向下经过一条 轻边 时,所在子树的大小至少会除以二。所以说,对于树上的任意一条路径,把它拆分成从 \(lca\) 分别向两边往下走,分别最多走 \(O(\log n)\) 次,树上的每条路径都可以被拆分成不超过 \(O(\log n)\) 条重链。

  2. 树上每个节点都属于且仅属于一条重链

重链开头的结点不一定是重子节点(因为重边是对于每一个结点都有定义的)。所有的重链将整棵树 完全剖分

尽管树链部分看起来很难实现(的确有点繁琐),但我们可以用两个 DFS 来实现树链(树剖)。

相关伪码(来自 OI wiki)

第一个 DFS 记录每个结点的父节点(father)、深度(deep)、子树大小(size)、重子节点(hson)。

\[\begin{array}{l} \text{TREE-BUILD }(u,dep) \\ \begin{array}{ll} 1 & u.hson\gets 0 \\ 2 & u.hson.size\gets 0 \\ 3 & u.deep\gets dep \\ 4 & u.size\gets 1 \\ 5 & \textbf{for }\text{each }u\text{'s son }v \\ 6 & \qquad u.size\gets u.size + \text{TREE-BUILD }(v,dep+1) \\ 7 & \qquad v.father\gets u \\ 8 & \qquad \textbf{if }v.size> u.hson.size \\ 9 & \qquad \qquad u.hson\gets v \\ 10 & \textbf{return } u.size \end{array} \end{array} \]

第二个 DFS 记录所在链的链顶(top,应初始化为结点本身)、重边优先遍历时的 DFS 序(dfn)、DFS 序对应的节点编号(rank)。

\[\begin{array}{l} \text{TREE-DECOMPOSITION }(u,top) \\ \begin{array}{ll} 1 & u.top\gets top \\ 2 & tot\gets tot+1\\ 3 & u.dfn\gets tot \\ 4 & rank(tot)\gets u \\ 5 & \textbf{if }u.hson\text{ is not }0 \\ 6 & \qquad \text{TREE-DECOMPOSITION }(u.hson,top) \\ 7 & \qquad \textbf{for }\text{each }u\text{'s son }v \\ 8 & \qquad \qquad \textbf{if }v\text{ is not }u.hson \\ 9 & \qquad \qquad \qquad \text{TREE-DECOMPOSITION }(v,v) \end{array} \end{array} \]

以下为代码实现。

我们先给出一些定义:

  • \(fa(x)\) 表示节点 \(x\) 在树上的父亲(也就是父节点)。
  • \(dep(x)\) 表示节点 \(x\) 在树上的深度。
  • \(siz(x)\) 表示节点 \(x\) 的子树的节点个数。
  • \(son(x)\) 表示节点 \(x\)重儿子
  • \(top(x)\) 表示节点 \(x\) 所在 重链 的顶部节点(深度最小)。
  • \(dfn(x)\) 表示节点 \(x\)DFS 序 ,也是其在线段树中的编号。
  • \(rnk(x)\) 表示 DFS 序所对应的节点编号,有 \(rnk(dfn(x))=x\)

我们进行两遍 DFS 预处理出这些值,其中第一次 DFS 求出 \(fa(x)\) , \(dep(x)\) , \(siz(x)\) , \(son(x)\) ,第二次 DFS 求出 \(top(x)\) , \(dfn(x)\) , \(rnk(x)\)

// 当然树链写法不止一种,这个是我学习Oi wiki上知识点记录的模板代码
void dfs1(int o) {
    son[o] = -1, siz[o] = 1;
    for (int j = h[o]; j; j = nxt[j])
        if (!dep[p[j]]) {
            dep[p[j]] = dep[o] + 1;
            fa[p[j]] = o;
            dfs1(p[j]);
            siz[o] += siz[p[j]];
            if (son[o] == -1 || siz[p[j]] > siz[son[o]])
                son[o] = p[j];
        }
}
void dfs2(int o, int t) {
    top[o] = t;
    dfn[o] = ++cnt;
    rnk[cnt] = o;
    if (son[o] == -1)
        return;
    dfs2(son[o], t);  // 优先对重儿子进行 DFS,可以保证同一条重链上的点 DFS 序连续
    for (int j = h[o]; j; j = nxt[j])
        if (p[j] != son[o] && p[j] != fa[o])
            dfs2(p[j], p[j]);
}
// 写法2:来自Peocco学长,代码仅作学习使用
void dfs1(int p, int d = 1){
    int Siz = 1,ma = 0;
    dep[p] = d;
    for(auto q : edges[p]){ // for循环写法和auto是C++11标准,竞赛可用
        dfs1(q,d + 1);
        fa[q] = p;
        Siz += sz[q];
        if(sz[q] > ma)
            hson[p] = q, ma = sz[q];// hson = 重儿子
    }
    sz[p] = Siz; 
}
// 需要先把根节点的top初始化为自身
void dfs2(int p){
    for(auto q : edges[p]){
        if(!top[q]){
            if(q == hson[p])
                top[q] = top[p];
           	else 
                top[q] = q;
            dfs2(q);
        }
    }
}

以上这样便完成了剖分。

学习到这里想想开头的那句话:

 “在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。

如果不能一下想不到线段树解决不了的问题的话不如看看这道题 ↓

Hdu 3966 Aragorn's Story

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3966

题意:给一棵树,并给定各个点权的值,然后有3种操作:
  I C1 C2 K: 把C1与C2的路径上的所有点权值加上K
  D C1 C2 K:把C1与C2的路径上的所有点权值减去K
  Q C:查询节点编号为C的权值

  分析:典型的树链剖分题目,先进行剖分,然后用线段树去维护即可

// Author : RioTian
// Time : 20/11/30
#include <bits/stdc++.h>
using namespace std;

#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1

typedef long long ll;
typedef int lld;

stack<int> ss;
const int maxn = 2e5 + 10;
const int inf = ~0u >> 2;  // 1073741823
int M[maxn << 2];
int add[maxn << 2];

struct node {
    int s, t, w, next;
} edges[maxn << 1];

int E, n;
int Size[maxn], fa[maxn], heavy[maxn], head[maxn], vis[maxn];
int dep[maxn], rev[maxn], num[maxn], cost[maxn], w[maxn];
int Seg_size;

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void add_edge(int s, int t, int w) {
    edges[E].w = w;
    edges[E].s = s;
    edges[E].t = t;
    edges[E].next = head[s];
    head[s] = E++;
}

void dfs(int u, int f) {  //起点,父节点
    int mx = -1, e = -1;
    Size[u] = 1;
    for (int i = head[u]; i != -1; i = edges[i].next) {
        int v = edges[i].t;
        if (v == f)
            continue;
        edges[i].w = edges[i ^ 1].w = w[v];
        dep[v] = dep[u] + 1;
        rev[v] = i ^ 1;
        dfs(v, u);
        Size[u] += Size[v];
        if (Size[v] > mx)
            mx = Size[v], e = i;
    }
    heavy[u] = e;
    if (e != -1)
        fa[edges[e].t] = u;
}

inline void pushup(int rt) {
    M[rt] = M[rt << 1] + M[rt << 1 | 1];
}

void pushdown(int rt, int m) {
    if (add[rt]) {
        add[rt << 1] += add[rt];
        add[rt << 1 | 1] += add[rt];
        M[rt << 1] += add[rt] * (m - (m >> 1));
        M[rt << 1 | 1] += add[rt] * (m >> 1);
        add[rt] = 0;
    }
}

void built(int l, int r, int rt) {
    M[rt] = add[rt] = 0;
    if (l == r)
        return;
    int m = (r + l) >> 1;
    built(lson), built(rson);
}

void update(int L, int R, int val, int l, int r, int rt) {
    if (L <= l && r <= R) {
        M[rt] += val;
        add[rt] += val;
        return;
    }
    pushdown(rt, r - l + 1);
    int m = (l + r) >> 1;
    if (L <= m)
        update(L, R, val, lson);
    if (R > m)
        update(L, R, val, rson);
    pushup(rt);
}

lld query(int L, int R, int l, int r, int rt) {
    if (L <= l && r <= R)
        return M[rt];
    pushdown(rt, r - l + 1);
    int m = (l + r) >> 1;
    lld ret = 0;
    if (L <= m)
        ret += query(L, R, lson);
    if (R > m)
        ret += query(L, R, rson);
    return ret;
}

void prepare() {
    int i;
    built(1, n, 1);
    memset(num, -1, sizeof(num));
    dep[0] = 0;
    Seg_size = 0;
    for (i = 0; i < n; i++)
        fa[i] = i;
    dfs(0, 0);
    for (i = 0; i < n; i++) {
        if (heavy[i] == -1) {
            int pos = i;
            while (pos && edges[heavy[edges[rev[pos]].t]].t == pos) {
                int t = rev[pos];
                num[t] = num[t ^ 1] = ++Seg_size;
                // printf("pos=%d  val=%d t=%d\n", Seg_size, edge[t].w, t);
                update(Seg_size, Seg_size, edges[t].w, 1, n, 1);
                pos = edges[t].t;
            }
        }
    }
}

int lca(int u, int v) {
    while (1) {
        int a = find(u), b = find(v);
        if (a == b)
            return dep[u] < dep[v] ? u : v;  // a,b在同一条重链上
        else if (dep[a] >= dep[b])
            u = edges[rev[a]].t;
        else
            v = edges[rev[b]].t;
    }
}

void CH(int u, int lca, int val) {
    while (u != lca) {
        int r = rev[u];  // printf("r=%d\n",r);
        if (num[r] == -1)
            edges[r].w += val, u = edges[r].t;
        else {
            int p = fa[u];
            if (dep[p] < dep[lca])
                p = lca;
            int l = num[r];
            r = num[heavy[p]];
            update(l, r, val, 1, n, 1);
            u = p;
        }
    }
}

void change(int u, int v, int val) {
    int p = lca(u, v);
    // printf("p=%d\n",p);
    CH(u, p, val);
    CH(v, p, val);
    if (p) {
        int r = rev[p];
        if (num[r] == -1) {
            edges[r ^ 1].w += val;  //在此处发现了我代码的重大bug
            edges[r].w += val;
        } else
            update(num[r], num[r], val, 1, n, 1);
    }  //根节点,特判
    else
        w[p] += val;
}

lld solve(int u) {
    if (!u)
        return w[u];  //根节点,特判
    else {
        int r = rev[u];
        if (num[r] == -1)
            return edges[r].w;
        else
            return query(num[r], num[r], 1, n, 1);
    }
}

int main() {
    // freopen("in.txt", "r", stdin);
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int t, i, a, b, c, m, ca = 1, p;
    while (cin >> n >> m >> p) {
        memset(head, -1, sizeof(head));
        E = 0;
        for (int i = 0; i < n; ++i)
            cin >> w[i];
        for (int i = 0; i < m; ++i) {
            cin >> a >> b;
            a--, b--;
            add_edge(a, b, 0), add_edge(b, a, 0);
        }
        prepare();  // 预处理
        string op;
        while (p--) {
            cin >> op;
            if (op[0] == 'I') {  //区间添加
                cin >> a >> b >> c;
                a--, b--;
                change(a, b, c);
            } else if (op[0] == 'D') {  //区间减少
                cin >> a >> b >> c;
                a--, b--;
                change(a, b, -c);
            } else {  //查询
                cin >> a;
                a--;
                cout << solve(a) << endl;
            }
        }
    }
    return 0;
}

由于数据很大,建议使用快读,而不是像我一样用 cin(差了近500ms了)

折叠代码是千千dalao的解法:

Code
//千千dalao解法
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 50010;
struct Edge {
    int to;
    int next;
} edge[maxn << 1];
int head[maxn], tot;  //链式前向星存储
int top[maxn];        // v所在重链的顶端节点
int fa[maxn];         //父亲节点
int deep[maxn];       //节点深度
int num[maxn];        //以v为根的子树节点数
int p[maxn];          // v与其父亲节点的连边在线段树中的位置
int fp[maxn];         //与p[]数组相反
int son[maxn];        //重儿子
int pos;
int w[maxn];
int ad[maxn << 2];  //树状数组
int n;              //节点数目
void init() {
    memset(head, -1, sizeof(head));
    memset(son, -1, sizeof(son));
    tot = 0;
    pos = 1;  //因为使用树状数组,所以我们pos初始值从1开始
}
void addedge(int u, int v) {
    edge[tot].to = v;
    edge[tot].next = head[u];
    head[u] = tot++;
}
//第一遍dfs,求出 fa,deep,num,son (u为当前节点,pre为其父节点,d为深度)
void dfs1(int u, int pre, int d) {
    deep[u] = d;
    fa[u] = pre;
    num[u] = 1;
    //遍历u的邻接点
    for (int i = head[u]; i != -1; i = edge[i].next) {
        int v = edge[i].to;
        if (v != pre) {
            dfs1(v, u, d + 1);
            num[u] += num[v];
            if (son[u] == -1 || num[v] > num[son[u]])  //寻找重儿子
                son[u] = v;
        }
    }
}
//第二遍dfs,求出 top,p
void dfs2(int u, int sp) {
    top[u] = sp;
    p[u] = pos++;
    fp[p[u]] = u;
    if (son[u] != -1)  //如果当前点存在重儿子,继续延伸形成重链
        dfs2(son[u], sp);
    else
        return;
    for (int i = head[u]; i != -1; i = edge[i].next) {
        int v = edge[i].to;
        if (v != son[u] && v != fa[u])  //遍历所有轻儿子新建重链
            dfs2(v, v);
    }
}
int lowbit(int x) {
    return x & -x;
}
//查询
int query(int i) {
    int s = 0;
    while (i > 0) {
        s += ad[i];
        i -= lowbit(i);
    }
    return s;
}
//增加
void add(int i, int val) {
    while (i <= n) {
        ad[i] += val;
        i += lowbit(i);
    }
}
void update(int u, int v, int val) {
    int f1 = top[u], f2 = top[v];
    while (f1 != f2) {
        if (deep[f1] < deep[f2]) {
            swap(f1, f2);
            swap(u, v);
        }
        //因为区间减法成立,所以我们把对某个区间[f1,u]
        //的更新拆分为 [0,f1] 和 [0,u] 的操作
        add(p[f1], val);
        add(p[u] + 1, -val);
        u = fa[f1];
        f1 = top[u];
    }
    if (deep[u] > deep[v])
        swap(u, v);
    add(p[u], val);
    add(p[v] + 1, -val);
}
int main() {
    ios::sync_with_stdio(false);
    int m, ps;
    while (cin >> n >> m >> ps) {
        int a, b, c;
        for (int i = 1; i <= n; i++)
            cin >> w[i];
        init();
        for (int i = 0; i < m; i++) {
            cin >> a >> b;
            addedge(a, b);
            addedge(b, a);
        }
        dfs1(1, 0, 0);
        dfs2(1, 1);
        memset(ad, 0, sizeof(ad));
        for (int i = 1; i <= n; i++) {
            add(p[i], w[i]);
            add(p[i] + 1, -w[i]);
        }
        for (int i = 0; i < ps; i++) {
            char op;
            cin >> op;
            if (op == 'Q') {
                cin >> a;
                cout << query(p[a]) << endl;
            } else {
                cin >> a >> b >> c;
                if (op == 'D')
                    c = -c;
                update(a, b, c);
            }
        }
    }
    return 0;
}

利用树链求LCA

这个部分参考了Peocco学长,十分感谢

在这道经典题中,求了LCA,但为什么树剖就可以求LCA呢?

树剖可以单次 \(O(log\ n)\)! 地求LCA,且常数较小。假如我们要求两个节点的LCA,如果它们在同一条链上,那直接输出深度较小的那个节点就可以了。

否则,LCA要么在链头深度较小的那条链上,要么就是两个链头的父节点的LCA,但绝不可能在链头深度较大的那条链上[1]。所以我们可以直接把链头深度较大的节点用其链头的父节点代替,然后继续求它与另一者的LCA。

由于在链上我们可以 \(O(1)\) 地跳转,每条链间由轻边连接,而经过轻边的次数又不超过 [公式] ,所以我们实现了 \(O(log\ n)\) 的LCA查询。

int lca(int a, int b) {
    while (top[a] != top[b]) {
        if (dep[top[a]] > dep[top[b]])
            a = fa[top[a]];
        else
            b = fa[top[b]];
    }
    return (dep[a] > dep[b] ? b : a);
}

结合数据结构

在进行了树链剖分后,我们便可以配合线段树等数据结构维护树上的信息,这需要我们改一下第二次 DFS 的代码,我们用dfsn数组记录每个点的dfs序,用madfsn数组记录每棵子树的最大dfs序:(这里有点像连通图的知识了)

// 需要先把根节点的top初始化为自身
int cnt;
void dfs2(int p) {
    madfsn[p] = dfsn[p] = ++cnt;
    if (hson[p] != 0) {
        top[hson[p]] = top[p];
        dfs2(hson[p]);
        madfsn[p] = max(madfsn[p], madfsn[hson[p]]);
    }
    for (auto q : edges[p])
        if (!top[q]) {
            top[q] = q;
            dfs2(q);
            madfsn[p] = max(madfsn[p], madfsn[q]);
        }
}

注意到,每棵子树的dfs序都是连续的,且根节点dfs序最小;而且,如果我们优先遍历重子节点,那么同一条链上的节点的dfs序也是连续的,且链头节点dfs序最小

连通树(雾)

所以就可以用线段树等数据结构维护区间信息(以点权的和为例),例如路径修改(类似于求LCA的过程):

void update_path(int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) {
            update(dfsn[top[x]], dfsn[x], z);
            x = fa[top[x]];
        } else {
            update(dfsn[top[y]], dfsn[y], z);
            y = fa[top[y]];
        }
    }
    if (dep[x] > dep[y])
        update(dfsn[y], dfsn[x], z);
    else
        update(dfsn[x], dfsn[y], z);
}

路径查询:

int query_path(int x, int y) {
    int ans = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) {
            ans += query(dfsn[top[x]], dfsn[x]);
            x = fa[top[x]];
        } else {
            ans += query(dfsn[top[y]], dfsn[y]);
            y = fa[top[y]];
        }
    }
    if (dep[x] > dep[y])
        ans += query(dfsn[y], dfsn[x]);
    else
        ans += query(dfsn[x], dfsn[y]);
    return ans;
}

子树修改(更新):

void update_subtree(int x, int z){
    update(dfsn[x], madfsn[x], z);
}

子树查询:

int query_subtree(int x){
    return query(dfsn[x], madfsn[x]);
}

需要注意,建线段树的时候不是按节点编号建,而是按dfs序建,类似这样:

for (int i = 1; i <= n; ++i)
    B[i] = read();
// ...
for (int i = 1; i <= n; ++i)
    A[dfsn[i]] = B[i];
build();

当然,不仅可以用线段树维护,有些题也可以使用珂朵莉树等数据结构(要求数据不卡珂朵莉树,如这道)。此外,如果需要维护的是边权而不是点权,把每条边的边权下放到深度较深的那个节点处即可,但是查询、修改的时候要注意略过最后一个点。

写在最后:

OI wiki上有一些推荐做的列题,但每个都需要比较多的时间+耐心去完成,所以这里推荐几个必做的题:

SPOJ QTREE – Query on a tree (树链剖分):千千dalao的题解报告

HDU 3966 Aragorn’s Story (树链剖分):建议先看一遍我的解法再独立完成。

参考

洛谷日报:https://zhuanlan.zhihu.com/p/41082337

OI wiki:https://oi-wiki.org/graph/hld/

Pecco学长:https://www.zhihu.com/people/one-seventh

千千:https://www.dreamwings.cn/hdu3966/4798.html


  1. 设top[a]的深度≤top[b]的深度,且c=lca(a,b)在b所在的链上;那么c是a和b的祖先且c的深度≥top[b]的深度,那么c的深度≥top[a]的深度。c是a的祖先,top[a]也是a的祖先,c的深度大于等于top[a],那c必然在连接top[a]和a的这条链上,与前提矛盾 ↩︎

posted @ 2020-11-30 21:01  Koshkaaa  阅读(782)  评论(2编辑  收藏  举报