bzoj3589 动态树

3589: 动态树

Time Limit: 30 Sec  Memory Limit: 1024 MB
Submit: 635  Solved: 230
[Submit][Status][Discuss]

Description

 

别忘了这是一棵动态树, 每时每刻都是动态的. 小明要求你在这棵树上维护两种事件
事件0:
这棵树长出了一些果子, 即某个子树中的每个节点都会长出K个果子.
事件1:
小明希望你求出几条树枝上的果子数. 一条树枝其实就是一个从某个节点到根的路径的一段. 每次小明会选定一些树枝, 让你求出在这些树枝上的节点的果子数的和. 注意, 树枝之间可能会重合, 这时重合的部分的节点的果子只要算一次.
 

Input

第一行一个整数n(1<=n<=200,000), 即节点数.
接下来n-1行, 每行两个数字u, v. 表示果子u和果子v之间有一条直接的边. 节点从1开始编号.
在接下来一个整数nQ(1<=nQ<=200,000), 表示事件.
最后nQ行, 每行开头要么是0, 要么是1.
如果是0, 表示这个事件是事件0. 这行接下来的2个整数u, delta表示以u为根的子树中的每个节点长出了delta个果子.
如果是1, 表示这个事件是事件1. 这行接下来一个整数K(1<=K<=5), 表示这次询问涉及K个树枝. 接下来K对整数u_k, v_k, 每个树枝从节点u_k到节点v_k. 由于果子数可能非常多, 请输出这个数模2^31的结果.

Output

对于每个事件1, 输出询问的果子数.

Sample Input

5
1 2
2 3
2 4
1 5
3
0 1 1
0 2 3
1 2 3 1 1 4

Sample Output

13

HINT

 

 1 <= n <= 200,000, 1 <= nQ <= 200,000, K = 5.


生成每个树枝的过程是这样的:先在树中随机找一个节点, 然后在这个节点到根的路径上随机选一个节点, 这两个节点就作为树枝的两端.

 

Source

分析:这标题......
         其实这道题很容易就能想到解法:线段树+树链剖分+dfs序.对于第一个操作无非就是在线段树上区间加嘛,但是第二个操作就有点麻烦,因为统计的是并,不能一条一条地统计链的答案.解决这种问题有两种方法:1.利用容斥原理,很容易想到. 2.将需要统计的部分标记出来,然后统计有标记的地方. 利用第二种方法.先对所有需要操作的链打上标记.统计完答案后将所有标记清零就可以了,挺好写.
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 400010;

int n,q,head[maxn],to[maxn],nextt[maxn],tot = 1;
int pos[maxn],ppos[maxn],deep[maxn],top[maxn],sizee[maxn],son[maxn],fa[maxn];
int sum[maxn << 2],tag[maxn << 2],add[maxn << 2],val[maxn << 2],cnt,endd[maxn],L[maxn << 2],R[maxn << 2];

void addd(int x,int y)
{
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;
}

void dfs(int u,int faa,int dep)
{
    deep[u] = dep;
    sizee[u] = 1;
    for (int i = head[u];i;i = nextt[i])
    {
        int v = to[i];
        if (v == faa)
            continue;
        fa[v] = u;
        dfs(v,u,dep + 1);
        sizee[u] += sizee[v];
        if (sizee[v] > sizee[son[u]])
            son[u] = v;
    }
}

void dfs2(int u,int topp)
{
    pos[u] = ++cnt;
    ppos[cnt] = u;
    top[u] = topp;
    if (son[u])
        dfs2(son[u],topp);
    for (int i = head[u];i;i = nextt[i])
    {
        int v = to[i];
        if (v == fa[u] || v == son[u])
            continue;
        dfs2(v,v);
    }
    endd[u] = cnt;
}

void pushup(int o)
{
    sum[o] = sum[o * 2] + sum[o * 2 + 1];
    val[o] = val[o * 2] + val[o * 2 + 1];
}

void build(int o,int l,int r)
{
    L[o] = l;
    R[o] = r;
    tag[o] = -1;
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    build(o * 2,l,mid);
    build(o * 2 + 1,mid + 1,r);
    pushup(o);
}

void pushdown(int o)
{
    if (add[o])
    {
        add[o * 2] += add[o];
        add[o * 2 + 1] += add[o];
        sum[o * 2] += add[o] * (R[o * 2] - L[o * 2] + 1);
        sum[o * 2 + 1] += add[o] * (R[o * 2 + 1] - L[o * 2 + 1] + 1);
        add[o] = 0;
    }
    if (tag[o] != -1)
    {
        val[o * 2] = sum[o * 2] * tag[o]; //这里其实是新开了一个数组将答案传到根节点上去
        val[o * 2 + 1] = sum[o * 2 + 1] * tag[o];
        tag[o * 2] = tag[o * 2 + 1] = tag[o];
        tag[o] = -1;
    }
}

void update(int o,int l,int r,int x,int y,int v)  //区间加
{
    if (x <= l && r <= y)
    {
        sum[o] += v * (r - l + 1);
        add[o] += v;
        return;
    }
    pushdown(o);
    int mid = (l + r) >> 1;
    if (x <= mid)
        update(o * 2,l,mid,x,y,v);
    if (y > mid)
        update(o * 2 + 1,mid + 1,r,x,y,v);
    pushup(o);
}

void update2(int o,int l,int r,int x,int y,int v)
{
    if (x <= l && r <= y)
    {
        val[o] = sum[o] * v;
        tag[o] = v; //打标记,v=1代表统计,v=0代表不统计,v=-1代表没标记.
        return;
    }
    pushdown(o);
    int mid = (l + r) >> 1;
    if (x <= mid)
        update2(o * 2,l,mid,x,y,v);
    if (y > mid)
        update2(o * 2 + 1,mid + 1,r,x,y,v);
    pushup(o);
}

void solve(int x,int y)
{
    while (top[x] != top[y])
    {
        if (deep[top[x]] < deep[top[y]])
            swap(x,y);
        update2(1,1,n,pos[top[x]],pos[x],1);
        x = fa[top[x]];
    }
    if (deep[x] < deep[y])
        swap(x,y);
    update2(1,1,n,pos[y],pos[x],1);
}

int main()
{
    scanf("%d",&n);
    for (int i = 1; i < n; i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        addd(x,y);
        addd(y,x);
    }
    dfs(1,0,1);
    dfs2(1,1);
    build(1,1,n);
    scanf("%d",&q);
    for (int i = 1; i <= q; i++)
    {
        int id,x,y;
        scanf("%d",&id);
        if (id == 0)
        {
            scanf("%d%d",&x,&y);
            update(1,1,n,pos[x],endd[x],y);
        }
        else
        {
            int k;
            scanf("%d",&k);
            while (k--)
            {
                scanf("%d%d",&x,&y);
                solve(x,y);
            }
            printf("%d\n",val[1] & 0x7fffffff);
            update2(1,1,n,1,n,0);
        }
    }

    return 0;
}

 

posted @ 2018-01-19 17:43  zbtrs  阅读(161)  评论(0编辑  收藏  举报