严格最小生成树

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring> 
using namespace std;
const int N=500005;
typedef long long ll;
ll ans;
int n,m,cnt,tot;
int mm=0x7f7f7f;
int head[N],f[N][30],dep[N],ff[N],d1[N][30],d2[N][30];
struct node{
	int to,next,w;
}edge[N<<2];
struct Node{
	int u,v,w,vis;
}a[N];
void add(int u,int v,int w){
	edge[tot].to=v;
	edge[tot].next=head[u];
	edge[tot].w=w;
	head[u]=tot++;
}
int find(int x){
	if(ff[x]==x) return x;
	return ff[x]=find(ff[x]);
}
bool cmp(Node aa,Node bb){
	return aa.w<bb.w;
}
void dfs(int u,int fa){
	for(int i=1;(1<<i)<=dep[u];i++){
		f[u][i]=f[f[u][i-1]][i-1];
		d1[u][i]=max(d1[u][i-1],d1[f[u][i-1]][i-1]);
		if(d1[u][i-1]==d1[f[u][i-1]][i-1]) d2[u][i]=max(d2[u][i-1],d2[f[u][i-1]][i-1]);
		else{
			d2[u][i]=min(d1[u][i-1],d1[f[u][i-1]][i-1]);
			d2[u][i]=max(d2[u][i],max(d2[u][i-1],d2[f[u][i-1]][i-1]));
		}
	}
	for(int i=head[u];i!=-1;i=edge[i].next){
		int v=edge[i].to;
		if(v==fa) continue;
		f[v][0]=u;
		dep[v]=dep[u]+1;
		d1[v][0]=edge[i].w;
//		d2[v][0]=edge[i].w;
		dfs(v,u);
	}
}
int lca(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);
	if(dep[u]!=dep[v]){
		for(int j=20;j>=0;j--){
			if(dep[f[u][j]]>=dep[v]) u=f[u][j];
		}
	}
	if(u==v) return u;
	for(int j=20;j>=0;j--){
		if(f[u][j]!=f[v][j]){
			u=f[u][j];
			v=f[v][j];
		}
	}
	return f[u][0];
}
void cal(int u,int fa,int w){
	int mx1=0,mx2=0;
	int t=dep[u]-dep[fa];
	for(int i=0;i<=20;i++){
		if(t&(1<<i)){
			if(d1[u][i]>mx1){
				mx2=mx1;
				mx1=d1[u][i];
			}
			mx2=max(mx2,d2[u][i]);
			u=f[u][i];
		}
	}
	if(mx1!=w) mm=min(mm,w-mx1);
	else mm=min(mm,w-mx2);
} 
void solve(int id,int w){
	int x=a[id].u,y=a[id].v;
	int fa=lca(x,y);
	cal(x,fa,w); cal(y,fa,w);
}
int main(){
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&m);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&a[i].u,&a[i].v,&a[i].w);
	}
	for(int i=1;i<=n;i++) ff[i]=i;
	sort(a+1,a+m+1,cmp);
	for(int i=1;i<=m;i++){
		int r1=find(a[i].u),r2=find(a[i].v);
		if(r1==r2) continue;
		ff[r1]=r2;
		ans+=a[i].w;
		add(a[i].u,a[i].v,a[i].w);
		add(a[i].v,a[i].u,a[i].w);
		a[i].vis=1;
		cnt++;
		if(cnt==n-1) break;
	}
	dfs(1,0);
	for(int i=1;i<=m;i++){
		if(!a[i].vis){
			solve(i,a[i].w);
		}
	}
	printf("%lld\n",ans+mm);
	return 0;
}
posted @ 2021-05-30 20:15  dfydn  阅读(35)  评论(0编辑  收藏  举报