LCA最近公共祖先模板(求树上任意两个节点的最短距离 || 求两个点的路进(有且只有唯一的一条))

原理可以参考大神

LCA_Tarjan (离线)

  TarjanTarjan 算法求 LCA 的时间复杂度为 O(n+q) ,是一种离线算法,要用到并查集。(注:这里的复杂度其实应该不是 O(n+q) ,还需要考虑并查集操作的复杂度 ,但是由于在多数情况下,路径压缩并查集的单次操作复杂度可以看做 O(1),所以写成了 O(n+q)。)

 

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
struct EDGE{ int v, nxt, w; }Edge[maxn<<1];
struct Query{ int v, id;
    Query(){};
    Query(int _v, int _id):v(_v),id(_id){};
}; vector<Query> q[maxn];

int Head[maxn], cnt;
int Fa[maxn];///并查集数组
int ans[maxn];///问询数数组大小要注意一下、不一定是 maxn
bool vis[maxn];///Tarjan算法中的标记数组
int n, m, s, qNum;///点、边、Tarjan递归起点、问询数

inline void init()
{
    memset(Head, -1, sizeof(Head));
    memset(vis, false, sizeof(vis));
    cnt = 0;
}

inline void AddEdge(int from, int to)
{
    Edge[cnt].v = to;
    Edge[cnt].nxt = Head[from];
    Head[from] = cnt++;
}

int Findset(int x)
{
    int root = x;
    while(Fa[root] != root) root = Fa[root];

    int tmp;
    while(Fa[x] != root){
        tmp = Fa[x];
        Fa[x] = root;
        x = tmp;
    }

    return root;
}

void Tarjan(int v, int f)
{
    Fa[v] = v;
    for(int i=Head[v]; i!=-1; i=Edge[i].nxt){
        int Eiv = Edge[i].v;
        if(Eiv == f) continue;
        Tarjan(Eiv, v);
        Fa[Findset(Eiv)] = v;
    }
    vis[v] = true;
    for(int i=0; i<q[v].size(); i++){
        if(vis[q[v][i].v])
            ans[q[v][i].id] = Findset(q[v][i].v);
    }
}

int main(void)
{
    init();
    scanf("%d %d %d %d", &n, &m, &s, &qNum);
    for(int i=1; i<=m; i++){
        int u, v;
        scanf("%d %d", &u, &v);
        AddEdge(u, v);
        AddEdge(v, u);
    }
    for(int i=0; i<q; i++){
        int u, v;
        scanf("%d %d", &u, &v);
        q[u].push_back(Query(v, i));
        q[v].push_back(Query(u, i));
    }
    Tarjan(s, -1);
    for(int i=0; i<q; i++) printf("%d\n", ans[i]);
    return 0;
}
View Code

 

 

 

倍增

  我们可以用倍增来在线求 LCA ,时间和空间复杂度分别是 O((n+q)logn) 和 O(nlogn) 。

  对于这个算法,我们从最暴力的算法开始:

    ①如果 aa 和 bb 深度不同,先把深度调浅,使他变得和浅的那个一样

    ②现在已经保证了 aa 和 bb 的深度一样,所以我们只要把两个一起一步一步往上移动,直到他们到达同一个节点,也就是他们的最近公共祖先了。

#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
using namespace std;
const int N=10000+5;
vector <int> son[N];
int T,n,depth[N],fa[N],in[N],a,b;
void dfs(int prev,int rt){
    depth[rt]=depth[prev]+1;
    fa[rt]=prev;
    for (int i=0;i<son[rt].size();i++)
        dfs(rt,son[rt][i]);
}
int LCA(int a,int b){
    if (depth[a]>depth[b])
        swap(a,b);
    while (depth[b]>depth[a])
        b=fa[b];
    while (a!=b)
        a=fa[a],b=fa[b];
    return a;
}
int main(){
    scanf("%d",&T);
    while (T--){
        scanf("%d",&n);
        for (int i=1;i<=n;i++)
            son[i].clear();
        memset(in,0,sizeof in);
        for (int i=1;i<n;i++){
            scanf("%d%d",&a,&b);
            son[a].push_back(b);
            in[b]++;
        }
        depth[0]=-1;
        int rt=0;
        for (int i=1;i<=n&&rt==0;i++)
            if (in[i]==0)
                rt=i;
        dfs(0,rt);
        scanf("%d%d",&a,&b);
        printf("%d\n",LCA(a,b));
    }
    return 0;
}
View Code

 

优化

 1. 把 aa 和 bb 移到同一深度(设 depthx 为节点 x 的深度),假设 depthadepthbdeptha≤depthb ,这个时候,之前预处理的 fafa 数组就派上用场了。从大到小枚举 kk ,如果 bb 向上跳 2k2k 得到的节点的深度 deptha≥deptha ,那么 bb 就往上跳。

  2.如果 a=ba=b ,那么显然 LCA 就是 aa。否则执行第 3 步。

  3.这一步的主要目的是 :分别找到最浅的 aa′ 和 bb′ ,并且 aba′≠b′ 。

    利用之前的那个性质,再利用倍增,从大到小枚举 kk ,如果对于当前的 kk , aa 和 bb 的第 2k个祖先不同,那么 aa 和 bb 都跳到其 2k 祖先的位置。LCA 就是 faa,0或者 fab,0

 

#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
using namespace std;
const int N=10000+5;
vector <int> son[N];
int T,n,depth[N],fa[N][20],in[N],a,b;
void dfs(int prev,int rt){
    depth[rt]=depth[prev]+1;
    fa[rt][0]=prev;
    for (int i=1;i<20;i++)
        fa[rt][i]=fa[fa[rt][i-1]][i-1];
    for (int i=0;i<son[rt].size();i++)
        dfs(rt,son[rt][i]);
}
int LCA(int x,int y){
    if (depth[x]<depth[y])
        swap(x,y);
    for (int i=19;i>=0;i--)
        if (depth[x]-(1<<i)>=depth[y])
            x=fa[x][i];
    if (x==y)
        return x;
    for (int i=19;i>=0;i--)
        if (fa[x][i]!=fa[y][i])
            x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
int main(){
    scanf("%d",&T);
    while (T--){
        scanf("%d",&n);
        for (int i=1;i<=n;i++)
            son[i].clear();
        memset(in,0,sizeof in);
        for (int i=1;i<n;i++){
            scanf("%d%d",&a,&b);
            son[a].push_back(b);
            in[b]++;
        }
        depth[0]=-1;
        int rt=0;
        for (int i=1;i<=n&&rt==0;i++)
            if (in[i]==0)
                rt=i;
        dfs(0,rt);
        scanf("%d%d",&a,&b);
        printf("%d\n",LCA(a,b));
    }
    return 0;
}
View Code

 

RMQ

  现在来介绍一种 O(nlogn)O(nlog⁡n) 预处理,O(1)O(1) 在线查询的算法。

  RMQ 的意思大概是“区间最值查询”。顾名思义,用 RMQ 来求 LCA 是通过 RMQ 来实现的。

//CodeVS2370
#include <bits/stdc++.h>
#define time _____time
using namespace std;
const int N=50005;
struct Gragh{
    int cnt,y[N*2],z[N*2],nxt[N*2],fst[N];
    void clear(){
        cnt=0;
        memset(fst,0,sizeof fst);
    }
    void add(int a,int b,int c){
        y[++cnt]=b,z[cnt]=c,nxt[cnt]=fst[a],fst[a]=cnt;
    }
}g;
int n,m,depth[N],in[N],out[N],time;
int ST[N*2][20];
void dfs(int x,int pre){
    in[x]=++time;
    ST[time][0]=x;
    for (int i=g.fst[x];i;i=g.nxt[i])
        if (g.y[i]!=pre){
            depth[g.y[i]]=depth[x]+g.z[i];
            dfs(g.y[i],x);
            ST[++time][0]=x;
        }
    out[x]=time;
}
void Get_ST(int n){
    for (int i=1;i<=n;i++)
        for (int j=1;j<20;j++){
            ST[i][j]=ST[i][j-1];
            int v=i-(1<<(j-1));
            if (v>0&&depth[ST[v][j-1]]<depth[ST[i][j]])
                ST[i][j]=ST[v][j-1];
        }
}
int RMQ(int L,int R){
    int val=floor(log(R-L+1)/log(2));
    int x=ST[L+(1<<val)-1][val],y=ST[R][val];
    if (depth[x]<depth[y])
        return x;
    else
        return y;
}
int main(){
    scanf("%d",&n);
    for (int i=1,a,b,c;i<n;i++){
        scanf("%d%d%d",&a,&b,&c);
        a++,b++;
        g.add(a,b,c);
        g.add(b,a,c);
    }
    time=0;
    dfs(1,0);
    depth[0]=1000000;
    Get_ST(time);
    scanf("%d",&m);
    while (m--){
        int x,y;
        scanf("%d%d",&x,&y);
        if (in[x+1]>in[y+1])
            swap(x,y);
        int LCA=RMQ(in[x+1],in[y+1]);
        printf("%d\n",depth[x+1]+depth[y+1]-depth[LCA]*2);
    }
    return 0;
}
View Code

 

posted @ 2018-09-20 17:21  shuai_hui  阅读(464)  评论(0编辑  收藏  举报