算法总结—最近公共祖先1

最近公共祖先定义

树上的两个点 \(a\)\(b\),它们的祖先中相同且深度最深的节点就是 \(a\)\(b\) 的最近公共祖。

求法

1.暴力

\(a\)\(b\) 的所有祖先都标记上,在从下往上找第一个两个都标记了的点。单次询问的时间复杂度 \(\mathcal{O}(dep_a+dep_b)\)

1.5.暴力优化

  • 如果 \(a\) 的深度 \(dep_a\)\(b\) 的深度 \(dep_b\) 大,让 \(a\) 不停的往上跳,直至 \(dep_a=dep_b\)

  • 如果 \(a\) 的深度 \(dep_a\)\(b\) 的深度 \(dep_b\) 小,让 \(b\) 不停的往上跳,直至 \(dep_a=dep_b\)

  • 如果 \(dep_a=dep_b\)\(a\)\(b\) 同时往上跳,直至 \(a=b\)

单次询问的时间复杂度 \(\mathcal{O}(dep_a+dep_b-2dep_{\text{LCA}})\)

2.倍增

预处理

预处理出 每个结点 \(i\) 的深度 \(dep_i\),和每个节点 \(i\) 往上走 \(2_j\) 步所到达的点 \(dp_{i,j}\)

\(dep\) 数组和好处理 \(dep_u=dep_{fa}+1\)。那么 \(dp\) 数组怎么转移呢,首先一个节点 \(i\) 往上跳 \(2_j\) 次等于 \(i\) 节点往上跳 \(2_{j-1}\) 所到达的节点往上再跳 \(2_{j-1}\) 次,所以 \(dp_{i,j}=dp_{dp_{i,j-1},j-1}\)\(dp_{i,0}=fa\)

预处理这一步的时间复杂度是 \(\mathcal{O}(n\log n)\)

查询

为了方便操作,可以先让 \(a\) 的深度更小:

if(dep[a]>dep[b]){
    swap(a,b);
}

再让 \(b\) 跳到和 \(a\) 同一深度,但不是一个节点一个节点去跳,而是利用 \(dp\) 数组跳,也就是对 \(dep_b-dep_a\) 进行二进制拆分:

for(int i=20;i>=0;i--){
    if(dep[dp[b][i]]>=dep[a]){
        y=dp[b][i];
    }
}

再次利用 \(dp\) 数组,让 \(a\)\(b\) 同时往上跳,但不可以重合,因为重合了只能保证是公共祖先,不能保证深度最深:

for(int i=20;i>=0;i--){
    if(dp[a][i]!=dp[b][i]){
        a=dp[a][i];
        b=dp[b][i];
    }
}

答案是 \(dp_{a,0}\)

单次询问的时间复杂度 \(\mathcal{O}(\log n)\)

【模板】最近公共祖先(LCA)

代码

#include<bits/stdc++.h>
using namespace std;
vector<int>G[500005];
int dep[500005];
int dp[500005][21];
void dfs(int u,int fa){
    dp[u][0]=fa;
    dep[u]=dep[fa]+1;
    for(int i=1;(1<<i)<=dep[u];i++){
        dp[u][i]=dp[dp[u][i-1]][i-1];
    }
    for(auto i:G[u]){
        int v=i;
        if(v==fa){
            continue;
        }
        dfs(v,u);
    }
    return;
}
int query(int x,int y){
    if(dep[x]>dep[y]){
        swap(x,y);
    }
    for(int i=20;i>=0;i--){
        if(dep[dp[y][i]]>=dep[x]){
            y=dp[y][i];
        }
    }
    if(x==y){
        return x;
    }
    for(int i=20;i>=0;i--){
        if(dp[x][i]!=dp[y][i]){
            x=dp[x][i];
            y=dp[y][i];
        }
    }
    return dp[x][0];
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n,q,s;
    cin>>n>>q>>s;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(s,0);
    while(q--){
        int x,y;
        cin>>x>>y;
        cout<<query(x,y)<<"\n";
    }
    return 0;
}

点的距离

先求出 \(x\)\(y\) 的最近公共祖先 \(LCA\),答案就是 \(x\)\(LCA\) 的距离加上 \(y\)\(LCA\) 的距离 \(dep_x-dep_{LCA}+dep_y-dep_{LCA}\)

代码

#include<bits/stdc++.h>
using namespace std;
vector<int>G[100005];
int dep[100005];
int dp[100005][21];
void dfs(int u,int fa){
    dp[u][0]=fa;
    dep[u]=dep[fa]+1;
    for(int i=1;(1<<i)<=dep[u];i++){
        dp[u][i]=dp[dp[u][i-1]][i-1];
    }
    for(auto i:G[u]){
        int v=i;
        if(v==fa){
            continue;
        }
        dfs(v,u);
    }
    return;
}
int query(int x,int y){
    if(dep[x]>dep[y]){
        swap(x,y);
    }
    for(int i=20;i>=0;i--){
        if(dep[dp[y][i]]>=dep[x]){
            y=dp[y][i];
        }
    }
    if(x==y){
        return x;
    }
    for(int i=20;i>=0;i--){
        if(dp[x][i]!=dp[y][i]){
            x=dp[x][i];
            y=dp[y][i];
        }
    }
    return dp[x][0];
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1,0);
    int q;
    cin>>q;
    while(q--){
        int x,y;
        cin>>x>>y;
        int lca=query(x,y);
        cout<<dep[x]-dep[lca]+dep[y]-dep[lca]<<"\n";
    }
    return 0;
}

Dis

先求出每个节点 \(i\) 到根节点的距离 \(dis_i\),然后求出 \(x\)\(y\) 的最近公共祖先 \(LCA\),答案就是 \(x\)\(LCA\) 的距离加上 \(y\)\(LCA\) 的距离 \(dis_x-dis_{LCA}+dis_y-dis_{LCA}\)

代码

#include<bits/stdc++.h>
using namespace std;
struct edge{
    int v,w;
};
vector<edge>G[10005];
int dep[10005];
int dp[10005][21];
int dis[10005];
void dfs(int u,int fa){
    dp[u][0]=fa;
    dep[u]=dep[fa]+1;
    for(int i=1;(1<<i)<=dep[u];i++){
        dp[u][i]=dp[dp[u][i-1]][i-1];
    }
    for(auto i:G[u]){
        int v=i.v;
        int w=i.w;
        if(v==fa){
            continue;
        }
        dis[v]=dis[u]+w;
        dfs(v,u);
    }
    return;
}
int query(int x,int y){
    if(dep[x]>dep[y]){
        swap(x,y);
    }
    for(int i=20;i>=0;i--){
        if(dep[dp[y][i]]>=dep[x]){
            y=dp[y][i];
        }
    }
    if(x==y){
        return x;
    }
    for(int i=20;i>=0;i--){
        if(dp[x][i]!=dp[y][i]){
            x=dp[x][i];
            y=dp[y][i];
        }
    }
    return dp[x][0];
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n,q;
    cin>>n>>q;
    for(int i=1;i<n;i++){
        int u,v,w;
        cin>>u>>v>>w;
        G[u].push_back({v,w});
        G[v].push_back({u,w});
    }
    dfs(1,0);
    while(q--){
        int x,y;
        cin>>x>>y;
        int lca=query(x,y);
        cout<<dis[x]-dis[lca]+dis[y]-dis[lca]<<"\n";
    }
    return 0;
}

祖孙询问

先求出 \(x\)\(y\) 的最近公共祖先 \(LCA\)

  • 如果 \(x=LCA\),则 \(x\)\(y\) 的祖先,输出 1

  • 如果 \(y=LCA\),则 \(y\)\(x\) 的祖先,输出 2

  • 否则 \(x\)\(y\) 没有任何关系输出 0

代码

#include<bits/stdc++.h>
using namespace std;
vector<int>G[100005];
int dep[100005];
int dp[100005][21];
void dfs(int u,int fa){
    dp[u][0]=fa;
    dep[u]=dep[fa]+1;
    for(int i=1;(1<<i)<=dep[u];i++){
        dp[u][i]=dp[dp[u][i-1]][i-1];
    }
    for(auto i:G[u]){
        int v=i;
        if(v==fa){
            continue;
        }
        dfs(v,u);
    }
    return;
}
int query(int x,int y){
    if(dep[x]>dep[y]){
        swap(x,y);
    }
    for(int i=20;i>=0;i--){
        if(dep[dp[y][i]]>=dep[x]){
            y=dp[y][i];
        }
    }
    if(x==y){
        return x;
    }
    for(int i=20;i>=0;i--){
        if(dp[x][i]!=dp[y][i]){
            x=dp[x][i];
            y=dp[y][i];
        }
    }
    return dp[x][0];
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin>>n;
    int rt=0;
    for(int i=1;i<=n;i++){
        int u,v;
        cin>>u>>v;
        if(v==-1){
            rt=u;
            continue;
        }
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(rt,0);
    int q;
    cin>>q;
    while(q--){
        int x,y;
        cin>>x>>y;
        int lca=query(x,y);
        if(lca==x){
            cout<<1<<"\n";
        }else if(lca==y){
            cout<<2<<"\n";
        }else{
            cout<<0<<"\n";
        }
    }
    return 0;
} 
posted @ 2025-03-24 12:46  LRRabcd  阅读(23)  评论(0)    收藏  举报