BZOJ1036[ZJOI2008]树的统计——树链剖分+线段树

题目描述

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

输入

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

输出

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

样例输入

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

样例输出

4
1
2
2
10
6
5
6
5
16
 
  单点修改、路径求最大值、路径求和,直接上树链剖分,但要注意求最大值时因为可能有负数,所以最小值要设成-INF。
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int n,m;
int tot;
int num;
int x,y;
char ch[30];
int f[30010];
int d[30010];
int s[30010];
int to[60010];
int mx[240010];
int son[30010];
int top[30010];
int size[30010];
int head[30010];
int next[30010];
int sum[240010];
void add(int x,int y)
{
    tot++;
    next[tot]=head[x];
    head[x]=tot;
    to[tot]=y;
}
void dfs(int x,int fa)
{
    size[x]=1;
    f[x]=fa;
    d[x]=d[fa]+1;
    for(int i=head[x];i;i=next[i])
    {
        if(to[i]!=fa)
        {
            dfs(to[i],x);
            size[x]+=size[to[i]];
            if(size[to[i]]>size[son[x]])
            {
                son[x]=to[i];
            }
        }
    }
}
void dfs2(int x,int tp)
{
    s[x]=++num;
    top[x]=tp;
    if(son[x])
    {
        dfs2(son[x],tp);
    }
    for(int i=head[x];i;i=next[i])
    {
        if(to[i]!=f[x]&&to[i]!=son[x])
        {
            dfs2(to[i],to[i]);
        }
    }
}
void updata(int rt)
{
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
    mx[rt]=max(mx[rt<<1],mx[rt<<1|1]);
}
void change(int rt,int l,int r,int k,int v)
{
    if(l==r)
    {
        sum[rt]=v;
        mx[rt]=v;
        return ;
    }
    int mid=(l+r)>>1;
    if(k<=mid)
    {
        change(rt<<1,l,mid,k,v);
    }
    else
    {
        change(rt<<1|1,mid+1,r,k,v);
    }
    updata(rt);
}
int querysum(int rt,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        return sum[rt];
    }
    int mid=(l+r)>>1;
    int res=0;
    if(L<=mid)
    {
        res+=querysum(rt<<1,l,mid,L,R);
    }
    if(R>mid)
    {
        res+=querysum(rt<<1|1,mid+1,r,L,R);
    }
    return res;
}
int querymax(int rt,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        return mx[rt];
    }
    int mid=(l+r)>>1;
    if(R<=mid)
    {
        return querymax(rt<<1,l,mid,L,R);
    }
    else if(L>mid)
    {
        return querymax(rt<<1|1,mid+1,r,L,R);
    }
    return max(querymax(rt<<1,l,mid,L,R),querymax(rt<<1|1,mid+1,r,L,R));
}
int asksum(int x,int y)
{
    int res=0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
        {
            swap(x,y);
        }
        res+=querysum(1,1,n,s[top[x]],s[x]);
        x=f[top[x]];
    }
    if(d[x]>d[y])
    {
        swap(x,y);
    }
    res+=querysum(1,1,n,s[x],s[y]);
    return res;
}
int askmax(int x,int y)
{
    int res=-2147483647;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
        {
            swap(x,y);
        }
        res=max(res,querymax(1,1,n,s[top[x]],s[x]));
        x=f[top[x]];
    }
    if(d[x]>d[y])
    {
        swap(x,y);
    }
    res=max(res,querymax(1,1,n,s[x],s[y]));
    return res;
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(1,1);
    dfs2(1,1);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&x);
        change(1,1,n,s[i],x);
    }
    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        scanf("%s",ch);
        scanf("%d%d",&x,&y);
        if(ch[1]=='H')
        {
            
            change(1,1,n,s[x],y);
        }
        else if(ch[1]=='M')
        {
            printf("%d\n",askmax(x,y));
        }
        else
        {
            printf("%d\n",asksum(x,y));
        }
    }
}
posted @ 2018-08-30 16:33  The_Virtuoso  阅读(273)  评论(0编辑  收藏  举报