树链剖分

下面给出能够完成下列操作的一份树剖代码

  • 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

  • 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

  • 3 x z 表示将以x为根节点的子树内所有节点值都加上z

  • 4 x 表示求以x为根节点的子树内所有节点值之和

/*
*TODO
*---- Galaxy
*/
#include <cstdio>
#include <cstring>

typedef long long LL;

const int MAXN = 1e6 + 10;
const int MAXM = 2e5 + 10;
#define rep(i, s, t) for(int i = s; i <= t; ++i)
#define erep(i, u) for(int i = Begin[u]; i ^ (-1); i = Next[i])

template<class T> void swap(T &x, T &y) {x ^= y ^= x ^= y;}

#define C c = getchar()
inline LL read(LL x = 0, int f = 1) {
    char C;
    while(c < '0' || c > '9') f = c=='-'?-1:1, C;
    while(c >= '0' && c <= '9') x = x * 10 + c-'0', C;
    return x * f;
}

int n, m, Root, a[MAXN];
LL Mod;

namespace Galaxy {
    int e, to[MAXM], Begin[MAXN], Next[MAXN];
    int _, Id[MAXN], End[MAXN];
    int sz[MAXN], dep[MAXN], hson[MAXN], fa[MAXN], top[MAXN];

#define FILL(a, b) memset(a, b, sizeof a)
    void init() {
        e = _ = 0;
        FILL(Begin, -1);
    }

    void Add(int x, int y) {
        to[e] = y;
        Next[e] = Begin[x];
        Begin[x] = e++;
    }

    void dfs(int u) {
        int v;
        sz[u] = 1;
        erep(i, u)
            if((v=to[i]) ^ fa[u]) {
                fa[v] = u;
                dep[v] = dep[u] + 1;
                dfs(v);
                if(sz[v] > sz[hson[u]]) hson[u] = v;
                sz[u] += sz[v];
            }
    }

    //XXX
    void DFS(int u, int Top) {
        top[u] = Top; Id[u] = ++_;
        erep(i, u)
            if(to[i] == hson[u])
                DFS(to[i], Top);

        int v;
        erep(i, u)
            if(!top[v=to[i]])
                DFS(v, v);
        End[u] = _;
    }

#define l(h) (h<<1)
#define r(h) (h<<1|1)
    LL sum[MAXN << 2], add[MAXN << 2];

    void push_up(int h, int L, int R) {
        sum[h] = (sum[l(h)] + sum[r(h)]) % Mod;
        sum[h] = (sum[h] + add[h] * (R-L+1)) % Mod;
    }

    void update(int h, int u, int v, int L, int R, LL Tmp) {
        if(u <= L && R <= v) add[h] = (add[h] + Tmp) % Mod;
        else {
            int M = (L + R) >> 1;
            if(u <= M) update(l(h), u, v, L, M, Tmp);
            if(v > M) update(r(h), u, v, M+1, R, Tmp);
        }

        push_up(h, L, R);
    }

    LL query(int h, int u, int v, int L, int R, LL _add) {
        if(L >= u && R <= v)
            return (sum[h] + _add * (R-L+1) % Mod) % Mod;

        int M = (L + R) >> 1;
        LL ret = 0;
        if(u <= M) ret = (ret + query(l(h), u, v, L, M, _add+add[h])) % Mod;
        if(v > M) ret = (ret + query(r(h), u, v, M+1, R, _add+add[h])) % Mod;

        return ret % Mod;
    }

    void TC(int u, int v, int w) {
        for(; top[u] ^ top[v]; u = fa[top[u]]) {
            if(dep[top[u]] < dep[top[v]]) swap(u, v);
            update(1, Id[top[u]], Id[u], 1, n, w);
        }
        if(Id[u] > Id[v]) swap(u, v);
        update(1, Id[u], Id[v], 1, n, w);
    }

    LL TQ(int u, int v) {
        LL ret = 0;

        for(; top[u] ^ top[v]; u = fa[top[u]]) {
            if(dep[top[u]] < dep[top[v]]) swap(u, v);
            ret = (ret + query(1, Id[top[u]], Id[u], 1, n, 0)) % Mod;
        }
        if(dep[u] > dep[v]) swap(u, v);
        ret = (ret + query(1, Id[u], Id[v], 1, n, 0)) % Mod;

        return ret % Mod;
    }
};
using namespace Galaxy;

int main() {
#ifndef ONLINE_JUDGE
    freopen("input.in", "r", stdin);
    freopen("res.out", "w", stdout);
#endif
    init();

    n = read(), m = read(), Root = read(), Mod = read();
    rep(i, 1, n) a[i] = read() % Mod;
    rep(i, 1, n-1) {
        int u = read(), v = read();
        Add(u, v), Add(v, u);
    }

    dfs(Root);
    DFS(Root, Root);
    rep(i, 1, n) update(1, Id[i], Id[i], 1, n, a[i]);

    rep(i, 1, m) {
        int type = read(), u, v, w;
        if(type == 1) {
            u = read(), v = read(), w = read();
            TC(u, v, w);
        }else if(type == 2) {
            u = read(), v = read();
            printf("%lld\n", TQ(u, v));
        }else if(type == 3) {
            u = read(), w = read();
            update(1, Id[u], End[u], 1, n, w);
        }else if(type == 4) {
            u = read();
            printf("%lld\n", query(1, Id[u], End[u], 1, n, 0));
        }
    }
    return 0;
}
posted @ 2017-02-24 18:52  pbvrvnq  阅读(116)  评论(0编辑  收藏  举报