AcWing356 次小生成树(lca)

通过lca来计算两个点之间的最大最小值,这样比暴力要快

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=3e5+10;
const int inf=0x3f3f3f3f;
int h[N],ne[N],e[N],w[N],idx;
int p[N];
int depth[N];
int f[N][20];
int d1[N][20],d2[N][20];
int n,m;
int dis[N];
void add(int a,int b,int c){
    e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}
struct node{
    int a,b,c;
    int f;
    bool operator <(const node &t) const{
        return c<t.c;
    }
}s[N];
int find(int x){
    if(x!=p[x]){
        p[x]=find(p[x]);
    }
    return p[x];
}
ll kruscal(){
    sort(s+1,s+1+m);
    ll res=0;
    for(int i=1;i<=m;i++){
        int pa=find(s[i].a),pb=find(s[i].b);
        if(pa!=pb){
           p[pa]=pb;
           res+=s[i].c;
           s[i].f=1;
        }
    }
    return res;
}
void build(){
    int i;
    memset(h,-1,sizeof h);
    for(i=1;i<=m;i++){
        if(s[i].f){
            add(s[i].a,s[i].b,s[i].c);
            add(s[i].b,s[i].a,s[i].c);
        }
    }
}

void bfs(){
    memset(depth,0x3f,sizeof depth);
    int i;
    depth[0]=0;
    depth[1]=1;
    queue<int> q;
    q.push(1);
    while(q.size()){
        int t=q.front();
        q.pop();
        for(i=h[t];i!=-1;i=ne[i]){
            int j=e[i];
            if(depth[j]>depth[t]+1){
                depth[j]=depth[t]+1;
                q.push(j);
                f[j][0]=t;
                int k;
                d1[j][0]=w[i],d2[j][0]=-inf;
                for(k=1;k<=18;k++){
                    int pa=f[j][k-1];
                    int distance[4]={d1[j][k-1],d2[j][k-1],d1[pa][k-1],d2[pa][k-1]};
                    f[j][k]=f[pa][k-1];
                    d1[j][k] =d2[j][k]=-inf;
                    for(int u=0;u<4;u++){
                        if(distance[u]>d1[j][k]){
                            d2[j][k]=d1[j][k];
                            d1[j][k]=distance[u];
                        }
                        else if(distance[u]!=d1[j][k]&&distance[u]>d2[j][k]){
                            d2[j][k]=distance[u];
                        }
                    }
                }
            }
        }
    }
}
int lca(int a,int b,int c){
    if(depth[a]<depth[b])
        swap(a,b);
    int i;
    int cnt=0;
    for(i=18;i>=0;i--){
        if(depth[f[a][i]]>=depth[b]){
           dis[cnt++]=d1[a][i];
            dis[cnt++]=d2[a][i];
            a=f[a][i]; 
        }
        
    }
    if(a!=b){
        for(i=18;i>=0;i--){
            if(f[a][i]!=f[b][i]){
                dis[cnt++]=d1[a][i];
                dis[cnt++]=d2[a][i];
                dis[cnt++]=d1[b][i];
                dis[cnt++]=d2[b][i];
                a=f[a][i];
                b=f[b][i];
            }
        }
        dis[cnt++]=d1[a][0];
        dis[cnt++]=d1[b][0];
    }
    int df=-inf,ds=-inf;
    for(i=0;i<cnt;i++){
        if(dis[i]>df){
            ds=df,df=dis[i];
        }
        else if(dis[i]!=df&&dis[i]>ds){
            ds=dis[i];
        }
    }
    if(c>df)
    return c-df;
    else 
    return c-ds;
    
}
int main(){
    int i;
    cin>>n>>m;
    for(i=0;i<=n;i++)
    p[i]=i;
    for(i=1;i<=m;i++){
        int a,b,c;
        scanf("%d%d%d",&a,&b,&c);
        s[i]=node{a,b,c};
    }
    ll sum=kruscal();
    build();
    bfs();
    ll res=1e18;
    for(i=1;i<=m;i++){
        if(!s[i].f){
           int a=s[i].a,b=s[i].b,c=s[i].c;
           res=min(res,sum+lca(a,b,c));
        }
    }
    cout<<res<<endl;
}
View Code

 

posted @ 2020-05-20 12:28  朝暮不思  阅读(189)  评论(0)    收藏  举报