20251220 - LCA 总结
20251220 - LCA 总结
定义
在有根树上,对于两个点 u 和 v,这两个点的所有公共祖先中,距离根节点最远的节点,就是 u 和v 的最近公共祖先(LCA)。
对于一个点集,这些点的所有公共祖先中,距离根节点最远的节点,就是这些点的LCA。
暴力做法
方法一:
首先把每一个点的祖先序列给求出来,再求 LCA。
Joler 老师直呼,太暴力了!
预处理时间复杂度:\(O(n)\)。
单次询问时间复杂度:\(O(n)\)。
代码:
int solve(int x, int y) {
if (x == 1 || y == 1) return 1;
vector <int> v1, v2;
for (; x != 1; x = fa[x]) v1.push_back(x);
v1.push_back(1);
reverse(v1.begin(), v1.end());
for (; y != 1; y = fa[y]) v2.push_back(y);
v2.push_back(1);
reverse(v2.begin(), v2.end());
int ans = 1;
for (int i = 1; i <= (int)v1.size() && i <= (int)v2.size(); i++) {
if(v1[i] == v2[i]) {
ans = v1[i];
}else {
break;
}
}
return ans;
}
方法二:
让深度较大的点向上跳,直到他们的深度相同,再同时向上跳。
int solve(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (1) {
if (dep[x] == dep[y]) break;
x = fa[x];
}
while (x != y) {
x = fa[x], y = fa[y];
}
return x;
}
倍增
预处理出所有点的深度,以及每个点的 \(2^k\) 级父节点。
基于暴力枚举的LCA,使用倍增优化向上跳的过程。
预处理时间复杂度:\(O(n\log_2n)\)。
单次询问时间复杂度:\(O(\log_2n)\)。
代码:
int dep[N], fa[N][21];
inline void dfs(int u, int from) {
for (auto v : edges[u]) {
if (v == from) continue;
dep[v] = dep[u] + 1;
fa[v][0] = u;
dfs(v, u);
}
}
void init() {
dfs(1, 1);
for (int j = 1; j <= 20; j++) {
for (int i = 1; i <= n; i++) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}
}
int solve(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int d = dep[x] - dep[y];
for (int i = 0; d; i++, d >>= 1) {
if (d & 1) {
x = fa[x][i];
}
}
if (x == y) return x;
for (int i = 20; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
dfs 序
就是欧拉序少了回溯。
dfs 序有一个性质,\([l, r]\) 的 \(dfn\) 的最小的深度的父亲就是 \(l\) 和 \(r\) 的 LCA。
但是,如果 \(r\) 是 \(l\) 的祖先,那么就挂了,所以就搞 \([l + 1, r]\) 就好了。
代码:
vector <int> edges[N];
inline void dfs(int u, int from) {
dfn[u] = ++idx;
id[idx] = u;
for (auto v : edges[u]) {
if (v == from) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs(v, u);
}
}
int MIN(int x, int y) {
if (dep[x] < dep[y]) return x;
return y;
}
void init() {
lg[0] = -1;
for (int i = 1; i <= n; i++) {
f[i][0] = id[i];
lg[i] = lg[i >> 1] + 1;
}
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = MIN(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}
int get_lca(int x, int y) {
if (x == y) return x;
int l = dfn[x], r = dfn[y];
if (l > r) swap(l, r);
l++;
int j = lg[r - l + 1];
return fa[MIN(f[l][j], f[r - (1 << j) + 1][j])];
}
Tarjan 算法
Tarjan 算法是一种 离线算法,需要使用 并查集 记录某个结点的祖先结点。
如果一个点访问过了,把 y 并到 x 上。
时间复杂度:\(O(n + mα(n + m, n))\)。
int find(int x) {
if (fa[x] == x) return x;
return fa[x] = find(fa[x]);
}
void Tarjan(int u) {
vis[u] = 1;
for (auto v : edges[u]) {
if (vis[v]) continue;
Tarjan(v);
fa[v] = u;
}
for (auto [v, id] : query[u]) {
if (vis[v]) ans[id] = find(v);
}
}
例题:
B - 大量的工作沟通
求出每个点 LCA,再向上跳求最大的编号。
代码:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 1e5 + 7;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int n, q, fa[N][21], dep[N];
vector <int> edges[N];
inline void dfs(int u, int from) {
for (auto v : edges[u]) {
if (v == from) continue;
fa[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v, u);
}
}
int get_lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int d = dep[x] - dep[y];
for (int i = 0; d; i++, d >>= 1) {
if (d & 1) {
x = fa[x][i];
}
}
if (x == y) return x;
for (int i = 20; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
void solve() {
n = read();
for (int i = 1; i < n; i++) {
int x = read();
edges[x].push_back(i);
edges[i].push_back(x);
}
dfs(0, 0);
for (int j = 1; j <= 20; j++) {
for (int i = 0; i < n; i++) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}
q = read();
while (q--) {
int t = read();
int lca = 0;
for (int i = 0; i < t; i++) {
int x = read();
if (i == 0) lca = x;
else lca = get_lca(lca, x);
}
// printf("%d\n", lca);
int ans = 0;
while (lca != 0) {
ans = max(ans, lca);
lca = fa[lca][0];
}
printf("%d\n", ans);
}
}
int main() {
int oT_To = 1;
while (oT_To--) solve();
return 0;
}
D - Milk Visits S
思路一:
记录每一个点到根节点的 G 或 H 的数量,再求一下路径和。
警示后人:请一定要判断 LCA 节点。
思路二
倍增 LCA 的同时求出是否是纯的 G 或 H 就好了。
思路三(考场做法)
把 G 当做点权 1,H 当做点权 0。
如果路径最大值是 1,那么就存在 G。
如果路径最小值是 0,那么就存在 H。
路径最大值和路径最小值求法详见 CF609E Minimum spanning tree for each edge。
代码:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 1e5 + 7, M = N * 2;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int idx = 0;
struct TreeNode {
int to;
TreeNode *next;
} *head[M], G[M];
void add(int x, int y) {
idx++;
G[idx].to = y;
G[idx].next = head[x];
head[x] = &G[idx];
}
int n, q, a[N], dp[N][21], f[N][21], v[N][21], dep[N];
inline void dfs(int u, int from) {
for (TreeNode *nxt = head[u]; nxt; nxt = nxt->next) {
int v = nxt->to;
if (v == from) continue;
dp[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v, u);
}
}
void init() {
memset(f, 127, sizeof(f));
for (int i = 1; i <= n; i++) {
f[i][0] = a[i];
v[i][0] = a[i];
}
for (int j = 1; j < 20; j++)
for (int i = 1; i <= n; i++) {
dp[i][j] = dp[dp[i][j - 1]][j - 1];
f[i][j] = min(f[i][j - 1], f[dp[i][j - 1]][j - 1]);
v[i][j] = max(v[i][j - 1], v[dp[i][j - 1]][j - 1]);
}
}
int Lca(int x, int y) {
int ans = inf;
if (dep[x] < dep[y]) swap(x, y);
int d = dep[x] - dep[y];
for (int i = 0; d; i++, d >>= 1) {
if (d & 1) {
ans = min(ans, f[x][i]);
x = dp[x][i];
}
}
if (x != y) {
for (int i = 19; i >= 0; i--) {
if (dp[x][i] != dp[y][i]) {
ans = min({ans, f[x][i], f[y][i]});
x = dp[x][i];
y = dp[y][i];
}
}
ans = min({ans, f[x][0], f[y][0]});
x = dp[x][0];
}
return min(ans, f[x][0]);
}
int lca(int x, int y) {
int ans = 0;
if (dep[x] < dep[y]) swap(x, y);
int d = dep[x] - dep[y];
for (int i = 0; d; i++, d >>= 1) {
if (d & 1) {
ans = max(ans, v[x][i]);
x = dp[x][i];
}
}
if (x != y) {
for (int i = 19; i >= 0; i--) {
if (dp[x][i] != dp[y][i]) {
ans = max({ans, v[x][i], v[y][i]});
x = dp[x][i];
y = dp[y][i];
}
}
ans = max({ans, v[x][0], v[y][0]});
x = dp[x][0];
}
return max(ans, v[x][0]);
}
char str[N];
void solve() {
n = read(), q = read();
scanf("%s", str + 1);
for (int i = 1; i <= n; i++) {
a[i] = str[i] == 'G' ? 1 : 0;
}
for (int i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y);
add(y, x);
}
dfs(1, 1);
init();
while (q--) {
int x, y;
char ch;
scanf("%d %d %c", &x, &y, &ch);
// cout << x << " " << y << " " << ch << endl;
if (ch == 'G') {
int d = lca(x, y);
if (d == 1) putchar('1');
else putchar('0');
}else {
int d = Lca(x, y);
if (d == 0) putchar('1');
else putchar('0');
}
}
}
int main() {
int oT_To = 1;
while (oT_To--) solve();
return 0;
}
I - The Shortest Statement
可以发现,如果搞出一个生成树,那么最多剩下 \(21\) 条边,\(42\) 个点。
如果最短路是在生成树上,用 LCA 秒了。
如果是剩下的点,暴力跑 Dijkstra 就好了。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1LL << 60)
#define pb push_back
typedef pair<int, int> PII;
#define int ll
const int N = 1e5 + 7;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int n, m, f[N], dep[N], fa[N][21], dc[N];
vector <pair<int, int>> edges[N], G[N];
int find(int x) {
if (f[x] == x) return x;
return f[x] = find(f[x]);
}
inline void dfs(int u, int from) {
for (auto [v, w] : edges[u]) {
if (v == from) continue;
fa[v][0] = u;
dep[v] = dep[u] + 1;
dc[v] = dc[u] + w;
dfs(v, u);
}
}
int get_lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int d = dep[x] - dep[y];
for (int i = 0; d; i++, d >>= 1) {
if (d & 1) {
x = fa[x][i];
}
}
if (x == y) return x;
for (int i = 20; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
void solve() {
n = read(), m = read();
for (int i = 1; i <= n; i++) {
f[i] = i;
}
set <int> st;
for (int i = 1; i <= m; i++) {
int x = read(), y = read(), w = read();
int xx = find(x), yy = find(y);
G[x].push_back({y, w});
G[y].push_back({x, w});
if (xx != yy) {
edges[x].push_back({y, w});
edges[y].push_back({x, w});
f[xx] = yy;
}else {
st.insert(x);
st.insert(y);
}
}
dfs(1, 1);
for (int j = 1; j <= 20; j++) {
for (int i = 1; i <= n; i++) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}
map <int, vector<int>> mp;
auto Dijkstra = [&](int s) {
auto &dist = mp[s];
dist = vector <int> (n + 1, inf);
priority_queue <pair<int, int>, vector <pair<int, int>>, greater<pair<int, int>>> q;
vector <int> vis(n + 1, 0);
dist[s] = 0;
q.push({0, s});
while (!q.empty()) {
int u = q.top().second;
q.pop();
if (vis[u]) continue;
vis[u] = 1;
for (auto [v, w] : G[u]) {
if (dist[u] + w < dist[v]) {
dist[v] = dist[u] + w;
q.push({dist[v], v});
}
}
}
return;
};
for (auto i : st) Dijkstra(i);
int q = read();
while (q--) {
int x = read(), y = read();
int ans = dc[x] + dc[y] - 2LL * dc[get_lca(x, y)];
for (auto i : st) {
ans = min(ans, mp[i][x] + mp[i][y]);
}
printf("%lld\n", ans);
}
}
signed main() {
int oT_To = 1;
while (oT_To--) solve();
return 0;
}
后记
LCA 可以求树上路径,是一个很好的算法!(Tarjan 太巨了!)

浙公网安备 33010602011771号