BZOJ1036 树的统计(树链剖分+线段树)

【题目描述】

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

【输入格式】

输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 

对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

【输出格式】

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

【样例输入】

4

1 2

2 3

4 1

4 2 1 3

12

QMAX 3 4

QMAX 3 3

QMAX 3 2

QMAX 2 3

QSUM 3 4

QSUM 2 1

CHANGE 1 5

QMAX 3 4

CHANGE 3 6

QMAX 3 4

QMAX 2 4

QSUM 3 4

【样例输出】

4

1

2

2

10

6

5

6

5

16

【题目分析】

这个题大意就是:给定一棵树,要求其支持单点修改,路径查询最大值和,路径求和。

这个就是典型树链剖分板题了,用线段树维护重链信息,对于单点修改就直接在线段树中修改,对于路径询问直接在线段树上进行区间查询即可。

【代码(已更新)】

#include<bits/stdc++.h>
using namespace std;
const int MAXN=3e4+10;

struct node{
    int y,next;
}edge[MAXN<<1];

int head[MAXN];
int dfn[MAXN],son[MAXN],depth[MAXN],fa[MAXN],siz[MAXN],a[MAXN];
int l,x,y,n,q,tot;
char s[10];
int rec[MAXN],top[MAXN];

void add(int x,int y)
{
    l++;
    edge[l].y=y;
    edge[l].next=head[x];
    head[x]=l;
}

void dfs1(int x,int f)
{
    fa[x]=f;
    son[x]=0;
    siz[x]=1;
    for(int i=head[x];i!=-1;i=edge[i].next)
    {
    	int v=edge[i].y;
        if(v!=f)
        {
            depth[v]=depth[x]+1;
            dfs1(v,x);
            siz[x]+=siz[v];
            if(siz[son[x]]<siz[v])
              son[x]=v;
        }
    }
}

void dfs2(int x,int tp)
{
    top[x]=tp;
    dfn[x]=++tot;
    rec[dfn[x]]=x;
    if(son[x]) 
	  dfs2(son[x],tp);
    for(int i=head[x];i!=-1;i=edge[i].next)
    {
    	int v=edge[i].y;
     	if(v!=fa[x]&&v!=son[x])
     	  dfs2(v,v);
 	}
}
struct point{
    int l,r,sum,maxx;
}tr[4*MAXN];

void push_up(int root)
{
    tr[root].sum=tr[root<<1].sum+tr[root<<1|1].sum;
    tr[root].maxx=max(tr[root<<1].maxx,tr[root<<1|1].maxx);
}

void build(int root,int l,int r)
{
    tr[root].l=l;tr[root].r=r;
    if(l==r)
	{
		tr[root].sum=tr[root].maxx=a[rec[l]]; 
		return ;
	}
    int mid=l+r>>1;
    build(root<<1,l,mid);
	build(root<<1|1,mid+1,r);
    push_up(root);
}

void update(int root,int x,int y)
{
    if(tr[root].l==x&&tr[root].r==x)
    {
        tr[root].sum=tr[root].maxx=y;
        return ;
    }
    int mid=tr[root].l+tr[root].r>>1;
    if(x<=mid) 
	  update(root<<1,x,y);
    if(x>mid) 
	  update(root<<1|1,x,y);
    push_up(root);
}

int querymax(int root,int l,int r)
{
    if(tr[root].l==l&&tr[root].r==r)
      return tr[root].maxx;
    int mid=tr[root].l+tr[root].r>>1;
    if(r<=mid) 
	  return querymax(root<<1,l,r);
    if(l>mid) 
	  return querymax(root<<1|1,l,r);
    if(l<=mid&&r>mid)
    {
        int s1=querymax(root<<1,l,mid);
        int s2=querymax(root<<1|1,mid+1,r);
        return max(s1,s2);
    }
}

int querysum(int root,int l,int r)
{
    if(tr[root].l==l&&tr[root].r==r)
      return tr[root].sum;
    int mid=tr[root].l+tr[root].r>>1;
    if(r<=mid) 
	  return querysum(root<<1,l,r);
    if(l>mid) 
	  return querysum(root<<1|1,l,r);
    if(l<=mid&&r>mid)
    {
        int s1=querysum(root<<1,l,mid);
        int s2=querysum(root<<1|1,mid+1,r);
        return s1+s2;
    }
}

int findmax(int x, int y)
{
    int f1=top[x],f2=top[y],ret=-0x3f3f3f3f;
    while(f1!=f2)
    {
        if(depth[f1]<depth[f2])
        {
			swap(f1,f2);
			swap(x,y);
		}
        ret=max(ret,querymax(1,dfn[f1],dfn[x]));
        x=fa[f1],f1=top[x];
    }
    if(x==y) 
	  return max(ret,querymax(1,dfn[x],dfn[x]));
    if(depth[x]>depth[y]) 
	  swap(x,y);
    return max(ret,querymax(1,dfn[x],dfn[y]));
}

int findsum(int x,int y)
{
    int f1=top[x],f2=top[y],ret=0;
    while(f1!=f2)
    {
        if(depth[f1]<depth[f2])
        { 
			swap(f1,f2);
			swap(x,y);
		}
        ret+=querysum(1,dfn[f1],dfn[x]);
        x=fa[f1],f1=top[x];
    }
    if(x==y) 
	  return ret+querysum(1,dfn[x],dfn[y]);
    if(depth[x]>depth[y]) 
	  swap(x,y);
    return ret+querysum(1,dfn[x],dfn[y]);
}

int main()
{
    scanf("%d",&n);
    memset(head,-1,sizeof(head));
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    for(int i=1;i<=n;i++) 
	  scanf("%d",&a[i]);
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    scanf("%d",&q);
    while(q--)
    {
        scanf("%s%d%d",s,&x,&y);
        if(s[0]=='C')  
		  update(1,dfn[x],y); 
        if(s[0]=='Q'&&s[1]=='M') 
		  printf("%d\n",findmax(x,y));
        if(s[0]=='Q'&&s[1]=='S') 
		  printf("%d\n", findsum(x,y));
    }
    return 0;
}

 

posted @ 2018-10-07 12:01  Ishtar~  阅读(164)  评论(0编辑  收藏  举报