【解题思路】

  直接上树剖套线段树/BIT即可。复杂度o(n+qlog22n)(线段树)或o(n+qlog23n)(BIT)。

【参考代码】

  树剖套BIT。(这个树剖好naive啊QAQ)

  1 #include <algorithm>
  2 #include <cctype>
  3 #include <cstdio>
  4 #define REP(I,start,end) for(int I=(start);I<=(end);I++)
  5 #define PER(I,start,end) for(int I=(start);I>=(end);I--)
  6 #define ClearStack(_stack) while(!_stack.empty()){_stack.pop();}
  7 #define ClearQueue(_queue) while(!_queue.empty()){_queue.pop();}
  8 #define ClearArray(_array,from,to,val) REP(i,from,to){_array[i]=val;}
  9 #define maxint 32767
 10 #define maxlongint 2147483647
 11 #define maxint64 9223372036854775807ll
 12 inline void space()
 13 {
 14     putchar(' ');
 15 }
 16 inline void enter()
 17 {
 18     putchar('\n');
 19 }
 20 inline bool eoln(char ptr)
 21 {
 22     return ptr=='\n';
 23 }
 24 inline bool eof(char ptr)
 25 {
 26     return ptr=='\0';
 27 }
 28 inline int getint()
 29 {
 30     char ch=getchar();
 31     for(;!isdigit(ch)&&ch!='+'&&ch!='-';ch=getchar());
 32     bool impositive=ch=='-';
 33     if(impositive)
 34         ch=getchar();
 35     int result=0;
 36     for(;isdigit(ch);ch=getchar())
 37         result=(result<<3)+(result<<1)+ch-'0';
 38     return impositive?-result:result;
 39 }
 40 inline char *getstr()
 41 {
 42     char *result=new char[256],*ptr=result,ch=getchar();
 43     for(;isspace(ch)||eoln(ch)||eof(ch);ch=getchar());
 44     for(;!isspace(ch)&&!eoln(ch)&&!eof(ch);ch=getchar())
 45     {
 46         *ptr=ch;
 47         ptr++;
 48     }
 49     *ptr='\0';
 50     return result;
 51 }
 52 template<typename integer> inline int write(integer n)
 53 {
 54     integer now=n;
 55     bool impositive=now<0;
 56     if(impositive)
 57     {
 58         putchar('-');
 59         now=-now;
 60     }
 61     char sav[20];
 62     sav[0]=now%10+'0';
 63     int result=1;
 64     for(;now/=10;sav[result++]=now%10+'0');
 65     PER(i,result-1,0)
 66         putchar(sav[i]);
 67     return result+impositive;
 68 }
 69 template<typename T> inline bool getmax(T &target,T pattern)
 70 {
 71     return pattern>target?target=pattern,true:false;
 72 }
 73 template<typename T> inline bool getmin(T &target,T pattern)
 74 {
 75     return pattern<target?target=pattern,true:false;
 76 }
 77 template<typename T> class BIT
 78 {
 79     private:
 80         int size;
 81         T _nINF,_INF,*sav,*savSum,*savMax,*savMin;
 82         int lowbit(int now)
 83         {
 84             return now&-now;
 85         }
 86     public:
 87         inline void clear(int length,T nINF=-maxlongint,T INF=maxlongint)
 88         {
 89             size=length;
 90             delete []sav;
 91             delete []savSum;
 92             delete []savMax;
 93             delete []savMin;
 94             sav=new T[size+1];
 95             savSum=new T[size+1];
 96             savMin=new T[size+1];
 97             savMax=new T[size+1];
 98             _nINF=nINF;
 99             _INF=INF;
100             ClearArray(sav,1,size,0);
101             ClearArray(savSum,1,size,0);
102             ClearArray(savMax,1,size,_nINF);
103             ClearArray(savMin,1,size,_INF);
104         }
105         inline void increase(int point,T delta)
106         {
107             sav[point]+=delta;
108             for(int i=point;i<=size;i+=lowbit(i))
109             {
110                 savSum[i]+=delta;
111                 savMax[i]=savMin[i]=sav[i];
112                 for(int j=1;j<lowbit(i);j<<=1)
113                 {
114                     getmax(savMax[i],savMax[i-j]);
115                     getmin(savMin[i],savMin[i-j]);
116                 }
117             }
118         }
119         inline void change(int point,T val)
120         {
121             T delta=val-sav[point];
122             sav[point]=val;
123             for(int i=point;i<=size;i+=lowbit(i))
124             {
125                 savSum[i]+=delta;
126                 savMax[i]=savMin[i]=sav[i];
127                 for(int j=1;j<lowbit(i);j<<=1)
128                 {
129                     getmax(savMax[i],savMax[i-j]);
130                     getmin(savMin[i],savMin[i-j]);
131                 }
132             }
133         }
134         inline T pre_sum(int length)
135         {
136             T result=T(0);
137             for(int i=length;i;i-=lowbit(i))
138                 result+=savSum[i];
139             return result;
140         }
141         inline T query_sum(int left,int right)
142         {
143             return pre_sum(right)-pre_sum(left-1);
144         }
145         inline T query_max(int left,int right)
146         {
147             T result=sav[right];
148             for(int l=left,r=right;l<r;)
149             {
150                 for(r--;r-l>=lowbit(r);r-=lowbit(r))
151                     getmax(result,savMax[r]);
152                 getmax(result,sav[r]);
153             }
154             return result;
155         }
156         inline T query_min(int left,int right)
157         {
158             T result=sav[right];
159             for(int l=left,r=right;l<r;)
160             {
161                 for(r--;r-l>=lowbit(r);r-=lowbit(r))
162                     getmin(result,savMin[r]);
163                 getmin(result,sav[r]);
164             }
165             return result;
166         }
167 };
168 //=============================Header Template===============================
169 #include <cstring>
170 #include <vector>
171 using namespace std;
172 typedef vector<int> vecint;
173 BIT<int> bit[30010];
174 int depth[30010],w[30010],father[30010],weight[30010],wfa[30010],wson[30010],group[30010],position[30010];
175 vecint lines[30010];
176 int DFS(int now)
177 {
178     weight[now]=1;
179     for(vecint::iterator i=lines[now].begin();i!=lines[now].end();i++)
180     {
181         int p=*i;
182         if(!weight[p])
183         {
184             father[p]=now;
185             weight[now]+=DFS(p);
186         }
187     }
188     return weight[now];
189 }
190 void DFSdep(int now,int deep)
191 {
192     if(!now)
193         return;
194     depth[now]=deep;
195     for(vecint::iterator i=lines[now].begin();i!=lines[now].end();i++)
196     {
197         int p=*i;
198         if(p!=father[now]&&p!=wson[now])
199             DFSdep(p,deep+1);
200     }
201     DFSdep(wson[now],deep);
202 }
203 int main()
204 {
205     int n=getint();
206     memset(lines,0,sizeof(lines));
207     REP(i,2,n)
208     {
209         int u=getint(),v=getint();
210         lines[u].push_back(v);
211         lines[v].push_back(u);
212     }
213     REP(i,1,n)
214         w[i]=getint();
215     memset(weight,0,sizeof(weight));
216     father[0]=father[1]=0;
217     DFS(1);
218     memset(wson,0,sizeof(wson));
219     REP(i,1,n)
220     {
221         int maxer=0;
222         for(vecint::iterator j=lines[i].begin();j!=lines[i].end();j++)
223         {
224             int p=*j;
225             if(p!=father[i]&&getmax(maxer,weight[p]))
226                 wson[i]=p;
227         }
228     }
229     int cnt=0;
230     DFSdep(1,1);
231     REP(i,1,n)
232         if(wson[father[i]]!=i)
233         {
234             cnt++;
235             int place=0;
236             for(int j=i;j;j=wson[j])
237             {
238                 wfa[j]=i;
239                 group[j]=cnt;
240                 position[j]=++place;
241             }
242             bit[cnt].clear(place);
243             for(int j=i;j;j=wson[j])
244                 bit[cnt].change(position[j],w[j]);
245         }
246     int q=getint();
247     while(q--)
248     {
249         char *opt=getstr();
250         int u=getint(),v=getint();
251         if(!strcmp(opt,"CHANGE"))
252             bit[group[u]].change(position[u],v);
253         if(!strcmp(opt,"QSUM"))
254         {
255             int ans=0;
256             for(;depth[u]>depth[v];u=father[wfa[u]])
257             {
258                 int left=position[u],right=position[wfa[u]];
259                 if(left>right)
260                     swap(left,right);
261                 ans+=bit[group[u]].query_sum(left,right);
262             }
263             for(;depth[u]<depth[v];v=father[wfa[v]])
264             {
265                 int left=position[v],right=position[wfa[v]];
266                 if(left>right)
267                     swap(left,right);
268                 ans+=bit[group[v]].query_sum(left,right);
269             }
270             for(;wfa[u]!=wfa[v];u=father[wfa[u]],v=father[wfa[v]])
271             {
272                 int left=position[u],right=position[wfa[u]];
273                 if(left>right)
274                     swap(left,right);
275                 ans+=bit[group[u]].query_sum(left,right);
276                 left=position[v];right=position[wfa[v]];
277                 if(left>right)
278                     swap(left,right);
279                 ans+=bit[group[v]].query_sum(left,right);
280             }
281             int left=position[u],right=position[v];
282             if(left>right)
283                 swap(left,right);
284             ans+=bit[group[u]].query_sum(left,right);
285             write(ans);
286             enter();
287         }
288         if(!strcmp(opt,"QMAX"))
289         {
290             int ans=-maxlongint;
291             for(;depth[u]>depth[v];u=father[wfa[u]])
292             {
293                 int left=position[u],right=position[wfa[u]];
294                 if(left>right)
295                     swap(left,right);
296                 getmax(ans,bit[group[u]].query_max(left,right));
297             }
298             for(;depth[u]<depth[v];v=father[wfa[v]])
299             {
300                 int left=position[v],right=position[wfa[v]];
301                 if(left>right)
302                     swap(left,right);
303                 getmax(ans,bit[group[v]].query_max(left,right));
304             }
305             for(;wfa[u]!=wfa[v];u=father[wfa[u]],v=father[wfa[v]])
306             {
307                 int left=position[u],right=position[wfa[u]];
308                 if(left>right)
309                     swap(left,right);
310                 getmax(ans,bit[group[u]].query_max(left,right));
311                 left=position[v];right=position[wfa[v]];
312                 if(left>right)
313                     swap(left,right);
314                 getmax(ans,bit[group[v]].query_max(left,right));
315             }
316             int left=position[u],right=position[v];
317             if(left>right)
318                 swap(left,right);
319             getmax(ans,bit[group[u]].query_max(left,right));
320             write(ans);
321             enter();
322         }
323     }
324     return 0;
325 }
View Code