比较基础的一道树链剖分的题 大概还是得说说思路

  树链剖分是将树剖成很多条链,比较常见的剖法是按儿子的size来剖分,剖分完后对于这课树的询问用线段树维护——比如求路径和的话——随着他们各自的链向上走,直至他们在同一条链上为止。比较像lca的方法,只不过这里是按链为单位,而且隔壁的SymenYang说可以用树链剖分做lca。。吓哭

  然后说说惨痛的调题经历:边表一定要开够啊! 不是n-1 而是2*(n-1)啊! 然后写变量时原始值和映射值要搞清楚啊! 不要搞错了! 还有就是下次求最小值一定看清下界是多少! 树的统计是-30000 ~ 30000 ,我果断naive 的写了一个初值为0!!! wa 0 就是这么痛苦! 还是too Young too Simple !

code :

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cstring>
  4 #include <algorithm>
  5 using namespace std;
  6 
  7 const int maxn = 50001;
  8 
  9 struct edge{
 10     int t; edge* next;
 11 }e[maxn*3], *head[maxn];int ne = 0;
 12 
 13 void addedge(int f, int t){
 14     e[ne].t = t; e[ne].next = head[f]; head[f] = e + ne ++;
 15 }
 16 
 17 int n; 
 18 int size[maxn],fa[maxn],dep[maxn],w[maxn],un[maxn],map[maxn];
 19 
 20 struct node{
 21     int smax, sum;
 22     node *l, *r;
 23 }tr[maxn * 3], *root; int trne = 0;
 24 
 25 node* build(int l, int r){
 26     node* x = tr + trne ++;
 27     if(l != r) {
 28         int mid = (l + r) >> 1;
 29         x-> l = build(l, mid);
 30         x-> r = build(mid + 1, r);
 31     }
 32     return x;
 33 }
 34 
 35 void update(node* x){
 36     if(x-> l) {
 37         x-> sum = x-> l-> sum + x-> r->sum;
 38         x-> smax = max(x-> l-> smax, x-> r-> smax);
 39     }
 40 }
 41 
 42 void insert(node* x, int l, int r, int pos, int v) {
 43     if(l == r) { x-> sum = v, x-> smax = v;}
 44     else{
 45         int mid = (l + r) >> 1;
 46         if(pos <= mid) insert(x-> l, l, mid, pos, v);
 47         else insert(x-> r, mid + 1, r, pos, v);
 48         update(x);
 49     }
 50 }
 51 
 52 int ask(node* x, int l, int r, int ls, int rs, int flag) {
 53     if(l == ls && r == rs) {
 54         if(flag == 0) return x-> smax;
 55         else return x-> sum;
 56     }
 57     else {
 58         int mid = (l + r) >> 1;
 59         if(rs <= mid) return ask(x-> l, l, mid, ls, rs, flag);
 60         else if(ls >= mid + 1) return ask(x-> r, mid + 1, r, ls, rs, flag);
 61         else {
 62             if(flag == 0) 
 63                 return max(ask(x->l, l, mid, ls, mid, flag), ask(x-> r, mid + 1, r, mid + 1, rs, flag));
 64             else 
 65                 return ask(x-> l, l, mid, ls, mid, flag) + ask(x-> r, mid + 1, r, mid + 1, rs, flag);
 66         }
 67     }
 68 }
 69 
 70 int cnt = 0;
 71 
 72 void size_cal(int x, int pre) {
 73     if(pre == -1) dep[x] = 1, fa[x] = x;
 74     else dep[x] = dep[pre] + 1, fa[x] = pre;
 75     
 76     size[x] = 1;
 77     for(edge* p = head[x]; p; p = p-> next) 
 78         if(dep[p-> t] == -1)size_cal(p-> t, x), size[x] += size[p-> t];
 79 }
 80 
 81 void divide(int x, int pre){
 82     if(pre == -1) un[x] = x;
 83     else un[x] = un[pre];
 84     map[x] = ++ cnt; insert(root, 1, n, map[x], w[x]);
 85     int tmax = -1, ts = -1;
 86     for(edge* p = head[x]; p; p = p-> next) {
 87         if(dep[p-> t] > dep[x] && size[p-> t] > tmax) tmax = size[p-> t], ts = p-> t;
 88     }
 89     if(ts != -1) {
 90         divide(ts, x);
 91         for(edge* p = head[x]; p; p = p-> next) {
 92             if(dep[p-> t] > dep[x] && p-> t != ts) divide(p-> t, -1);
 93         }
 94     }
 95 }
 96 
 97 void read() {
 98     memset(dep,-1,sizeof(dep));
 99     scanf("%d", &n);
100     root = build(1, n);
101     for(int i = 1; i <= n - 1; i++) {
102         int f, t;
103         scanf("%d%d", &f, &t);
104         addedge(f, t), addedge(t, f);
105     }
106     for(int i = 1; i <= n; ++ i) {
107         scanf("%d", &w[i]);
108     }
109     size_cal(1, -1);divide(1, -1);
110 }
111 
112 int sovmax(int a, int b) {
113     int ans = -30001; int ls, rs;
114     while(un[a] != un[b]) {
115         if(dep[un[a]] > dep[un[b]]) {
116             ls = map[a]; rs = map[un[a]];
117             if(ls > rs) swap(ls, rs);
118             ans = max(ans, ask(root, 1, n, ls, rs, 0));
119             a = fa[un[a]];
120         }
121         else {
122             ls = map[b]; rs = map[un[b]];
123             if(ls > rs) swap(ls, rs);
124             ans = max(ans, ask(root, 1, n, ls, rs, 0));
125             b = fa[un[b]];
126             }
127     }
128     ls = map[a], rs = map[b];
129     if(ls > rs) swap(ls,rs);
130     ans = max(ans, ask(root, 1, n, ls, rs, 0));
131     return ans;
132 }
133 
134 int sovsum(int a,int b) {
135     int ans = 0; int ls, rs;
136     while(un[a] != un[b]) {
137         if(dep[un[a]] > dep[un[b]]) {
138             ls = map[a], rs = map[un[a]];
139             if(ls > rs) swap(ls, rs);
140             ans += ask(root, 1, n, ls, rs, 1);
141             a = fa[un[a]];
142         }
143         else {
144             ls = map[b]; rs = map[un[b]];
145             if(ls > rs) swap(ls, rs);
146             ans += ask(root, 1, n, ls, rs, 1);
147             b = fa[un[b]];
148         }
149     }
150     ls = map[a], rs = map[b];
151     if(ls > rs) swap(ls, rs);
152     ans += ask(root, 1, n, ls, rs, 1);
153     return ans;
154 }
155 
156 void sov() {
157     int m;
158     scanf("%d", &m);
159     while(m --) {
160         char s[10]; int ls, rs;
161         scanf("%s %d%d", s + 1, &ls, &rs);
162         if(s[2] == 'M') printf("%d\n", sovmax(ls, rs));
163         if(s[2] == 'S') printf("%d\n", sovsum(ls, rs));
164         if(s[2] == 'H') insert(root, 1, n, map[ls], rs);
165     }
166 }
167 
168 int main(){
169     read();sov(); 
170     return 0;
171 }
View Code