洛谷 P3366 【模板】最小生成树 题解
这里介绍求最小生成树的最常见的两种解法。
Kruskal 算法
这是最好写最好想的一种写法,本质上是贪心+并查集。
我们把输入的边存起来,然后把边按边权为关键字从小到大排序,然后遍历这些边。我们贪心地尽量使用边权小的边,由于要生成的是树,意味着我们不能出现环,也就是两个点之间只能有一条路径。所以我们可以用并查集来维护这些点,当遍历到的边中的两个点已经在同一个连通块里面,则这条边不能使用,否则加入这条边,将两个点合并到一个连通块。
最后判断图是否连通,只需要看我们加入的边数是否达到 \(n-1\)。
证明:按边权从小到大排序,如果这条边可以选但是没有选,那么为了使这两个点能够联通,就需要在后面的边里选择直接或间接连接这两个点的边,又已知后面的边权更大,一定不优,所以排序后能选的时候就选。
时间复杂度 \(O(m\log m)\)。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=5100,M=2e5+100;
int n,m,fa[N],cnt,sum;
struct node
{
int x,y,z;
}a[M];
bool cmp(node a1,node a2)
{
return a1.z<a2.z;
}
int set_find(int dx)
{
return dx==fa[dx]?dx:fa[dx]=set_find(fa[dx]);
}
void set_merge(int dx,int dy)
{
int gx=set_find(dx),gy=set_find(dy);
if(gx!=gy) fa[gx]=gy;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) fa[i]=i;
for(int i=1;i<=m;i++) scanf("%d%d%d",&a[i].x,&a[i].y,&a[i].z);
sort(a+1,a+m+1,cmp);
for(int i=1;i<=m;i++)
{
if(set_find(a[i].x)!=set_find(a[i].y))
{
set_merge(a[i].x,a[i].y);
cnt++,sum+=a[i].z;
}
if(cnt==n-1) break;
}
cnt<n-1?printf("orz"):printf("%d",sum);
return 0;
}
Prim 算法
这种方法和最短路中的 Dijkstra 比较相似,本质上也是贪心。
我们可以把最终的最小生成树视为一个集合(下文称 mst 点集),一开始它只有结点 \(1\),后面每次遍历加入新的点,最终形成我们要的最小生成树。我们定义 \(dis_i\) 为与 mst 点集直接相连的结点 \(i\) 到 mst 点集的距离,如果结点 \(i\) 与 mst 点集并不直接相连,则 \(dis_i=+\infty\)。由于 \(1\) 号结点一开始就在集合里,应初始化 \(dis_1=0\)。每次遍历寻找 \(dis\) 最小的点加入,然后更新与这个点相邻的点的 \(dis\)。对于判连通,如果某一次在找新点时发现找不到,那么图是不连通的。
证明:与 Dijkstra 的证明类似,每次都要往 mst 点集加入一个新点,这个新点的作用是更新点集外的点到 mst 点集的距离,假设目前距离最短的点为 \(x\),但我们没有选择让 \(x\) 加入而是让另一个点 \(y\) 加入,那么因为 \(y\) 到点集的距离更长,再去更新 \(x\) 不可能使 \(x\) 更短,所以每次选择距离最短的点加入最优。
时间复杂度 \(O(n^2)\)。稠密图适合使用 Prim。但一般情况下 Kruskal 跑得比 Prim 快,即使 Prim 加上堆优化,常数也比 Kruskal 大,所以在求解最小生成树时,我们一般使用 Kruskal。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=5100,M=2e5+100,INF=0x3f3f3f3f;
int n,m,k,t[N],u,v,w,dis[N],mk,sum;
bool flag[N];
struct node
{
int id,last,val;
}a[M*2];
void add(int a1,int a2,int a3)
{
a[++k].id=a2;
a[k].last=t[a1];
a[k].val=a3;
t[a1]=k;
}
int main()
{
scanf("%d%d",&n,&m);
while(m--)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w),add(v,u,w);
}
memset(dis,0x3f,sizeof dis);
dis[1]=0;
for(int i=1;i<=n;i++)
{
mk=0;
for(int j=1;j<=n;j++)
{
if(!flag[j]&&dis[j]<dis[mk]) mk=j;
}
if(dis[mk]==INF)//图不连通
{
printf("orz");
return 0;
}
//取结点mk
sum+=dis[mk];
flag[mk]=true;
for(int j=t[mk];j;j=a[j].last)//更新与mk相连的点离mst点集的最短距离
{
if(!flag[a[j].id]&&a[j].val<dis[a[j].id]) dis[a[j].id]=a[j].val;
}
}
printf("%d",sum);
return 0;
}

浙公网安备 33010602011771号