[Bzoj1036][ZJOI2008]树的统计Count(树链剖分)

题目链接:https://www.lydsy.com/JudgeOnline/problem.php?id=1036

树链剖分的板子题,在bzoj上做到就当复习啦

  1 #include<bits/stdc++.h>
  2 #define lson l,mid,i<<1
  3 #define rson mid+1,r,i<<1|1
  4 using namespace std;
  5 typedef long long ll;
  6 const int maxn = 200010;
  7 const int INF = 2e9;
  8 struct node {
  9     int s, e, next;
 10 }edge[maxn * 2];
 11 int n, m;
 12 int son[maxn], top[maxn], tid[maxn], fat[maxn], siz[maxn], dep[maxn], rak[maxn];
 13 int head[maxn], len, dfx;
 14 //siz保存以i为根的子树节点个数,top保存i节点所在链的顶端节点,son保存i节点的重儿子,fat保存i节点的父亲节点\
 15 dep保存i节点的深度(根为1),,tid保存i节点dfs后的新编号,rak保存新编号i对应的节点(rak[i]=j,tid[j]=i)。
 16 void init() {
 17     memset(head, -1, sizeof(head));
 18     len = 0, dfx = 0;
 19 }
 20 void add(int s, int e) {//邻接表存值
 21     edge[len].s = s;
 22     edge[len].e = e;
 23     edge[len].next = head[s];
 24     head[s] = len++;
 25 }
 26 void dfs1(int x, int fa, int d) {
 27     dep[x] = d, siz[x] = 1, fat[x] = fa, son[x] = -1;
 28     for (int i = head[x]; i != -1; i = edge[i].next) {
 29         int y = edge[i].e;
 30         if (y == fa)
 31             continue;
 32         dfs1(y, x, d + 1);
 33         siz[x] += siz[y];
 34         if (son[x] == -1 || siz[y] > siz[son[x]])
 35             son[x] = y;
 36     }
 37 }
 38 void dfs2(int x, int c) {
 39     top[x] = c;
 40     tid[x] = ++dfx;
 41     rak[dfx] = x;
 42     if (son[x] == -1)
 43         return;
 44     dfs2(son[x], c);
 45     for (int i = head[x]; i != -1; i = edge[i].next) {
 46         int y = edge[i].e;
 47         if (y != fat[x] && y != son[x])
 48             dfs2(y, y);
 49     }
 50 }
 51 ll a[maxn];
 52 ll sum[maxn * 4];
 53 ll Max[maxn * 4];
 54 void up(int i) {
 55     sum[i] = sum[i << 1] + sum[i << 1 | 1];
 56     Max[i] = max(Max[i << 1], Max[i << 1 | 1]);
 57 }
 58 void build(int l, int r, int i) {
 59     if (l == r) {
 60         sum[i] = a[rak[l]];
 61         Max[i] = a[rak[l]];
 62         return;
 63     }
 64     int mid = l + r >> 1;
 65     build(lson);
 66     build(rson);
 67     up(i);
 68 }
 69 void update(int t, int k, int l, int r, int i) {
 70     if (l == r) {
 71         sum[i] = k;
 72         Max[i] = k;
 73         return;
 74     }
 75     int mid = l + r >> 1;
 76     if (t <= mid)
 77         update(t, k, lson);
 78     else
 79         update(t, k, rson);
 80     up(i);
 81 }
 82 ll queryM(int L, int R, int l, int r, int i) {
 83     if (L <= l && r <= R)
 84         return Max[i];
 85     int mid = l + r >> 1;
 86     ll MAX = -INF;
 87     if (L <= mid)
 88         MAX = max(MAX, queryM(L, R, lson));
 89     if (R > mid)
 90         MAX = max(MAX, queryM(L, R, rson));
 91     return MAX;
 92 }
 93 ll queryS(int L, int R, int l, int r, int i) {
 94     if (L <= l && r <= R)
 95         return sum[i];
 96     int mid = l + r >> 1;
 97     ll ans = 0;
 98     if (L <= mid)
 99         ans += queryS(L, R, lson);
100     if (R > mid)
101         ans += queryS(L, R, rson);
102     return ans;
103 }
104 ll solve(int x, int y, int flg) {
105     ll ans;
106     if (flg)
107         ans = -INF;
108     else
109         ans = 0;
110     while (top[x] != top[y]) {
111         if (dep[top[x]] < dep[top[y]])
112             swap(x, y);
113         if (flg)
114             ans = max(ans, queryM(tid[top[x]], tid[x], 1, n, 1));
115         else
116             ans += queryS(tid[top[x]], tid[x], 1, n, 1);
117         x = fat[top[x]];
118     }
119     if (dep[x] < dep[y])
120         swap(x, y);
121     if (flg)
122         ans = max(ans, queryM(tid[y], tid[x], 1, n, 1));
123     else
124         ans += queryS(tid[y], tid[x], 1, n, 1);
125     return ans;
126 }
127 int main() {
128     while (scanf("%d", &n) != EOF) {
129         init();
130         int x, y;
131         for (int i = 0; i < n - 1; i++) {
132             scanf("%d%d", &x, &y);
133             add(x, y);
134             add(y, x);
135         }
136         for (int i = 1; i <= n; i++)
137             scanf("%lld", &a[i]);
138         dfs1(1, 0, 0);
139         dfs2(1, 1);
140         build(1, n, 1);
141         scanf("%d", &m);
142         while (m--) {
143             char s[10];
144             scanf("%s", s);
145             if (strcmp(s, "QMAX") == 0) {
146                 scanf("%d%d", &x, &y);
147                 printf("%lld\n", solve(x, y, 1));
148             }
149             else if (strcmp(s, "QSUM") == 0) {
150                 scanf("%d%d", &x, &y);
151                 printf("%lld\n", solve(x, y, 0));
152             }
153             else {
154                 scanf("%d%d", &x, &y);
155                 update(tid[x], y, 1, n, 1);
156             }
157         }
158     }
159 }

 

posted @ 2019-07-01 19:16  祈梦生  阅读(155)  评论(0编辑  收藏  举报