树上前缀和
求前缀和,我们一般默认\(w[i]\) 表示该节点到根节点的权值和,采用自顶向下更新。
点权前缀
x到y路径上的和为:
\[sum_x + sum_y - sum_{lca}- sum_{fa_{lca}}
\]
边权前缀
x到y的路径上的和为:
\[sum_x + sum_y - 2 \times sum_{lca}
\]
树上差分学习
树上差分可以理解为对树上某一段路径进行差分操作,这里的路径可以类比一维数组的区间进行理解。 例如在对树上的一些路径进行频繁操作,并且询问某条边或某个点在经过操作后的值的时候,就可以运用树上差分思想了。
树上差分通常会结合 树基础 和 最近共公共祖先进行考察。 树差分又分为 点差分 和 边差分, 在实现上会稍有不同。
点差分
\[d_s \leftarrow d_s + 1 \\
d_{lca} \leftarrow d_{lca} - 1 \\
d_t \leftarrow d_t + 1 \\
d_{f(lca)} \leftarrow d_{f(lca)} - 1
\]
注意树上差分+前缀我们需要自底向上更新,\(w[i]\) 表示以\(i\)节点为根的权值和
例题:Max Flow P
#include <bits/stdc++.h>
using i64 = long long;
const int N = 5e5 + 10, M = N;
int n, k;
int h[N], e[M], ne[M], w[M], idx;
int depth[N], fa[N][17], sum[N], par[N], ans;
void add(int a, int b) {
ne[idx] = h[a], h[a] = idx, e[idx ++] = b;
}
void dfs(int u, int father, int dep) {
depth[u] = dep;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == father) continue;
fa[v][0] = u;
for (int k = 1; k <= 16; k ++) {
fa[v][k] = fa[fa[v][k - 1]][k - 1];
}
dfs(v, u, dep + 1);
}
}
int lca(int a, int b) {
if (depth[a] < depth[b]) std::swap(a, b);
for (int k = 16; k >= 0; k --) {
if (depth[fa[a][k]] >= depth[b]) {
a = fa[a][k];
}
}
if (a == b) return a;
for (int k = 16; k >= 0; k --) {
if (fa[a][k] != fa[b][k]) {
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
int dfs2(int u, int fa) {
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == fa) continue;
int s = dfs2(v, u);
sum[u] += s;
}
ans = std::max(ans, sum[u]);
return sum[u];
}
int main() {
memset(h, -1, sizeof h);
std::cin >> n >> k;
for (int i = 1; i < n; i ++) {
int x, y;
std::cin >> x >> y;
add(x, y);
add(y, x);
}
depth[0] = 0;
dfs(1, -1, 1);
while (k --) {
int x, y;
std::cin >> x >> y;
int anc = lca(x, y);
sum[x] += 1;
sum[y] += 1;
sum[anc] -= 1;
if (fa[anc][0])
sum[fa[anc][0]] -= 1;
}
dfs2(1, -1);
std::cout << ans;
}
边差分
\[d_s \leftarrow d_s + 1 \\
d_t \leftarrow d_t + 1 \\
d_{lca} \leftarrow d_{lca} - 2
\]
对于边差分,由于树是拓扑结构,我们用向下的一条边所指向的点来表达边,这样很简单就能推出上述公式。
例题: 闇の連鎖
#include <bits/stdc++.h>
using i64 = long long;
const int N = 1e5 + 10, M = 2 * N;
int n, m;
int h[N], e[M], ne[M], w[M], idx;
int fa[N][18], depth[N], sum[M];
i64 ans;
void add(int a, int b) {
ne[idx] = h[a], h[a] = idx, e[idx ++] = b;
}
void dfs(int u, int father, int dep) {
depth[u] = dep;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == father) continue;
depth[v] = depth[u] + 1;
fa[v][0] = u;
for (int k = 1; k <= 17; k ++) {
fa[v][k] = fa[fa[v][k - 1]][k - 1];
}
dfs(v, u, dep + 1);
}
}
int lca(int a, int b) {
if (depth[a] < depth[b]) std::swap(a, b);
for (int k = 17; k >= 0; k --) {
if (depth[fa[a][k]] >= depth[b]) {
a = fa[a][k];
}
}
if (a == b) return b;
for (int k = 17; k >= 0; k --) {
if (fa[a][k] != fa[b][k]) {
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
int dfs2(int u, int fa) {
int res = sum[u];
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == fa) continue;
int s = dfs2(v, u);
if (s == 0) ans += m;
else if (s == 1) ans ++;
res += s;
}
return res;
}
int main() {
memset(h, -1, sizeof h);
std::cin >> n >> m;
for (int i = 1; i <= n - 1; i ++) {
int x, y;
std::cin >> x >> y;
add(x, y);
add(y, x);
}
depth[0] = 0;
dfs(1, -1, 1);
for (int i = 1; i <= m; i ++) {
int x, y;
std::cin >> x >> y;
sum[x] += 1;
sum[y] += 1;
sum[lca(x, y)] -= 2;
}
dfs2(1, -1);
std::cout << ans;
}
例题2:松鼠的新家
#include <bits/stdc++.h>
using i64 = long long;
const int N = 3e5 + 10, M = N * 2;
int n, m;
int h[N], e[M], ne[M], idx;
int fa[N][19], depth[N], sum[M];
int a[N];
i64 ans;
void add(int a, int b) {
ne[idx] = h[a], h[a] = idx, e[idx ++] = b;
}
void dfs(int u, int father, int dep) {
depth[u] = dep;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == father) continue;
depth[v] = depth[u] + 1;
fa[v][0] = u;
for (int k = 1; k <= 18; k ++) {
fa[v][k] = fa[fa[v][k - 1]][k - 1];
}
dfs(v, u, dep + 1);
}
}
int lca(int a, int b) {
if (depth[a] < depth[b]) std::swap(a, b);
for (int k = 18; k >= 0; k --) {
if (depth[fa[a][k]] >= depth[b]) {
a = fa[a][k];
}
}
if (a == b) return b;
for (int k = 18; k >= 0; k --) {
if (fa[a][k] != fa[b][k]) {
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
int dfs2(int u, int fa) {
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == fa) continue;
sum[u] += dfs2(v, u);
}
return sum[u];
}
int main() {
memset(h, -1, sizeof h);
std::cin >> n;
for (int i = 1; i <= n; i ++) {
std::cin >> a[i];
}
for (int i = 1; i <= n - 1; i ++) {
int x, y;
std::cin >> x >> y;
add(x, y);
add(y, x);
}
depth[0] = 0;
dfs(1, -1, 1);
for (int i = 2; i <= n; i ++) {
int x = a[i - 1], y = a[i];
sum[x] += 1;
sum[y] += 1;
int anc = lca(x, y);
sum[anc] -= 1;
sum[fa[anc][0]] -= 1;
}
dfs2(1, -1);
sum[a[1]] ++;
for (int i = 1; i <= n; i ++) {
std::cout << sum[i] - 1 << "\n";
}
}