CF125E MST company (凸优化+MST)

qwq自闭的一个题

我来修锅辣!!!!!!

这篇题解!可以\(hack\)全网大部分的做法!!!

首先,我们可以把原图中的边,分成两类,一类是与\(1\)相连,另一类是不与\(1\)相连。

原题就转化成选择\(k\)条关键边的\(MST\)

那么我们可以按照tree I 那个题的思路来考虑这个题。

由于是\(MST\),所以函数满足下凸,那么对于这种恰好选\(k\)个的问题,我们可以直接凸优化。

\(erf\)一个值,然后把所有与1相连的边都加上这个值。

对于相等权值的来说,我们优先把不与1相连的边排在前面。

那么这种情况我们二分出来的那个\(mid\)
满足两个条件
1.满足下界,也就是说,能选到<=k条边的最大的\(mid\)
2.在这种情况下,\(mst\)上面的与1相连的边,是“必须出现在MST”上的边。

那么我们考虑该怎么统计方案。

首先,对于那些一定要出现在\(MST\)上的关键边,我们先把他们加入\(ans\)(只加入与1相连的边),然后对于剩下的边,把与1相连的边加上\(mid\)后,进行\(MST\),如果已经选够了\(k\)条边,那么对于剩下的关键边就直接跳过。

这样做正确的原因是,我们首先把必须要出现暗在\(MST\)上的边加入了\(ans\),然后对于剩下的边,只会分成两种,可能在MST上,或者是不可能在MST上,那剩下的部分直接用贪心的思路来做\(MST\)就是没错的。

而网上大多数题解是错的,qwq
所以这个问题困扰了我很久

#include<bits/stdc++.h>
#define pb push_back
#define mk make_pair
#define ll long long
#define int long long
using namespace std;

inline int read()
{
   int x=0,f=1;char ch=getchar();
   while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
   while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
   return x*f;
}
const int maxn = 3e5+1e2;
struct Edge{
    int u,v;
    ll w;
    int tag,num;
}; 
Edge e[maxn];
int fa[maxn];
int n,m,k;
int val;
int find(int x)
{
    if (fa[x]!=x) fa[x]=find(fa[x]);
    return fa[x];
}

bool cmp(Edge a,Edge b)
{
    if(a.w==b.w) return a.tag>b.tag;
    return a.w<b.w;
}

bool cmp1(Edge a,Edge b)
{
	if (a.w==b.w) return a.tag<b.tag;
	return a.w<b.w;
}

int solve()
{
    for (int i=1;i<=n;i++) fa[i]=i;
    sort(e+1,e+1+m,cmp1);
    int tot=0;
    for (int i=1;i<=m;i++)
    {
        int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        fa[f1]=fa[f2];
        tot+=e[i].tag; 
    }
    return tot;
}

vector<int> v;

int lyf[maxn*2];

signed main()
{
   n=read(),m=read(),k=read();
   ll l = -1e10,r=1e10;
   int ymh=0;
   for (int i=1;i<=m;i++)
   {
   	   e[i].u=read(),e[i].v=read(),e[i].w=read(),e[i].num=i;
   	   if (e[i].u==1 || e[i].v==1) e[i].tag=1;
   	   if (e[i].tag==1) ymh++;
   }
   ll ans=0;
   //cout<<ymh<<endl;
   while(r>=l)
   {
   	  ll mid = (l+r)/2;
   	  for (int i=1;i<=m;i++) if(e[i].tag) e[i].w+=mid;
   	  int tmp = solve();
   	  //cerr<<"*"<<mid<<" "<<tmp<<endl;
      if (tmp<=k) r=mid-1,ans=mid;
      else l=mid+1;
      for (int i=1;i<=m;i++) if (e[i].tag) e[i].w-=mid;
   }
   //cerr<<ans<<endl;
   int ptx=0;
   for (int i=1;i<=m;i++) if (e[i].tag) e[i].w+=ans;
   for (int i=1;i<=n;i++) fa[i]=i;
   sort(e+1,e+1+m,cmp1);
   for (int i=1;i<=m;i++)
   {
        int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        fa[f1]=fa[f2];
        if (e[i].tag==1) lyf[e[i].num]=1;
   }
   for (int i=1;i<=n;i++) fa[i]=i;
   for (int i=1;i<=m;i++)
   {
   	  if (lyf[e[i].num])
	  {
   	  	int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        fa[f1]=fa[f2];
        ptx+=e[i].tag;
		v.pb(e[i].num);
	  }
   }
   //cout<<ptx<<endl;
   if (ptx>k) 
   {
   	 cout<<-1;
   	 return 0;
   }
   //for (int i=1;i<=n;i++) fa[i]=i;
   //cerr<<ptx<<endl;
   sort(e+1,e+1+m,cmp);
   for (int i=1;i<=m;i++)
   {
   	 int f1 = find(e[i].u);
     int f2 = find(e[i].v);
     if (f1==f2) continue;
     if (ptx==k && e[i].tag==1) continue;
     fa[f1]=fa[f2];
     v.pb(e[i].num);
     if (e[i].tag==1) ptx++;
   }
   //cerr<<ptx<<endl;
   if(ptx!=k || v.size()!=n-1) 
   {
     cout<<-1;
     return 0;
   }
   cout<<n-1<<endl;
   for (int i=0;i<v.size();i++) cout<<v[i]<<" ";
   return 0;
}
//final
posted @ 2019-01-01 09:38  y_immortal  阅读(242)  评论(0编辑  收藏  举报