lca算法

#include <bits/stdc++.h>

using namespace std;

int main() {
  int n, m, s;
  scanf("%d%d%d", &n, &m, &s);
  
  int N = 20;
  vector<vector<int>> adj(n + 1);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d%d", &u, &v);

    adj[u].push_back(v);
    adj[v].push_back(u);
  }
  
  vector<int> dep(n + 1);
  vector<vector<int>> fa(n + 1, vector<int>(N + 1));

  function <void(int)> dfs = [&](int u) {
    for (auto v : adj[u]) {
      if (v == fa[u][0]) {
        continue;
      }
      dep[v] = dep[u] + 1;
      fa[v][0] = u;
      dfs(v);
    }
  };

  dfs(s);
  
  // init
  for (int i = 1; i <= N; i++) {
    for (int j = 1; j <= n; j++) {
      if (fa[j][i - 1]) {
        fa[j][i] = fa[fa[j][i - 1]][i - 1];
      }
    }
  }

  while (m--) {
    int x, y;
    scanf("%d%d", &x, &y);
    if (dep[x] < dep[y]) {
      swap(x, y);
    }

    int z = dep[x] - dep[y];
    for (int j = 0; j <= N && z; j++, z /= 2) {
      if (z & 1) {
        x = fa[x][j];
      }
    }
    if (x == y) {
      printf("%d\n", x);
    } else {
      for (int j = N; j >= 0; j--) {
        if (fa[x][j] != fa[y][j]) {
          x = fa[x][j], y = fa[y][j];
        }
      }
      printf("%d\n", fa[x][0]);
    }
  }
}
posted @ 2023-06-02 16:37  hacker_dvd  阅读(28)  评论(0)    收藏  举报