小红的树上路径查询(hard)

小红的树上路径查询(hard)

题目描述

本题和 $hard$ 难度的区别是,询问的次数有多次!

小红拿到了一棵树,她有多次询问,每次询问输入一条简单路径 $x,y$,她想知道树上所有节点到该路径的最短路之和是多少,你能帮帮她吗?

定义节点到路径的最短路为:节点到路径上所有点的最短路中,值最小的那个。特殊的,如果节点在路径上,则最短路为 $0$。

简单路径:从树上的一个节点出发,沿着树的边走,不重复地经过树上的节点,到达另一个节点的路径。

输入描述:

第一行输入两个正整数 $n,q$,代表节点数量和询问次数。

接下来的 $n−1$ 行,每行输入两个正整数 $u,v$,代表节点 $u$ 和节点 $v$ 有一条边连接。

接下来的 $q$ 行,每行输入两个正整数 $x,y$,代表一次询问。

$1 \leq n,q \leq 10^5$

$1 \leq u,v,x,y \leq n$

输出描述:

输出 $q$ 行,每行输出一个整数,代表询问的答案。

示例1

输入

4 2
1 2
1 3
1 4
2 3
2 1

输出

1
2

 

解题思路

  不会,直接参考的题解。

  对于 $x$ 与 $y$ 构成的 $xy$ 路径外一点 $v$,假设 $v$ 到 $xy$ 路径上最近的点是 $u$,那么 $v$ 到 $x$ 或 $v$ 到 $y$ 的路径必定包含点 $u$(假设以 $v$ 为根进行 bfs,那么第一次遍历到 $xy$ 路径上的点就是 $u$,继而从 $u$ 扩展到 $xy$ 路径上的其他点)。这点其实还是很难想到的。

  然后求 $v$ 分别到 $x$ 和 $y$ 距离之和。有

\begin{align*}
&d(v,x) + d(v,y) \\
= &d(v,u) + d(u,x) + d(v,u) + d(u,y) \\
= &2d(v,u) + d(u,x) + d(u,y) \\
= &2d(v,u) + d(x,y)
\end{align*}

  即有 $d(v,x) + d(v,y) = 2d(v,u) + d(x,y) \Rightarrow d(v,u) = \frac{d(v,x) + d(v,y) - d(x,y)}{2}$。这条式子有什么用呢?实际上我们关心的是所有点到 $u$(即该点到 $xy$ 路径最近的点)的距离的和,并不需要求出具体的 $u$。同时可以发现如果 $v$ 也是 $xy$ 路径上的点上式同样成立。因此所有点关于 $d(v,u)$ 的和就是

\begin{align*}
&\sum\limits_{v=1}^{n}{d(v,u)} \\
=&\frac{1}{2}\sum\limits_{v=1}^{n}{d(v,x) + d(v,y) - d(x,y)} \\
=&\frac{1}{2}\left(\sum\limits_{v=1}^{n}{d(v,x)} + \sum\limits_{v=1}^{n}{d(v,y)} - n \cdot d(x,y)\right) \\
\end{align*}

  其中 $\sum\limits_{v=1}^{n}{d(v,x)}$ 和 $\sum\limits_{v=1}^{n}{d(v,y)}$ 分别是 $x$ 和 $y$ 到所有点的距离总和,这个可以用换根 dp 求得。$d(x,y)$ 可以分别求出 $x$ 和 $y$ 到最近公共祖先的距离再求和。

  下面简单讲一下如何求所有点到 $u$ 的距离总和。固定 $1$ 为根,定义 $f(u)$ 表示子树 $u$ 中所有点到 $u$ 的距离总和,$g(u)$ 表示从 $u$ 往上走的所有点(其实就是除了子树 $u$ 外的点)到 $u$ 的距离总和。那么状态转移方程就是

$$f(u) = \sum\limits_{v \in \text{son}(u)}{f(v) + \text{sz}_v}$$

$$g(u) = g(p_u) + f(p_u) - (f(u) + \text{sz}_u) + (n - \text{sz}_u)$$

  其中 $\text{sz}_u$ 表示子树 $u$ 的大小,$p_u$ 表示 $u$ 的父节点。那么所有点到 $u$ 的距离总和就是 $f(u)+g(u)$。

  AC 代码如下,时间复杂度为 $O((n+q)\log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 5, M = N * 2;

int n, m;
int h[N], e[M], ne[M], idx;
LL sz[N], f[N], g[N];
int fa[N][17], d[N];

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}

void dfs1(int u, int p) {
    sz[u] = 1;
    d[u] = d[p] + 1;
    fa[u][0] = p;
    for (int i = 1; i <= 16; i++) {
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    }
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        f[u] += f[v] + sz[v];
    }
}

void dfs2(int u, int p) {
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        g[v] = g[u] + f[u] - (f[v] + sz[v]) + n - sz[v];
        dfs2(v, u);
    }
}

int lca(int a, int b) {
    if (d[a] < d[b]) swap(a, b);
    for (int i = 16; i >= 0; i--) {
        if (d[fa[a][i]] >= d[b]) a = fa[a][i];
    }
    if (a == b) return a;
    for (int i = 16; i >= 0; i--) {
        if (fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i];
    }
    return fa[a][0];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    memset(h, -1, sizeof(h));
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    dfs1(1, 0);
    dfs2(1, 0);
    while (m--) {
        int x, y;
        cin >> x >> y;
        cout << (f[x] + g[x] + f[y] + g[y] - n * (d[x] + d[y] - 2 * d[lca(x, y)])) / 2 << '\n';
    }
    
    return 0;
}

 

参考资料

  牛客周赛64题解:https://ac.nowcoder.com/discuss/1421788

posted @ 2024-10-23 20:50  onlyblues  阅读(150)  评论(0)    收藏  举报
Web Analytics