树链剖分

思想

链可以看作一种特殊情况下的树。而当树退化成链之后问题就会变得非常简单。那么我们就可以考虑将一棵树变成若干条链来处理问题。基于这个思想,就有了基于树的路径剖分,也就是“树链剖分”。

一些定义

重儿子:该节点的子树中,节点个数最多的子树的根节点(也就是和该节点相连的点),即为该节点的重儿子
重边:连接该节点与它的重儿子的边
重链:由一系列重边相连得到的链
轻链:由一系列非重边相连得到的链

将树拆分为链采用轻重边路径剖分。对于每一个节点,找出它的重儿子,那么这棵树就自然而然的被拆成了许多重链与许多轻链。

既然要以链的形式来处理那么对于编号我们就需要保证同一条链上的要连续,我们可以利用dfs序的思想来给这棵树重新编号

通过图来理解一下

将这棵树进行剖分并重新编号后变成下面这棵树

这样我们就可以再用线段树或者树状数组来进行问题的处理了。

代码的实现

一:处理轻重边

void dfs1(int x, int f, int dep)
{
    deep[x] = dep; fa[x] = f; size[x] = 1;
    for(int i=h[x]; i; i=G[i].nex)
    {
        int v = G[i].to;
        if (v == f) continue;
        dfs1(v, x, dep+1);
        size[x] += size[v];
        if (son[x] == -1 || size[son[x]] < size[v]) son[x] = v;
    }
}

二:获得dfs序

void dfs2(int x, int tp)
{
    top[x] = tp; pos[x] = ++cnt; id[cnt] = x;
    if (son[x] == -1) return;
    dfs2(son[x], tp);
    for(int i=h[x]; i; i=G[i].nex)
    {
        int v = G[i].to;
        if (v == fa[x] || v == son[x]) continue;
        dfs2(v, v);
    }
}

另外对于不是一条链上的两个点,该如何处理呢?
在上述代码中我们有一个top数组,top[i]表示与i处于同一重链上的点中最深度最小的点。这样我们就可以通过i->top[i]->fa[top[i]]->...这样跳到另外一条重链上了。
对于两个点x,y每次让深度最大的一个点先跳,让其跳到fa[top[x(y)]]上,这样最终两个点总会跳到一条重链上的。而且由于重链上的点的序数是连续的,中间过程我们就可以通过线段树或者树状数组来进行各种操作即可。

例题

如题,已知一棵包含 NN 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作 1: 格式: 1 x y z 表示将树从 xx 到 yy 结点最短路径上所有节点的值都加上 zz。
操作 2: 格式: 2 x y 表示求树从 xx 到 yy 结点最短路径上所有节点的值之和。
操作 3: 格式: 3 x z 表示将以 xx 为根节点的子树内所有节点值都加上 zz。
操作 4: 格式: 4 x 表示求以 xx 为根节点的子树内所有节点值之和

代码

#include <bits/stdc++.h>
using namespace std;
#define LL long long
struct Tree{
    int l, r, mid, sum, lazy;
    Tree *lc, *rc;
    Tree()
    {
        lc = NULL; rc = NULL;
        lazy = 0; sum = 0;
    }
};
struct Edge{
    int from, to, nex;
};
const int N = 500005;
LL n, m, r, mod;
int size[N], deep[N], top[N], son[N], tot, h[N], A[N], id[N], pos[N], fa[N], cnt;
Tree *root;
Edge G[N*5];
void add(int u, int v)
{
    tot++; G[tot]=(Edge){u, v, h[u]}; h[u]=tot;
    tot++; G[tot]=(Edge){v, u, h[v]}; h[v]=tot;
}
void dfs1(int x, int f, int dep)
{
    deep[x] = dep; fa[x] = f; size[x] = 1;
    for(int i=h[x]; i; i=G[i].nex)
    {
        int v = G[i].to;
        if (v == f) continue;
        dfs1(v, x, dep+1);
        size[x] += size[v];
        if (son[x] == -1 || size[son[x]] < size[v]) son[x] = v;
    }
}
void dfs2(int x, int tp)
{
    top[x] = tp; pos[x] = ++cnt; id[cnt] = x;
    if (son[x] == -1) return;
    dfs2(son[x], tp);
    for(int i=h[x]; i; i=G[i].nex)
    {
        int v = G[i].to;
        if (v == fa[x] || v == son[x]) continue;
        dfs2(v, v);
    }
}
void build(Tree *p, int l, int r)
{
    p->l = l; p->r = r; p->mid = (l+r)>>1;
    if (l == r)
    {
        p->sum = A[id[l]];
        return;
    }
    p->lc = new Tree; p->rc = new Tree;
    build(p->lc, l, p->mid); build(p->rc, p->mid+1, r);
    p->sum = (p->lc->sum + p->rc->sum) % mod;
}
void down(Tree *p)
{
    if (p->lazy)
    {
        p->lc->lazy = (p->lc->lazy + p->lazy) % mod;
        p->lc->sum = (p->lc->sum + p->lazy * (p->lc->r - p->lc->l + 1) % mod) % mod;
        p->rc->lazy = (p->rc->lazy + p->lazy) % mod;
        p->rc->sum = (p->rc->sum + p->lazy * (p->rc->r - p->rc->l + 1) % mod) % mod;
        p->lazy = 0;
    }
}
void update(Tree *p, int l, int r, int k)
{
    if (l <= p->l && r >= p->r) (p->lazy += k%mod) %= mod, (p->sum += k*(p->r-p->l+1)%mod) %= mod;
    else
    {
        down(p);
        if (l <= p->mid) update(p->lc, l, r, k);
        if (r > p->mid) update(p->rc, l, r, k);
        p->sum = (p->lc->sum + p->rc->sum) % mod;
    }
}
void add(int u, int v, int k)
{
    while(top[u] != top[v])
    {
        if (deep[top[u]] < deep[top[v]]) swap(u, v);
        update(root, pos[top[u]], pos[u], k);
        u = fa[top[u]];
    }
    if(deep[u] < deep[v]) swap(u, v);
    update(root, pos[v], pos[u], k);
}
LL query(Tree *p, int l, int r)
{
    if (l <= p->l && r >= p->r) return p->sum;
    LL re = 0;
    down(p);
    if(l <= p->mid) re += query(p->lc, l, r), re %= mod;
    if(r > p->mid) re += query(p->rc, l, r), re %= mod;
    return re;
}
LL query_1(int u, int v)
{
    LL ans = 0;
    while(top[u] != top[v])
    {
        if(deep[top[u]] < deep[top[v]]) swap(u, v);
        ans += query(root, pos[top[u]], pos[u]);
        ans %= mod;
        u = fa[top[u]];
    }
    if(deep[u] < deep[v]) swap(u, v);
    ans += query(root, pos[v], pos[u]); ans %= mod;
    return ans;
}
int main()
{
    root = new Tree;
    memset(son, -1, sizeof son);
    scanf("%lld%lld%lld%lld", &n, &m, &r, &mod);
    for(int i=1; i<=n; ++i) scanf("%d", &A[i]);
    for(int i=1; i<n; ++i)
    {
        int x, y; scanf("%d%d", &x, &y); add(x, y);
    }
    dfs1(r, -1, 1); dfs2(r, r);
    build(root, 1, n);
    for(int i=1; i<=m; ++i)
    {
        int a, b, c, d; scanf("%d", &a);
        switch(a)
        {
            case 1:
            {
                scanf("%d%d%d", &b, &c, &d);
                add(b, c, d);
            } break;
            case 2:
            {
                scanf("%d%d", &b, &c);
                printf("%lld\n", query_1(b, c));
            } break;
            case 3:
            {
                scanf("%d%d", &b, &c);
                update(root, pos[b], pos[b]+size[b]-1, c);
            } break;
            case 4:
            {
                scanf("%d", &b);
                printf("%lld\n", query(root, pos[b], pos[b]+size[b]-1));
            } break;
        }
    }
    return 0;
}
posted @ 2020-03-11 16:31  諾-Oier  阅读(113)  评论(0)    收藏  举报