树剖裸题——BZOJ1036 树的统计

  1 #include<cstring>
  2 #include<cmath>
  3 #include<algorithm>
  4 #include<cstdio>
  5 #define foru(i,x,y) for(int i=x;i<=y;i++)
  6 #define clr(a) memset(a,0,sizeof(a))
  7 using namespace std;
  8 const int N=100010;
  9 struct edge{int to,nxt;}e[N*2];
 10 struct node{int s,m;}t[10*N];
 11 int d[N],id[N],head[N],f[N],siz[N],son[N],top[N];
 12 //f[v]        节点v的父节点编号
 13 //id[v]        节点v的父边在线段树中的编号
 14 //siz[v]    以节点v为根的数中的节点数
 15 //son[v]    节点v的子节点中siz[]最大的节点编号 
 16 //top[v]    节点v所在重链的顶端节点编号 
 17 //d[v]        节点v的深度 
 18 int ne,cnt,n;
 19 
 20 void add(int a,int b){
 21     e[++ne]=(edge){b,head[a]};head[a]=ne;
 22 }
 23 void dfs(int k,int fa,int dep){//统计f[] siz[] son[] d[]
 24     //printf("test\n");
 25     f[k]=fa;d[k]=dep;siz[k]=1;son[k]=0;
 26     for(int i=head[k];i;i=e[i].nxt){
 27         int v=e[i].to;
 28         if(v==fa)continue;
 29         dfs(v,k,dep+1);
 30         siz[k]+=siz[v];
 31         if(siz[v]>siz[son[k]])son[k]=v;
 32     }
 33 }
 34 
 35 void build(int k,int tp){
 36     id[k]=++cnt; top[k]=tp;//按序将边加入线段树
 37     if(son[k])build(son[k],tp);//重儿子的top[]从重链顶端继承
 38     for(int i=head[k];i;i=e[i].nxt){
 39         int v=e[i].to;
 40         if(v!=son[k]&&v!=f[k])
 41             build(v,v);//轻儿子top[]为自身
 42     }
 43 }
 44 
 45 #define mid ((L+R)>>1)
 46 #define ls (k<<1)//写位运算一定要开-Wall,否则一定要记得加括号 
 47 #define rs ls+1
 48 
 49 void update(int k,int L,int R,int p,int x){
 50     if(p>R||p<L)return;
 51     if(L==R){t[k].s=t[k].m=x;return;}
 52     update(ls,L,mid,p,x); update(rs,mid+1,R,p,x);
 53     t[k].m=max(t[ls].m,t[rs].m);
 54     t[k].s=t[ls].s+t[rs].s;
 55 }
 56 
 57 int querym(int k,int L,int R,int l,int r){
 58     if(l>R||r<L)return -1e9;
 59     if(l<=L&&R<=r)return t[k].m;
 60     return max(querym(ls,L,mid,l,r),querym(rs,mid+1,R,l,r));
 61 }
 62 
 63 int querys(int k,int L,int R,int l,int r){
 64     if(l>R||r<L)return 0;
 65     if(l<=L&&R<=r)return t[k].s;
 66     return querys(ls,L,mid,l,r)+querys(rs,mid+1,R,l,r);
 67 }
 68 
 69 int findm(int x,int y){
 70     int ans=-1e9*2;
 71     while(top[x]!=top[y]){//类似LCA,每次将较低的节点上跳,并统计路径上的最大值 
 72         if(d[top[x]]<d[top[y]])swap(x,y);
 73         ans=max(ans,querym(1,1,cnt,id[top[x]],id[x]));
 74         x=f[top[x]];
 75     }
 76     if(d[x]>d[y])swap(x,y);//当两点处于同一条链上的时候,进行最后一次统计 
 77     ans=max(ans,querym(1,1,cnt,id[x],id[y]));
 78     return ans;
 79 }
 80 
 81 int finds(int x,int y){
 82     int ans=0;
 83     while(top[x]!=top[y]){
 84         if(d[top[x]]<d[top[y]])swap(x,y);
 85         ans+=querys(1,1,cnt,id[top[x]],id[x]);
 86         x=f[top[x]];
 87     }
 88     if(d[x]>d[y])swap(x,y);
 89     ans+=querys(1,1,cnt,id[x],id[y]); 
 90     return ans;
 91 }
 92 char ch[20];
 93 int main(){
 94     int u,v,x,y;
 95     scanf("%d",&n);
 96     foru(i,1,n-1){
 97         scanf("%d%d",&u,&v);
 98         add(u,v);add(v,u);
 99     }
100     dfs(1,0,1);
101     build(1,1);
102     foru(i,1,n){
103         scanf("%d",&u);
104         update(1,1,cnt,id[i],u);
105     }
106     scanf("%d",&u);
107     while(u--){
108         scanf("%s%d%d",ch,&x,&y);
109         if(ch[0]=='C')update(1,1,cnt,id[x],y);
110         else{
111             if(ch[1]=='M')printf("%d\n",findm(x,y));
112             else printf("%d\n",finds(x,y));
113         }
114     }
115     return 0;
116 }

 

posted @ 2017-04-07 23:27  羊毛羊  阅读(439)  评论(0编辑  收藏  举报