题解:Luogu P9433 [NAPC-#1] Stage5 - Conveyors

题目概要

题目给定一棵带权树,其中有 \(k\) 个关键节点。对于每个查询 \((s, t)\),要求求出从 \(s\)\(t\) 的最短路径长度,且该路径必须经过所有关键节点至少一次。

解题思路

  1. 预处理
    利用 DFS 遍历整棵树,计算每个节点:

    • 到根的距离 \(dis_{\cdot}\)
    • 节点深度 \(dep_{\cdot}\)
    • 利用倍增方法构造 LCA 数组。
  2. 利用 LCA 求两点间距离
    对于树中任意两点 \(x\)\(y\),它们的最短路径必定经过它们的最近公共祖先 \(\operatorname{LCA}(x,y)\)。设:

    • \(dis_{x}\) 为从根到 \(x\) 的距离,
    • \(dis_{y}\) 为从根到 \(y\) 的距离,
    • \(dis_{ \operatorname{LCA}(x,y)}\) 为从根到最近公共祖先的距离。

    则从 \(x\)\(y\) 的距离为:

    \[dist(x,y)=dis_{x}+dis_{y}-2\times dis_{ \operatorname{LCA}(x,y)} \]

    解释
    \(x\)\(\operatorname{LCA}(x,y)\) 的距离为 \(dis_{x}-dis_{ \operatorname{LCA}(x,y)}\),从 \(y\)\(\operatorname{LCA}(x,y)\) 的距离为 \(dis_{y}-dis_{ \operatorname{LCA}(x,y)}\)。两部分相加后,再去掉公共部分 \(2\times dis_{ \operatorname{LCA}(x,y)}\),就得到了正确的两点间距离。

  3. 查询求解
    在预处理过程中还可能累加了关键节点之间的路径总和(记为 \(tot\))。对于每次查询,通过一定的函数(如 \(f(s)\))找到查询点附近的关键节点,再结合上面的距离公式计算出最终答案。

代码

当然我坚信你们最爱看这个

#include <bits/stdc++.h>

#define N 500010
#define ll long long
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid ((l + r) >> 1)

using namespace std;

int n, k, q;
int nxt[N], head[N], to[N], w[N], e = 0;
int siz[N]; 
int dis[N]; 
bool kk[N];
bool iff[N]; 
int dep[N];
int lca[N][31];
int lg[N];
int r;
int tot;

inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return x * f;
}

inline void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

inline void writeln(int x) {
    write(x);
    putchar('\n');
}

inline void add_edge(int u, int v, int c) {
    to[++e] = v;
    w[e] = c;
    nxt[e] = head[u];
    head[u] = e;
}

inline void dfs(int p, int pre) {
    dep[p] = dep[pre] + 1;
    lca[p][0] = pre;
    if(kk[p]) {
        siz[p] = 1;
    }
    for(int i = 1; i <= lg[dep[p]]; ++i) {
        lca[p][i] = lca[lca[p][i - 1]][i - 1];
    }
    for(int i = head[p]; i; i = nxt[i]) {
        if(to[i] == pre) {
            continue;
        }
        dis[to[i]] = dis[p] + w[i];
        dfs(to[i], p);
        siz[p] += siz[to[i]];
        if(siz[to[i]] > 0) {
            tot += w[i];
        }
    }
}

inline int f(int p) {
    if(iff[p]) return p;
    for(int i = lg[dep[p]]; i >= 0; --i) {
        if(lca[p][i] && !iff[lca[p][i]])
            p = lca[p][i];
    }
    return lca[p][0];
}

inline int fa(int x, int y) {
    if(dep[x] > dep[y]) {
        swap(x, y);
    }
    for(int i = lg[dep[y]]; i >= 0; --i) {
        if(dep[lca[y][i]] >= dep[x]) {
            y = lca[y][i];
        }
    }
    if(x == y) {
        return x;
    }
    for(int i = lg[dep[x]]; i >= 0; --i) {
        if(lca[x][i] != lca[y][i]) {
            x = lca[x][i];
            y = lca[y][i];
        }
    }
    return lca[x][0];
}

inline int dist(int x, int y) {
    return dis[x] + dis[y] - 2 * dis[fa(x, y)];
}

inline void tag() {
    for(int i = 1; i <= n; ++i) {
        if(siz[i]) {
            iff[i] = 1;
        }
    }
}

int main() {
    n = read();
    q = read();
    k = read();
    
    for(int i = 1; i <= n; ++i) {
        lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
    }
    
    for(int i = 1; i <= n - 1; ++i) {
        int u, v, c;
        u = read();
        v = read();
        c = read();
        add_edge(u, v, c);
        add_edge(v, u, c);
    }
    
    for(int i = 1; i <= k; ++i) {
        int p = read();
        kk[p] = 1;
        r = p;
    }
    
    dep[0] = 0;
    dfs(r, 0);
    tag();
    
    while(q--) {
        int s, t;
        s = read();
        t = read();
        int x = f(s), y = f(t);
        int ans = dist(s, x) + dist(y, t) + 2 * tot - dist(x, y);
        writeln(ans);
    }
    
    return 0;
}
posted @ 2025-10-14 23:40  amlhdsan  阅读(4)  评论(0)    收藏  举报