树上路径(树链剖分)
来源:https://ac.nowcoder.com/acm/contest/22131/C
时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
题目描述
给出一个n个点的树,1号节点为根节点,每个点有一个权值
你需要支持以下操作
1.将以u为根的子树内节点(包括u)的权值加val
你需要支持以下操作
1.将以u为根的子树内节点(包括u)的权值加val
2.将(u, v)路径上的节点权值加val
3.询问(u, v)路径上节点的权值两两相乘的和
输入描述:
第一行两个整数n, m,表示树的节点个数以及操作个数
接下来一行n个数,表示每个节点的权值
接下来n - 1行,每行两个整数(u, v),表示(u, v)之间有边
接下来m行
开始有一个数opt,表示操作类型
若opt = 1,接下来两个整数表示u, val
若opt = 2,接下来三个整数表示(u, v), val
若opt = 3,接下来两个整数表示(u, v)
含义均如题所示
输出描述:
对于每个第三种操作,输出一个数表示答案,对10^9+7取模
示例1
输入
3 8
5 3 1
1 2
1 3
3 1 2
3 1 3
3 2 3
1 1 2
2 1 3 2
3 1 2
3 1 3
3 2 3
输出
15
5
23
45
45
115
$\begin{pmatrix}
1 & 2 & 3 & ... & n\\
p1 & p2 & p3 & ... & pn\\
\end{pmatrix}
\\22$
第一个和第二个操作板子就能解决, 第三个需要一些推导
有
$(x_1+x_2+x_3+...+x_n) ^ 2 = (x_1^2+x_2^2+x_3^2+...+x_i^2+...+x_n^2+...+x_1x_2+x_1x_3+...+x_ix_j+...+x_{n-1}x_n)$
$(\sum_{i = l}^r x_i)^2 - \sum_{i = l}^r x_i^2 = \sum_{i = l}^r\sum_{j = i \wedge j \neq i}^r 2x_ix_j$
后面的$\sum_{i = l}^r\sum_{j = i \wedge j \neq i}^r 2x_ix_j$即为我们想要的答案
于是我们可以维护区间元素的平方和跟一般和, sum1数组表示一般和, sum2数组表示平方和
但我们在维护的过程中会发现, 平方和似乎不是那么好维护, 因为题目里给定的操作还要给每个元素加上一个数值。
这时候就要思考lazy标记的作用了。
$\sum_{l}^r(x+b)^2 \\= \sum_{l}^r(x^2+2bx+b^2) \\= \sum_{l}^rx^2 + 2b\sum_{l}^rx + \sum_{l}^rb^2 \\= \sum_{l}^rx^2 +2 b\sum_{l}^rx + (r - l + 1)b^2$
至此, 结果就显而易见了
但是要注意模mod的时候不能使用除法, 而是要用逆元替代(千万要注意)
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 #define IOS ios::sync_with_stdio(false), cin.tie(0), cout.tie(0) 7 using ll = long long; 8 9 constexpr int MAXN = 1e5 + 3, mod = 1e9 + 7; 10 11 int h[MAXN], e[MAXN << 1], ne[MAXN << 1], w[MAXN << 2], wt[MAXN << 2], Size[MAXN], mx[MAXN], idx, root; 12 int fa[MAXN], deep[MAXN], son[MAXN], id[MAXN], top[MAXN], n, cnt; 13 ll sum1[MAXN << 2], sum2[MAXN << 2], lazy[MAXN << 2], nid, inv2; 14 ll qpow(ll x, int n) 15 { 16 ll ans = 1; 17 while (n) 18 { 19 if (n & 1) 20 { 21 ans = ans * x % mod; 22 } 23 x = x * x % mod; 24 n >>= 1; 25 } 26 return ans; 27 } 28 void addedge(int u, int v, int c = 0) 29 { 30 ++idx; 31 e[idx] = v; 32 ne[idx] = h[u]; 33 h[u] = idx; 34 } 35 36 void pushup(int rt) 37 { 38 sum1[rt] = (sum1[rt << 1] + sum1[rt << 1 | 1]) % mod; 39 sum2[rt] = (sum2[rt << 1] + sum2[rt << 1 | 1]) % mod; 40 } 41 void unionLazy(int rt) 42 { 43 lazy[rt << 1] += lazy[rt]; 44 lazy[rt << 1 | 1] += lazy[rt]; 45 } 46 void calLazy(int rt, int len) 47 { 48 (sum2[rt << 1] += 2 * lazy[rt] * sum1[rt << 1] + (len - (len >> 1)) * lazy[rt] * lazy[rt]) %= mod; 49 (sum1[rt << 1] += (len - (len >> 1)) * lazy[rt]) %= mod; 50 51 (sum2[rt << 1 | 1] += 2 * lazy[rt] % mod * sum1[rt << 1 | 1] % mod + (len >> 1) * lazy[rt] * lazy[rt]) %= mod; 52 (sum1[rt << 1 | 1] += (len >> 1) * lazy[rt]) %= mod; 53 } 54 55 void pushdown(int rt, int len) 56 { 57 if (lazy[rt] == 0) 58 return; 59 60 calLazy(rt, len); 61 unionLazy(rt); 62 63 lazy[rt] = 0; 64 } 65 66 void build(int l, int r, int rt) 67 { 68 lazy[rt] = 0; 69 if (l == r) 70 { 71 sum1[rt] = wt[l]; 72 sum2[rt] = wt[l] * wt[l]; 73 return; 74 } 75 int mid = l + r >> 1; 76 build(l, mid, rt << 1); 77 build(mid + 1, r, rt << 1 | 1); 78 pushup(rt); 79 } 80 81 ll query(int a, int b, int op, int l, int r, int rt) 82 { 83 if (a <= l && r <= b) 84 return op == 1 ? sum1[rt] : sum2[rt]; 85 pushdown(rt, r - l + 1); 86 int mid = l + r >> 1; 87 ll ans = 0; 88 if (a <= mid) 89 { 90 ans = (ans + query(a, b, op, l, mid, rt << 1)) % mod; 91 } 92 if (b > mid) 93 { 94 ans = (ans + query(a, b, op, mid + 1, r, rt << 1 | 1)) % mod; 95 } 96 return ans; 97 } 98 99 void update(int a, int b, ll c, int l, int r, int rt) 100 { 101 102 if (a <= l && r <= b) 103 { 104 (sum2[rt] += 2 * c * sum1[rt] + (r - l + 1) * c * c) %= mod; 105 (sum1[rt] += (r - l + 1) * c) %= mod; 106 lazy[rt] += c; 107 return; 108 } 109 110 pushdown(rt, r - l + 1); 111 int mid = l + r >> 1; 112 if (a <= mid) 113 update(a, b, c, l, mid, rt << 1); 114 if (b > mid) 115 update(a, b, c, mid + 1, r, rt << 1 | 1); 116 pushup(rt); 117 } 118 119 ll pathquery(int x, int y) 120 { 121 ll ans1 = 0, ans2 = 0; 122 while (top[x] != top[y]) 123 { 124 if (deep[top[x]] < deep[top[y]]) 125 swap(x, y); 126 ans1 = (ans1 + query(id[top[x]], id[x], 1, 1, n, 1)) % mod; 127 ans2 = (ans2 + query(id[top[x]], id[x], 2, 1, n, 1)) % mod; 128 x = fa[top[x]]; 129 } 130 131 if (deep[x] > deep[y]) 132 swap(x, y); 133 ans1 = (ans1 + query(id[x], id[y], 1, 1, n, 1)) % mod; 134 ans2 = (ans2 + query(id[x], id[y], 2, 1, n, 1)) % mod; 135 return (ans1 * ans1 % mod - ans2 + mod) * inv2 % mod; 136 } 137 138 void lcaadd(int x, int y, ll c) 139 { 140 while (top[x] != top[y]) 141 { 142 if (deep[top[x]] < deep[top[y]]) 143 swap(x, y); 144 update(id[top[x]], id[x], c, 1, n, 1); 145 146 x = fa[top[x]]; 147 } 148 if (deep[x] > deep[y]) 149 swap(x, y); 150 update(id[x], id[y], c, 1, n, 1); 151 } 152 153 void sonadd(int x, ll c) 154 { 155 update(id[x], id[x] + Size[x] - 1, c, 1, n, 1); 156 } 157 158 ll sonquery(int x, int op) 159 { 160 return query(id[x], id[x] + Size[x] - 1, op, 1, n, 1) % mod; 161 } 162 163 void dfs1(int x, int f, int dep) 164 { 165 deep[x] = dep; 166 fa[x] = f; 167 Size[x] = 1; 168 int maxson = -1; 169 for (int i = h[x]; i; i = ne[i]) 170 { 171 int y = e[i]; 172 if (y != f) 173 { 174 dfs1(y, x, dep + 1); 175 176 Size[x] += Size[y]; 177 if (maxson < Size[y]) 178 { 179 son[x] = y; 180 maxson = Size[y]; 181 } 182 } 183 } 184 } 185 186 void dfs2(int x, int topf) 187 { 188 id[x] = ++cnt; 189 wt[cnt] = w[x]; //必须根据访问顺序另开一个数组保存, 否则初始化答案会出错 190 top[x] = topf; 191 if (!son[x]) 192 return; 193 dfs2(son[x], topf); 194 for (int i = h[x]; i; i = ne[i]) 195 { 196 int y = e[i]; 197 if (y != fa[x] && y != son[x]) 198 dfs2(y, y); 199 } 200 } 201 202 int main() 203 { 204 IOS; 205 206 inv2 = qpow(2, mod - 2); 207 int u, v, c, k, m, rt = 1; 208 int x, y, z, op, val; 209 210 cin >> n >> m; 211 for (int i = 1; i <= n; ++i) 212 cin >> w[i]; 213 214 for (int i = 1; i < n; ++i) 215 { 216 cin >> u >> v; 217 addedge(u, v); 218 addedge(v, u); 219 } 220 dfs1(rt, 0, 1); 221 dfs2(rt, rt); 222 build(1, n, 1); 223 224 while (m--) 225 { 226 cin >> op >> x >> y; 227 if (op == 1) 228 { 229 sonadd(x, y); 230 } 231 else if (op == 2) 232 { 233 cin >> val; 234 lcaadd(x, y, val); 235 } 236 else 237 { 238 cout << pathquery(x, y) << '\n'; 239 } 240 } 241 242 return 0; 243 }

浙公网安备 33010602011771号