Atcoder ABC133F - Colorful Tree 题解 主席树 + LCA

题目链接:https://atcoder.jp/contests/abc133/tasks/abc133_f

题目大意:

有一棵树,顶点编号从 \(1\)\(N\)

这棵树中第 \(i\) 条边连接着顶点 \(a_i\) 和顶点 \(b_i\),其颜色和长度分别为 \(c_i\)\(d_i\)

这里每条边的颜色用介于 \(1\)\(N-1\)(包括边界值)之间的整数表示。相同的整数代表相同的颜色,不同的整数代表不同的颜色。

回答以下 \(Q\) 个查询:

查询 \(j\) (\(1 \leq j \leq Q\)): 假设颜色为 \(x_j\) 的边的长度都改变为 \(y_j\),求顶点 \(u_j\) 和顶点 \(v_j\) 之间的距离。(边的长度的改变不会影响后续的查询。)

解题思路完全参考自 Minecraft万岁 大佬的博客:https://www.luogu.com.cn/article/aw4dp6vd

我写代码的时候碰到一个比较脑抽的问题是:习惯用 d 表示深度,但是这里 edge 里也有一个 d,调了半天,然后把深度改成 depth 表示了囧

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5, maxm = 2e6 + 5;

int rt[maxn], idx, ls[maxm], rs[maxm];

int tcnt[maxm], tsum[maxm];

void push_up(int u) {
    tcnt[u] = tcnt[ls[u]] + tcnt[rs[u]];
    tsum[u] = tsum[ls[u]] + tsum[rs[u]];
}

// 多了一条颜色为 c 长度为 d 的边
void add(int c, int d, int l, int r, int u, int uu) {
    tcnt[u] = tcnt[uu];
    tsum[u] = tsum[uu];
    ls[u] = ls[uu];
    rs[u] = rs[uu];
    if (l == r) {
        tcnt[u]++;
        tsum[u] += d;
        return;
    }
    int mid = (l + r) / 2;
    if (c <= mid) {
        ls[u] = ++idx;
        add(c, d, l, mid, ls[u], ls[uu]);
    }
    else {
        rs[u] = ++idx;
        add(c, d, mid+1, r, rs[u], rs[uu]);
    }
    push_up(u);
}

pair<int, int> query(int c, int l, int r, int u) {
    if (!u) return {0, 0};
    if (l == r) {
        return { tcnt[u], tsum[u] };
    }
    int mid = (l + r) / 2;
    return (c <= mid) ? query(c, l, mid, ls[u]) : query(c, mid+1, r, rs[u]);
}

int fa[maxn][17], dis[maxn][17], dep[maxn];
int n, Q;
struct Edge { int v, c, d; };
vector<Edge> g[maxn];

void dfs(int u, int p, int depth) {
    fa[u][0] = p;
    dep[u] = depth;
    for (auto e : g[u]) {
        int v = e.v, c = e.c, d = e.d;
        if (v == p) continue;
        rt[v] = ++idx;
        add(c, d, 1, n-1, rt[v], rt[u]);
        dis[v][0] = d;
        dfs(v, u, depth+1);
    }
}

void check_dfs(int u, int p) {
    for (auto e : g[u]) {
        int v = e.v;
        if (v != p)
            assert(dep[v] == dep[u] + 1),
            check_dfs(v, u);
    }
}

int lca(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 16; i >= 0; i--)
        if (dep[ fa[x][i] ] >= dep[y])
            x = fa[x][i];
    if (x == y) return x;
    for (int i = 16; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

// 计算从节点 x 到它的祖先节点 z 的所有边的长度总和
int get_dis(int x, int z) {
    int res = 0;
    for (int i = 16; i >= 0; i--) {
        if (dep[ fa[x][i] ] >= dep[z]) {
            res += dis[x][i];
            x = fa[x][i];
        }
    }
    return res;
}

int cal(int c, int w, int x, int y) {
    int z = lca(x, y);
    int cnt = 0, sum = 0;
    auto pi = query(c, 1, n-1, rt[x]);
    cnt += pi.first;
    sum += pi.second;
    pi = query(c, 1, n-1, rt[y]);
    cnt += pi.first;
    sum += pi.second;
    pi = query(c, 1, n-1, rt[z]);
    cnt -= 2 * pi.first;
    sum -= 2 * pi.second;
    int dis1 = get_dis(x, z), dis2 = get_dis(y, z);
    return dis1 + dis2 - sum + cnt * w;
}

int main() {
    scanf("%d%d", &n, &Q);
    for (int i = 1; i < n; i++) {
        int a, b, c, d;
        scanf("%d%d%d%d", &a, &b, &c, &d);
        g[a].push_back({ b, c, d });
        g[b].push_back({ a, c, d });
    }
    rt[1] = ++idx;
    dfs(1, 0, 1);
    check_dfs(1, 1);
    for (int j = 1; j <= 16; j++) {
        for (int i = 1; i <= n; i++) {
            fa[i][j] = fa[ fa[i][j-1] ][j-1];
            dis[i][j] = dis[i][j-1] + dis[ fa[i][j-1] ][j-1];
        }
    }
    while (Q--) {
        int c, w, x, y;
        scanf("%d%d%d%d", &c, &w, &x, &y);
        printf("%d\n", cal(c, w, x, y));
    }
    return 0;
}
posted @ 2025-03-29 00:20  quanjun  阅读(13)  评论(0)    收藏  举报