树链剖分模板题

题目描述

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

输入格式

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

输出格式

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)

输入输出样例

输入 #1
5 5 2 24
7 3 7 8 0 
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
输出 #1
2
21

 

#include <bits/stdc++.h>

using namespace std;
const int maxn=1e5+100;
struct edge
{
    int v,next;
} e[maxn*2];
struct Tree
{
    int sum,l,r,lazy;
} tree[maxn*4];

int top[maxn],tim,dfn[maxn],son[maxn],w[maxn],head[maxn],a[maxn],siz[maxn],deep[maxn],fa[maxn],mod,t;

void add(int u,int v)
{
    t++;
    e[t].v=v;
    e[t].next=head[u];
    head[u]=t;
}

void pushup(int rt)
{
    tree[rt].sum=(tree[rt<<1].sum+tree[rt<<1|1].sum)%mod;
}

void pushdown(int rt,int l,int r)
{
    if (tree[rt].lazy==0)
        return;
    int mid=(l+r)>>1;
    tree[rt<<1].lazy+=tree[rt].lazy;
    tree[rt<<1|1].lazy+=tree[rt].lazy;
    tree[rt<<1].sum=(tree[rt<<1].sum+tree[rt].lazy*(mid-l+1))%mod;
    tree[rt<<1|1].sum=(tree[rt<<1|1].sum+tree[rt].lazy*(r-mid))%mod;
    tree[rt].lazy=0;
}

void build(int rt,int l,int r)
{
    tree[rt].l=l;
    tree[rt].r=r;
    tree[rt].sum=0;
    if (l==r)
    {
        tree[rt].sum=w[l];
        return;
    }
    int mid=(l+r)>>1;
    build(rt<<1,l,mid);
    build (rt<<1|1,mid+1,r);
    pushup(rt);
}

void update(int rt,int l,int r,int z)
{
    if (l<=tree[rt].l&&tree[rt].r<=r)
    {
        tree[rt].lazy+=z;
        tree[rt].sum+=z*(tree[rt].r-tree[rt].l+1);
        return;
    }
    pushdown(rt,tree[rt].l,tree[rt].r);
    int mid=(tree[rt].l+tree[rt].r)>>1;
    if (l<=mid)
    {
        update(rt<<1,l,r,z);
    }
    if (r>mid)
    {
        update(rt<<1|1,l,r,z);
    }
    pushup(rt);
}
void dfs1(int u,int f)
{
    deep[u]=deep[f]+1;
    siz[u]=1;
    fa[u]=f;
    int maxsiz=-1;
    for (int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].v;
        if (v==f)
        {
            continue;
        }
        dfs1(v,u);
        siz[u]+=siz[v];
        if (siz[v]>maxsiz)
        {
            maxsiz=siz[v];
            son[u]=v;
        }
    }
}

int query(int rt,int l,int r)
{
    if (l<=tree[rt].l&&tree[rt].r<=r)
    {
        return tree[rt].sum%mod;
    }
    pushdown(rt,tree[rt].l,tree[rt].r);
    int ret=0;
    int mid=(tree[rt].l+tree[rt].r)>>1;
    if (l<=mid)
        ret=(ret+query(rt<<1,l,r))%mod;
    if (r>mid)
        ret=(ret+query(rt<<1|1,l,r))%mod;
    return ret;
}

void dfs2(int u,int Top)
{
    dfn[u]=++tim;
    w[tim]=a[u];
    top[u]=Top;
    if (!son[u])
    {
        return;
    }
    dfs2(son[u],Top);
    for (int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].v;
        if (v==fa[u]||v==son[u])
        {
            continue;
        }
        dfs2(v,v);
    }
}

void update1(int x,int y,int z)
{
    z=z%mod;
    while (top[x]!=top[y])
    {
        if (deep[top[x]]<deep[top[y]])
        {
            swap(x,y);
        }
        update(1,dfn[top[x]],dfn[x],z);
        x=fa[top[x]];
    }
    if (deep[x]>deep[y])
    {
        swap(x,y);
    }
    update(1,dfn[x],dfn[y],z);
}

int query1(int x,int y)
{
    int ret=0;
    while (top[x]!=top[y])
    {
        if (deep[top[x]]<deep[top[y]])
        {
            swap(x,y);
        }
        ret+=query(1,dfn[top[x]],dfn[x]);
        x=fa[top[x]];
    }
    if (deep[x]>deep[y])
    {
        swap(x,y);
    }
    ret+=query(1,dfn[x],dfn[y]);
    return ret%mod;
}

int main()
{
    // freopen("1.txt","w",stdout);
    int n,m,r;
    scanf("%d%d%d%d",&n,&m,&r,&mod);
    for (int i=1; i<=n; ++i)
    {
        scanf("%d",&a[i]);
    }
    for (int i=1,u,v; i<n; i++)
    {
        scanf("%d%d",&u,&v);
        add(u,v);
        add(v,u);
    }
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);

    while (m--)
    {
        int op;
        scanf("%d",&op);
        if (op==1)
        {
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            update1(x,y,z);
        }
        if (op==2)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            printf("%d\n",query1(x,y));
        }
        if (op==3)
        {
            int x,z;
            scanf("%d%d",&x,&z);
            update(1,dfn[x],dfn[x]+siz[x]-1,z);
        }
        if (op==4)
        {
            int x;
            scanf("%d",&x);
            printf("%d\n",query(1,dfn[x],dfn[x]+siz[x]-1));
        }
    }
    return 0;
}

  

posted @ 2019-08-15 09:31  Snow_in_winer  阅读(150)  评论(0编辑  收藏  举报