Atcoder ABC133F - Colorful Tree 题解 主席树 + LCA
题目链接:https://atcoder.jp/contests/abc133/tasks/abc133_f
题目大意:
有一棵树,顶点编号从 \(1\) 到 \(N\)。
这棵树中第 \(i\) 条边连接着顶点 \(a_i\) 和顶点 \(b_i\),其颜色和长度分别为 \(c_i\) 和 \(d_i\)。
这里每条边的颜色用介于 \(1\) 和 \(N-1\)(包括边界值)之间的整数表示。相同的整数代表相同的颜色,不同的整数代表不同的颜色。
回答以下 \(Q\) 个查询:
查询 \(j\) (\(1 \leq j \leq Q\)): 假设颜色为 \(x_j\) 的边的长度都改变为 \(y_j\),求顶点 \(u_j\) 和顶点 \(v_j\) 之间的距离。(边的长度的改变不会影响后续的查询。)
解题思路完全参考自 Minecraft万岁 大佬的博客:https://www.luogu.com.cn/article/aw4dp6vd
我写代码的时候碰到一个比较脑抽的问题是:习惯用 d 表示深度,但是这里 edge 里也有一个 d,调了半天,然后把深度改成 depth 表示了囧
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5, maxm = 2e6 + 5;
int rt[maxn], idx, ls[maxm], rs[maxm];
int tcnt[maxm], tsum[maxm];
void push_up(int u) {
tcnt[u] = tcnt[ls[u]] + tcnt[rs[u]];
tsum[u] = tsum[ls[u]] + tsum[rs[u]];
}
// 多了一条颜色为 c 长度为 d 的边
void add(int c, int d, int l, int r, int u, int uu) {
tcnt[u] = tcnt[uu];
tsum[u] = tsum[uu];
ls[u] = ls[uu];
rs[u] = rs[uu];
if (l == r) {
tcnt[u]++;
tsum[u] += d;
return;
}
int mid = (l + r) / 2;
if (c <= mid) {
ls[u] = ++idx;
add(c, d, l, mid, ls[u], ls[uu]);
}
else {
rs[u] = ++idx;
add(c, d, mid+1, r, rs[u], rs[uu]);
}
push_up(u);
}
pair<int, int> query(int c, int l, int r, int u) {
if (!u) return {0, 0};
if (l == r) {
return { tcnt[u], tsum[u] };
}
int mid = (l + r) / 2;
return (c <= mid) ? query(c, l, mid, ls[u]) : query(c, mid+1, r, rs[u]);
}
int fa[maxn][17], dis[maxn][17], dep[maxn];
int n, Q;
struct Edge { int v, c, d; };
vector<Edge> g[maxn];
void dfs(int u, int p, int depth) {
fa[u][0] = p;
dep[u] = depth;
for (auto e : g[u]) {
int v = e.v, c = e.c, d = e.d;
if (v == p) continue;
rt[v] = ++idx;
add(c, d, 1, n-1, rt[v], rt[u]);
dis[v][0] = d;
dfs(v, u, depth+1);
}
}
void check_dfs(int u, int p) {
for (auto e : g[u]) {
int v = e.v;
if (v != p)
assert(dep[v] == dep[u] + 1),
check_dfs(v, u);
}
}
int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 16; i >= 0; i--)
if (dep[ fa[x][i] ] >= dep[y])
x = fa[x][i];
if (x == y) return x;
for (int i = 16; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
// 计算从节点 x 到它的祖先节点 z 的所有边的长度总和
int get_dis(int x, int z) {
int res = 0;
for (int i = 16; i >= 0; i--) {
if (dep[ fa[x][i] ] >= dep[z]) {
res += dis[x][i];
x = fa[x][i];
}
}
return res;
}
int cal(int c, int w, int x, int y) {
int z = lca(x, y);
int cnt = 0, sum = 0;
auto pi = query(c, 1, n-1, rt[x]);
cnt += pi.first;
sum += pi.second;
pi = query(c, 1, n-1, rt[y]);
cnt += pi.first;
sum += pi.second;
pi = query(c, 1, n-1, rt[z]);
cnt -= 2 * pi.first;
sum -= 2 * pi.second;
int dis1 = get_dis(x, z), dis2 = get_dis(y, z);
return dis1 + dis2 - sum + cnt * w;
}
int main() {
scanf("%d%d", &n, &Q);
for (int i = 1; i < n; i++) {
int a, b, c, d;
scanf("%d%d%d%d", &a, &b, &c, &d);
g[a].push_back({ b, c, d });
g[b].push_back({ a, c, d });
}
rt[1] = ++idx;
dfs(1, 0, 1);
check_dfs(1, 1);
for (int j = 1; j <= 16; j++) {
for (int i = 1; i <= n; i++) {
fa[i][j] = fa[ fa[i][j-1] ][j-1];
dis[i][j] = dis[i][j-1] + dis[ fa[i][j-1] ][j-1];
}
}
while (Q--) {
int c, w, x, y;
scanf("%d%d%d%d", &c, &w, &x, &y);
printf("%d\n", cal(c, w, x, y));
}
return 0;
}
浙公网安备 33010602011771号