20251220 - LCA 总结

20251220 - LCA 总结

定义

在有根树上,对于两个点 uv,这两个点的所有公共祖先中,距离根节点最远的节点,就是 uv 的最近公共祖先(LCA)。

对于一个点集,这些点的所有公共祖先中,距离根节点最远的节点,就是这些点的LCA。

暴力做法

方法一:

首先把每一个点的祖先序列给求出来,再求 LCA。

Joler 老师直呼,太暴力了!

预处理时间复杂度:\(O(n)\)

单次询问时间复杂度:\(O(n)\)

代码:

int solve(int x, int y) {
    if (x == 1 || y == 1) return 1;
    vector <int> v1, v2;
    for (; x != 1; x = fa[x]) v1.push_back(x);
    v1.push_back(1);
    reverse(v1.begin(), v1.end());
    for (; y != 1; y = fa[y]) v2.push_back(y);
    v2.push_back(1);
    reverse(v2.begin(), v2.end());
    int ans = 1;
    for (int i = 1; i <= (int)v1.size() && i <= (int)v2.size(); i++) {
        if(v1[i] == v2[i]) {
            ans = v1[i];
        }else {
            break;
        }
    }
    return ans;
}

方法二:

让深度较大的点向上跳,直到他们的深度相同,再同时向上跳。

int solve(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
   	while (1) {
        if (dep[x] == dep[y]) break;
        x = fa[x];
    }
    while (x != y) {
        x = fa[x], y = fa[y];
    }
    return x;
}

倍增

预处理出所有点的深度,以及每个点的 \(2^k\) 级父节点。

基于暴力枚举的LCA,使用倍增优化向上跳的过程。

预处理时间复杂度:\(O(n\log_2n)\)

单次询问时间复杂度:\(O(\log_2n)\)

代码:

int dep[N], fa[N][21];
inline void dfs(int u, int from) {
    for (auto v : edges[u]) {
        if (v == from) continue;
        dep[v] = dep[u] + 1;
        fa[v][0] = u;
        dfs(v, u);
    }
}
void init() {
    dfs(1, 1);
    for (int j = 1; j <= 20; j++) {
        for (int i = 1; i <= n; i++) {
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
        }
    }
}
int solve(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    int d = dep[x] - dep[y];
    for (int i = 0; d; i++, d >>= 1) {
        if (d & 1) {
            x = fa[x][i];
        }
    }
    if (x == y) return x;
    for (int i = 20; i >= 0; i--) {
        if (fa[x][i] != fa[y][i]) {
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    return fa[x][0];
}

dfs 序

就是欧拉序少了回溯。

dfs 序有一个性质,\([l, r]\)\(dfn\) 的最小的深度的父亲就是 \(l\)\(r\) 的 LCA。

但是,如果 \(r\)\(l\) 的祖先,那么就挂了,所以就搞 \([l + 1, r]\) 就好了。

代码:

vector <int> edges[N];
inline void dfs(int u, int from) {
  dfn[u] = ++idx;
  id[idx] = u;
  for (auto v : edges[u]) {
    if (v == from) continue;
    fa[v] = u;
    dep[v] = dep[u] + 1;
    dfs(v, u);
  }
}
int MIN(int x, int y) {
  if (dep[x] < dep[y]) return x;
  return y;
}
void init() {
  lg[0] = -1;
  for (int i = 1; i <= n; i++) {
    f[i][0] = id[i];
    lg[i] = lg[i >> 1] + 1;
  }
  for (int j = 1; (1 << j) <= n; j++) {
    for (int i = 1; i + (1 << j) - 1 <= n; i++) {
      f[i][j] = MIN(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
    }
  }
}
int get_lca(int x, int y) {
  if (x == y) return x;
  int l = dfn[x], r = dfn[y];
  if (l > r) swap(l, r);
  l++;
  int j = lg[r - l + 1];
  return fa[MIN(f[l][j], f[r - (1 << j) + 1][j])];
}

Tarjan 算法

Tarjan 算法是一种 离线算法,需要使用 并查集 记录某个结点的祖先结点。

如果一个点访问过了,把 y 并到 x 上。

时间复杂度:\(O(n + mα(n + m, n))\)

int find(int x) {
    if (fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}
void Tarjan(int u) {
    vis[u] = 1;
    for (auto v : edges[u]) {
        if (vis[v]) continue;
        Tarjan(v);
        fa[v] = u;
    }
    for (auto [v, id] : query[u]) {
        if (vis[v]) ans[id] = find(v);
    }
}

例题:

B - 大量的工作沟通

求出每个点 LCA,再向上跳求最大的编号。

代码:

#include <bits/stdc++.h>

using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 1e5 + 7;
const int P = 998244353;
int read() {
  int x = 0, f = 1;
  char ch = getchar();
  while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
  while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
  return x * f;
}
int n, q, fa[N][21], dep[N];
vector <int> edges[N];
inline void dfs(int u, int from) {
  for (auto v : edges[u]) {
    if (v == from) continue;
    fa[v][0] = u;
    dep[v] = dep[u] + 1;
    dfs(v, u);
  }
}
int get_lca(int x, int y) {
  if (dep[x] < dep[y]) swap(x, y);
  int d = dep[x] - dep[y];
  for (int i = 0; d; i++, d >>= 1) {
    if (d & 1) {
      x = fa[x][i];
    }
  }
  if (x == y) return x;
  for (int i = 20; i >= 0; i--) {
    if (fa[x][i] != fa[y][i]) {
      x = fa[x][i];
      y = fa[y][i];
    }
  }
  return fa[x][0];
}
void solve() { 
  n = read();
  for (int i = 1; i < n; i++) {
    int x = read();
    edges[x].push_back(i);
    edges[i].push_back(x);
  }
  dfs(0, 0);
  for (int j = 1; j <= 20; j++) {
    for (int i = 0; i < n; i++) {
      fa[i][j] = fa[fa[i][j - 1]][j - 1];
    }
  }
  q = read();
  while (q--) {
    int t = read();
    int lca = 0;
    for (int i = 0; i < t; i++) {
      int x = read();
      if (i == 0) lca = x;
      else lca = get_lca(lca, x);
    }
    // printf("%d\n", lca);
    int ans = 0;
    while (lca != 0) {
      ans = max(ans, lca);
      lca = fa[lca][0];
    }
    printf("%d\n", ans); 
  }
}
int main() {
  int oT_To = 1;
  while (oT_To--) solve();
  return 0;
}

D - Milk Visits S

思路一:

记录每一个点到根节点的 G 或 H 的数量,再求一下路径和。

警示后人:请一定要判断 LCA 节点。

思路二

倍增 LCA 的同时求出是否是纯的 G 或 H 就好了。

思路三(考场做法)

把 G 当做点权 1,H 当做点权 0。

如果路径最大值是 1,那么就存在 G。

如果路径最小值是 0,那么就存在 H。

路径最大值和路径最小值求法详见 CF609E Minimum spanning tree for each edge

代码:

#include <bits/stdc++.h>

using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 1e5 + 7, M = N * 2;
const int P = 998244353;
int read() {
	int x = 0, f = 1;
	char ch = getchar();
	while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
	return x * f;
}
int idx = 0;
struct TreeNode {
	int to;
	TreeNode *next;
} *head[M], G[M];
void add(int x, int y) {
	idx++;
	G[idx].to = y;
	G[idx].next = head[x];
	head[x] = &G[idx];
}
int n, q, a[N], dp[N][21], f[N][21], v[N][21], dep[N];
inline void dfs(int u, int from) {
	for (TreeNode *nxt = head[u]; nxt; nxt = nxt->next) {
		int v = nxt->to;
		if (v == from) continue;
		dp[v][0] = u;
		dep[v] = dep[u] + 1;
		dfs(v, u);
	}
}
void init() {
	memset(f, 127, sizeof(f));
	for (int i = 1; i <= n; i++) {
		f[i][0] = a[i];
    v[i][0] = a[i];
	}
	for (int j = 1; j < 20; j++)
		for (int i = 1; i <= n; i++) {
			dp[i][j] = dp[dp[i][j - 1]][j - 1];
			f[i][j] = min(f[i][j - 1], f[dp[i][j - 1]][j - 1]);
      		v[i][j] = max(v[i][j - 1], v[dp[i][j - 1]][j - 1]);
		}
} 
int Lca(int x, int y) {
	int ans = inf;
	if (dep[x] < dep[y]) swap(x, y);
	int d = dep[x] - dep[y];
	for (int i = 0; d; i++, d >>= 1) {
		if (d & 1) {
			ans = min(ans, f[x][i]);
			x = dp[x][i];
		}
	}
	if (x != y) {
		for (int i = 19; i >= 0; i--) {
			if (dp[x][i] != dp[y][i]) {
				ans = min({ans, f[x][i], f[y][i]});
				x = dp[x][i];
				y = dp[y][i];
			}
		}
		ans = min({ans, f[x][0], f[y][0]});
		x = dp[x][0];
	}
	return min(ans, f[x][0]);
}
int lca(int x, int y) {
	int ans = 0;
	if (dep[x] < dep[y]) swap(x, y);
	int d = dep[x] - dep[y];
	for (int i = 0; d; i++, d >>= 1) {
		if (d & 1) {
			ans = max(ans, v[x][i]);
			x = dp[x][i];
		}
	}
	if (x != y) {
		for (int i = 19; i >= 0; i--) {
			if (dp[x][i] != dp[y][i]) {
				ans = max({ans, v[x][i], v[y][i]});
				x = dp[x][i];
				y = dp[y][i];
			}
		}
		ans = max({ans, v[x][0], v[y][0]});
		x = dp[x][0];
	}
	return max(ans, v[x][0]);
}
char str[N];
void solve() {
  n = read(), q = read();
  scanf("%s", str + 1);
  for (int i = 1; i <= n; i++) {
    a[i] = str[i] == 'G' ? 1 : 0;
  }
  for (int i = 1; i < n; i++) {
    int x = read(), y = read();
    add(x, y);
    add(y, x);
  } 
  dfs(1, 1);
  init();
  while (q--) {
    int x, y;
    char ch;
    scanf("%d %d %c", &x, &y, &ch);
    // cout << x << " " << y << " " << ch << endl;
    if (ch == 'G') {
      int d = lca(x, y);
      if (d == 1) putchar('1');
      else putchar('0');
    }else {
      int d = Lca(x, y);
      if (d == 0) putchar('1');
      else putchar('0');
    }
  }
}
int main() {
	int oT_To = 1;
  while (oT_To--) solve();
	return 0;
}

I - The Shortest Statement

可以发现,如果搞出一个生成树,那么最多剩下 \(21\) 条边,\(42\) 个点。

如果最短路是在生成树上,用 LCA 秒了。

如果是剩下的点,暴力跑 Dijkstra 就好了。

#include <bits/stdc++.h>

using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1LL << 60)
#define pb push_back
typedef pair<int, int> PII;
#define int ll
const int N = 1e5 + 7;
const int P = 998244353;
int read() {
  int x = 0, f = 1;
  char ch = getchar();
  while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
  while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
  return x * f;
}
int n, m, f[N], dep[N], fa[N][21], dc[N];
vector <pair<int, int>> edges[N], G[N];
int find(int x) {
  if (f[x] == x) return x;
  return f[x] = find(f[x]);
}
inline void dfs(int u, int from) {
  for (auto [v, w] : edges[u]) {
    if (v == from) continue;
    fa[v][0] = u;
    dep[v] = dep[u] + 1;
    dc[v] = dc[u] + w;
    dfs(v, u);
  }
}
int get_lca(int x, int y) {
  if (dep[x] < dep[y]) swap(x, y);
  int d = dep[x] - dep[y];
  for (int i = 0; d; i++, d >>= 1) {
    if (d & 1) {
      x = fa[x][i];
    }
  }
  if (x == y) return x;
  for (int i = 20; i >= 0; i--) {
    if (fa[x][i] != fa[y][i]) {
      x = fa[x][i];
      y = fa[y][i];
    }
  }
  return fa[x][0];
}
void solve() {
  n = read(), m = read();
  for (int i = 1; i <= n; i++) {
    f[i] = i;
  }
  set <int> st;
  for (int i = 1; i <= m; i++) {
    int x = read(), y = read(), w = read();
    int xx = find(x), yy = find(y);
    G[x].push_back({y, w}); 
    G[y].push_back({x, w});
    if (xx != yy) {
      edges[x].push_back({y, w});
      edges[y].push_back({x, w});
      f[xx] = yy;
    }else {
      st.insert(x);
      st.insert(y);
    }
  }
  dfs(1, 1);
  for (int j = 1; j <= 20; j++) {
    for (int i = 1; i <= n; i++) {
      fa[i][j] = fa[fa[i][j - 1]][j - 1];
    }
  }
  map <int, vector<int>> mp;
  auto Dijkstra = [&](int s) {
    auto &dist = mp[s];
    dist = vector <int> (n + 1, inf);
    priority_queue <pair<int, int>, vector <pair<int, int>>, greater<pair<int, int>>> q;
    vector <int> vis(n + 1, 0);
    dist[s] = 0;
    q.push({0, s});
    while (!q.empty()) {
      int u = q.top().second;
      q.pop();
      if (vis[u]) continue;
      vis[u] = 1;
      for (auto [v, w] : G[u]) {
        if (dist[u] + w < dist[v]) {
          dist[v] = dist[u] + w;
          q.push({dist[v], v});
        }
      }
    }
    return;
  };
  for (auto i : st) Dijkstra(i);
  int q = read();
  while (q--) {
    int x = read(), y = read();
    int ans = dc[x] + dc[y] - 2LL * dc[get_lca(x, y)];
    for (auto i : st) {
      ans = min(ans, mp[i][x] + mp[i][y]);
    }
    printf("%lld\n", ans);
  }
} 
signed main() {
  int oT_To = 1;
  while (oT_To--) solve();
  return 0;
}

后记

LCA 可以求树上路径,是一个很好的算法!(Tarjan 太巨了!)

posted @ 2025-12-21 13:23  AKCoder  阅读(5)  评论(0)    收藏  举报