树链剖分原理和实现
树链剖分原理和实现
理解
树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。
首先就是一些必须知道的概念:
- 重结点:子树结点数目最多的结点;
- 轻节点:父亲节点中除了重结点以外的结点;
- 重边:父亲结点和重结点连成的边;
- 轻边:父亲节点和轻节点连成的边;
- 重链:由多条重边连接而成的路径;
- 轻链:由多条轻边连接而成的路径;

比如上面这幅图中,用黑线连接的结点都是重结点,其余均是轻结点,2-11、1-11就是重链,其他就是轻链,用红点标记的就是该结点所在链的起点,也就是我们👇提到的top结点,还有每条边的值其实是进行dfs时的执行序号。
算法中定义了以下的数组用来存储上边提到的概念:
| 名称 | 解释 | 
|---|---|
| siz[u] | 保存以u为根的子树节点个数 | 
| top[u] | 保存当前节点所在链的顶端节点 | 
| son[u] | 保存重儿子 | 
| dep[u] | 保存结点u的深度值 | 
| faz[u] | 保存结点u的父亲节点 | 
| tid[u] | 保存树中每个节点剖分以后的新编号(DFS的执行顺序) | 
| rnk[u] | 保存当前节点在树中的位置 | 
除此之外,还包括两种性质:
- 如果(u, v)是一条轻边,那么size(v) < size(u)/2;
- 从根结点到任意结点的路所经过的轻重链的个数必定都小与O(logn);
首先定义以下数组:
 xxxxxxxxxxconst int MAXN = (100000 << 2) + 10;//Heavy-light Decomposition STARTS FORM HEREint siz[MAXN];//number of sonint top[MAXN];//top of the heavy linkint son[MAXN];//heavy son of the nodeint dep[MAXN];//depth of the nodeint faz[MAXN];//father of the nodeint tid[MAXN];//ID -> DFSIDint rnk[MAXN];//DFSID -> ID
算法大致需要进行两次的DFS,第一次DFS可以得到当前节点的父亲结点(faz数组)、当前结点的深度值(dep数组)、当前结点的子结点数量(size数组)、当前结点的重结点(son数组)
 xxxxxxxxxxvoid dfs1(int u, int father, int depth) {    /*     * u: 当前结点     * father: 父亲结点     * depth: 深度     */    // 更新dep、faz、siz数组    dep[u] = depth;    faz[u] = father;    siz[u] = 1;    // 遍历所有和当前结点连接的结点    for (int i = head[u]; i; i = edg[i].next) {        int v = edg[i].to;        // 如果连接的结点是当前结点的父亲结点,则不处理        if (v != faz[u]) {            dfs1(v, u, depth + 1);            // 收敛的时候将当前结点的siz加上子结点的siz            siz[u] += siz[v];            // 如果没有设置过重结点son或者子结点v的siz大于之前记录的重结点son,则进行更新            if (son[u] == -1 || siz[v] > siz[son[u]]) {                son[u] = v;            }        }    }}
第二次DFS的时候则可以将各个重结点连接成重链,轻节点连接成轻链,并且将重链(其实就是一段区间)用数据结构(一般是树状数组或线段树)来进行维护,并且为每个节点进行编号,其实就是DFS在执行时的顺序(tid数组),以及当前节点所在链的起点(top数组),还有当前节点在树中的位置(rank数组)。
 xxxxxxxxxxvoid dfs2(int u, int t) {    /**     * u:当前结点     * t:起始的重结点     */    top[u] = t;  // 设置当前结点的起点为t    tid[u] = cnt;  // 设置当前结点的dfs执行序号    rnk[cnt] = u;  // 设置dfs序号对应成当前结点    cnt++;    // 如果当前结点没有处在重链上,则不处理    if (son[u] == -1) {        return;    }    // 将这条重链上的所有的结点都设置成起始的重结点    dfs2(son[u], t);    // 遍历所有和当前结点连接的结点    for (int i = head[u]; i; i = edg[i].next) {        int v = edg[i].to;        // 如果连接结点不是当前结点的重子结点并且也不是u的父亲结点,则将其的top设置成自己,进一步递归        if (v != son[u] && v != faz[u]){            dfs2(v, v);        }    }}
而修改和查询操作原理是类似的,以查询操作为例,其实就是个LCA,不过这里使用了top来进行加速,因为top可以直接跳转到该重链的起始结点,轻链没有起始结点之说,他们的top就是自己。需要注意的是,每次循环只能跳一次,并且让结点深的那个来跳到top的位置,避免两个一起跳从而插肩而过。
 xxxxxxxxxxINT64 query_path(int x, int y) {    /**     * x:结点x     * y:结点y     * 查询结点x到结点y的路径和     */    INT64 ans = 0;    int fx = top[x], fy = top[y];    // 直到x和y两个结点所在链的起始结点相等才表明找到了LCA    while (fx != fy) {        if (dep[fx] >= dep[fy]) {            // 已经计算了从x到其链中起始结点的路径和            ans += query(1, tid[fx], tid[x]);            // 将x设置成起始结点的父亲结点,走轻边,继续循环            x = faz[fx];        } else {            ans += query(1, tid[fy], tid[y]);            y = faz[fy];        }        fx = top[x], fy = top[y];    }    // 即便找到了LCA,但是前面也只是分别计算了从一开始到最终停止的位置和路径和    // 如果两个结点不一样,表明仍然需要计算两个结点到LCA的路径和    if (x != y) {        if (tid[x] < tid[y]) {            ans += query(1, tid[x], tid[y]);        } else {            ans += query(1, tid[y], tid[x]);        }    } else ans += query(1, tid[x], tid[y]);    return ans;}void update_path(int x, int y, int z) {    /**     * x:结点x     * y:结点y     * z:需要加上的值     * 更新结点x到结点y的值     */    int fx = top[x], fy = top[y];    while(fx != fy) {        if (dep[fx] > dep[fy]) {            update(1, tid[fx],tid[x], z);            x = faz[fx];        } else {            update(1, tid[fy], tid[y], z);            y = faz[fy];        }        fx = top[x], fy = top[y];    }    if (x != y)        if (tid[x] < tid[y]) update(1, tid[x], tid[y], z);        else update(1, tid[y], tid[x], z);    else update(1, tid[x], tid[y], z);}
实战
以这道题目为例,可以看出算法大致有两种操作,分别是求任意两个节点所连接的路径和、极值,又或者是以任意一个节点作为跟节点来求与子结点的路径和、极值,而求区间和、区间极值正是线段树所擅长的。
首先要构建线段树:
 x
void build(int i, int l, int r) {    /**     * i:当前结点的位置,i << 1表示左结点,+1表示右结点     * l:区间左索引     * r:区间右索引     */    tree[i].left = l;    tree[i].right = r;    // 设置树对应结点的值    if (l == r) {        tree[i].val = val[rnk[l]];    } else {        // 将数组按照二分的形式来拆分        int mid = (l + r) >> 1;        build(i << 1, l, mid);        build((i << 1) | 1, mid + 1, r);        tree[i].val = tree[i << 1].val + tree[(i << 1) + 1].val;    }}void pushdown(int i) {    /**     * 更新操作     */    int lc = i << 1;    int rc = (i << 1) + 1;    tree[lc].val += (tree[lc].right - tree[lc].left + 1) * tag[i];    tree[rc].val += (tree[rc].right - tree[rc].left + 1) * tag[i];    tag[lc] += tag[i];    tag[rc] += tag[i];    tag[i] = 0;}void update(int i, int x, int y, INT64 k) {    /**     * i:起始位置     * x:区间左索引     * y:区间右索引     * k:加上的值     * 更新满足区间x-y的所有值加上k     */    int lc = i << 1, rc = (i << 1) | 1;    if (tree[i].left > y || tree[i].right < x) return;    if (x <= tree[i].left && tree[i].right <= y) {        tree[i].val += (tree[i].right - tree[i].left + 1) * k;        tag[i] += k;    } else {        if (tag[i]) pushdown(i);        update(lc, x, y, k);        update(rc, x, y, k);        tree[i].val = tree[lc].val + tree[rc].val;    }}INT64 query(int i, int x, int y) {    /**     * 查询操作     */    int lc = i << 1, rc = (i << 1) + 1;    if (x <= tree[i].left && tree[i].right <= y)        return tree[i].val;    if (tree[i].left > y || tree[i].right < x)        return 0;    if (tag[i]) pushdown(i);    return query(lc, x, y) + query(rc, x, y);}
 
关注公众号:数据结构与算法那些事儿,每天一篇数据结构与算法
 
                     
                    
                 
                    
                
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号