BZOJ 1036 [ZJOI2008]树的统计Count

以前动态树写过这个题,今天尝试树链剖分解决~

模板题,就声明一点,线段树维护的是点权

 

View Code
  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cstring>
  4 #include <cstdlib>
  5 #include <algorithm>
  6 
  7 #define N 50000
  8 #define M 100000
  9 #define INF 1e9
 10 
 11 using namespace std;
 12 
 13 int head[N],to[M],next[M];
 14 int sz[N],top[N],bh[N],fa[N],son[N],dep[N];
 15 int sum[N<<2],mx[N<<2],val[N],dat[N];
 16 int q[N];
 17 int n,cnt,tot,qu;
 18 
 19 inline void add(int u,int v)
 20 {
 21     to[cnt]=v; next[cnt]=head[u]; head[u]=cnt++;
 22 }
 23 
 24 inline void prep()
 25 {
 26     int h=1,t=2,sta;
 27     q[1]=1; dep[1]=1;
 28     while(h<t)
 29     {
 30         sta=q[h++]; sz[sta]=1;
 31         for(int i=head[sta];~i;i=next[i])
 32             if(to[i]!=fa[sta])
 33             {
 34                 fa[to[i]]=sta;
 35                 dep[to[i]]=dep[sta]+1;
 36                 q[t++]=to[i];
 37             }
 38     }
 39     for(int j=t-1;j>=1;j--)
 40     {
 41         sta=q[j];
 42         for(int i=head[sta];~i;i=next[i])
 43             if(to[i]!=fa[sta])
 44             {
 45                 sz[sta]+=sz[to[i]];
 46                 if(sz[to[i]]>sz[son[sta]]) son[sta]=to[i]; 
 47             }
 48     }
 49     for(int i=1;i<t;i++)
 50     {
 51         sta=q[i];
 52         if(son[fa[sta]]==sta) top[sta]=top[fa[sta]];//不是重链顶部 
 53         else top[sta]=sta;
 54     }
 55 }
 56 
 57 inline void rewrite()
 58 {
 59     for(int i=1;i<=n;i++)
 60         if(top[i]==i)
 61             for(int j=i;j;j=son[j])//每条重链的编号是连续的 
 62             {
 63                 bh[j]=++tot;
 64                 dat[tot]=val[j];
 65             }
 66 }
 67 
 68 inline void pushup(int x)
 69 {
 70     sum[x]=sum[x<<1]+sum[x<<1|1];
 71     mx[x]=max(mx[x<<1],mx[x<<1|1]);
 72 }
 73 
 74 inline void build(int u,int L,int R)
 75 {
 76     if(L==R) {sum[u]=mx[u]=dat[L];return;}
 77     int MID=(L+R)>>1;
 78     build(u<<1,L,MID); build(u<<1|1,MID+1,R);
 79     pushup(u);
 80 }
 81 
 82 inline void read()
 83 {
 84     memset(head,-1,sizeof head); cnt=0;
 85     scanf("%d",&n);
 86     for(int i=1,a,b;i<n;i++)
 87     {
 88         scanf("%d%d",&a,&b);
 89         add(a,b); add(b,a);
 90     }
 91     for(int i=1;i<=n;i++) scanf("%d",&val[i]);
 92     prep();
 93     rewrite();
 94     build(1,1,tot);
 95 }
 96 
 97 inline int querysum(int u,int L,int R,int l,int r)
 98 {
 99     if(l<=L&&R<=r) return sum[u];
100     int MID=(L+R)>>1,res=0;
101     if(l<=MID) res+=querysum(u<<1,L,MID,l,r);
102     if(MID<r) res+=querysum(u<<1|1,MID+1,R,l,r);
103     return res;
104 }
105 
106 inline int getsum(int x,int y)
107 {
108     int res=0;
109     while(top[x]!=top[y])
110     {
111         if(dep[top[x]]<dep[top[y]]) swap(x,y);
112         res+=querysum(1,1,tot,bh[top[x]],bh[x]);
113         x=fa[top[x]];
114     }
115     if(bh[x]>bh[y]) swap(x,y);
116     res+=querysum(1,1,tot,bh[x],bh[y]);
117     return res;
118 }
119 
120 inline int querymax(int u,int L,int R,int l,int r)
121 {
122     if(l<=L&&R<=r) return mx[u];
123     int MID=(L+R)>>1,res=-INF;
124     if(l<=MID) res=max(res,querymax(u<<1,L,MID,l,r));
125     if(MID<r) res=max(res,querymax(u<<1|1,MID+1,R,l,r));
126     return res;
127 }
128 
129 inline int getmax(int x,int y)
130 {
131     int res=-INF;
132     while(top[x]!=top[y])
133     {
134         if(dep[top[x]]<dep[top[y]]) swap(x,y);
135         res=max(res,querymax(1,1,tot,bh[top[x]],bh[x]));
136         x=fa[top[x]];
137     }
138     if(bh[x]>bh[y]) swap(x,y);
139     
140     res=max(res,querymax(1,1,tot,bh[x],bh[y]));
141     return res; 
142 }
143 
144 inline void updata(int u,int L,int R,int pos,int sp)
145 {
146     if(L==R) {mx[u]=sum[u]=sp;return;}
147     int MID=(L+R)>>1;
148     if(pos<=MID) updata(u<<1,L,MID,pos,sp);
149     else updata(u<<1|1,MID+1,R,pos,sp);
150     pushup(u);
151 }
152 
153 inline void go()
154 {
155     scanf("%d",&qu);
156     char str[10];int a,b;
157     while(qu--)
158     {
159         scanf("%s%d%d",str,&a,&b);
160         if(str[1]=='S') printf("%d\n",getsum(a,b));
161         else if(str[1]=='M') printf("%d\n",getmax(a,b));
162         else updata(1,1,tot,bh[a],b);
163     }
164 }
165 
166 int main()
167 {
168     read();
169     go();
170     return 0;
171 } 

 

 

posted @ 2013-01-18 23:11  proverbs  阅读(827)  评论(0编辑  收藏  举报