题解 树
题目:
给定一颗 \(n\) 个节点的树,树上的边权要不是 \(1\) 要不是 \(2\)。有 \(m\) 组询问,对于每组询问 \(u,v,c\),问在每步最多走距离 \(c\) 的条件下,从 \(u\) 到 \(v\) 最少要多少步。中途不能停留在边上。
\(1\le n,m\le 50000,2\le c\)
题解:
首先每次一定是能走多远走多远。
然后看到五万的核 善数据范围,可以联想到根号分治。
首先定一个阈值 \(S\)。
- \(c\leq S\)
预处理出 \(nxt[i][j]\) 表示步长为 \(j\) 时,由 \(i\) 往祖先方向走一步,会到达哪个点。
但是如果直接非常暴躁地根据 \(nxt\) 跳肯定不行,所以我们要再预处理出 \(jump[i][j]\) 表示步长为 \(j\) 时 \(i\) 往上走 \(S\) 步会到达哪个点。
其实叫 jump 很形象,走了很多步,就有了跳一步的距离。
- \(c> S\)
这个时候可以用你能想到的最暴躁的方法暴力,比如一步一步走。这里一步可以走很远,所以用倍增优化。
复杂度:\(O(nS+\frac{n^2}{S}\log n)\),所以 \(S\) 取 \(\sqrt{n\log n}\) 时最优。
但是别忘了有个 \(O(nS)\) 的空间复杂度,这题空限只有 \(128\) MB,所以实践可得 \(S\) 取 \(300\) 是在不 MLE 时的最大块长。
#pragma GCC optimize(3)
#include <cstdio>
struct Edge {int to, nxt, w;} e[100005];
const int S = 300;
int fa[50005][17], dep[50005], dis[50005], head[50005], nxt[50005][S + 5], jump[50005][S + 5], tot;
int stk[50005], top;
inline void AddEdge(int u, int v, int w) {
e[++ tot].to = v, e[tot].nxt = head[u], e[tot].w = w, head[u] = tot;
}
void dfs(int u) {
dep[u] = dep[fa[u][0]] + 1;
for (int i = 1; i <= 16; ++ i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = 2, v = (u == 1 ? 1 : fa[u][0]); i <= S; ++ i) {
if (v != 1 && dis[u] - dis[fa[v][0]] <= i) v = fa[v][0];
nxt[u][i] = v;
}
for (int i = head[u]; i; i = e[i].nxt)
if (e[i].to != fa[u][0]) fa[e[i].to][0] = u, dis[e[i].to] = dis[u] + e[i].w, dfs(e[i].to);
}
void DFS(int u, int len) {
stk[++ top] = u;
jump[u][len] = stk[top > S ? top - S : 1];
for (int i = head[u]; i; i = e[i].nxt) DFS(e[i].to, len);
-- top;
}
int LCA(int u, int v) {
if (dep[u] < dep[v]) u ^= v ^= u ^= v;
int t = dep[u] - dep[v];
for (int i = 0; i <= 16; ++ i)
if (t & 1 << i) u = fa[u][i];
if (u == v) return u;
for (int i = 16; i >= 0; -- i)
if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
int main() {
int n, m;
scanf("%d", &n);
for (int i = 1, u, v, w; i < n; ++ i)
scanf("%d%d%d", &u, &v, &w), AddEdge(u, v, w), AddEdge(v, u, w);
dfs(1);
for (int i = 2; i <= S; ++ i) {
for (int j = 1; j <= n; ++ j) head[j] = 0;
tot = 0;
for (int j = 2; j <= n; ++ j) AddEdge(nxt[j][i], j, 0);
DFS(1, i);
}
scanf("%d", &m);
while (m --) {
int u, v, c, lca;
scanf("%d%d%d", &u, &v, &c);
if (u == v) {puts("0"); continue;}
lca = LCA(u, v);
if (c <= S) {
int ans = 1;
while (dep[jump[u][c]] > dep[lca]) u = jump[u][c], ans += S;
while (dep[nxt[u][c]] > dep[lca]) u = nxt[u][c], ++ ans;
while (dep[jump[v][c]] > dep[lca]) v = jump[v][c], ans += S;
while (dep[nxt[v][c]] > dep[lca]) v = nxt[v][c], ++ ans;
if (dis[u] + dis[v] - 2 * dis[lca] > c) ++ ans;
printf("%d\n", ans);
} else {
int ans = 1;
while (true) {
int Nxt = u;
for (int i = 16, j = c, k = fa[Nxt][16]; i >= 0 && j; -- i)
if (dis[Nxt] - dis[k = fa[Nxt][i]] <= j) j -= dis[Nxt] - dis[k], Nxt = k;
if (dep[Nxt] > dep[lca]) u = Nxt, ++ ans;
else break;
}
while (true) {
int Nxt = v;
for (int i = 16, j = c, k = fa[Nxt][16]; i >= 0 && j; -- i)
if (dis[Nxt] - dis[k = fa[Nxt][i]] <= j) j -= dis[Nxt] - dis[k], Nxt = k;
if (dep[Nxt] > dep[lca]) v = Nxt, ++ ans;
else break;
}
if (dis[u] + dis[v] - 2 * dis[lca] > c) ++ ans;
printf("%d\n", ans);
}
}
return 0;
}