最小斯坦纳树 学习笔记

神秘点歌台(?

解决范围

一张没有负权边的图有 \(k\) 个关键点,求一个方案使得连接 \(k\) 个关键点的边权和最小。

流程

首先,这个方案一定是一个包含这 \(k\) 个关键点的生成树。因为如果有环去掉环上最大边还能减小这个方案的边权和。

\(f_{i,S}\) 表示 \(i\) 为根节点的子树上包含点集 \(S\) 的最小代价。
我们对 \(i\) 的度数进行分类讨论:

  • 当生成树中 \(i\) 的度数不为 \(1\) 时:
    我们有 \(f_{i,S}=\min(f_{i,T_1}+f_{i,T_2})\),其中 \(T_1\)\(S\) 的子集,\(T_2\)\(T_1\)\(S\) 中的绝对补集(应该是这么说的吧(?)。
    相当于是枚举自己的不同组合并相加,也就是若干个生成树中该顶点度数大于等于 \(1\) 的方案相加。
    枚举子集时间复杂度 \(O(3^k)\),故该部分时间复杂度 \(O(n\times 3^k)\)
  • 当生成树中 \(i\) 的度数为 \(1\) 时:
    我们有 \(f_{i,S}=\min(f_{j,S}+w_{i,j})\),其中 \(i,j\) 有一条边使得其相连。
    相当于是枚举每一个能到的顶点,并把自己加入进去。
    我们发现这个和最短路的松弛操作很像,于是我们把在上个讨论中被成功更新的点跑最短路(如果没更新的话再怎么跑也不会有影响啊ww)。
    \(S\) 的情况数一共有 \(2^k\) 种,用 SPFA 实现最短路的话时间复杂度为 \(O(nm2^k)\)

代码

洛谷模板 P6192 【模板】最小斯坦纳树为例。

#include<bits/stdc++.h>

#define pii pair<int,int> 
#define pll pair<long long,long long> 
#define ll long long
#define i128 __int128

#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lb(x) ((x)&-(x))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define in4(a,b,c,d) a=read(), b=read(), c=read(), d=read()
#define fst first 
#define scd second 
#define dbg puts("IAKIOI")

using namespace std;

int read() {
	int x=0,f=1; char c=getchar();
	for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1); 
	for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }

const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }

#define maxn 110

int n,m,k;
vector<pii >G[maxn];
int p[maxn];
int f[maxn][(1<<10)+114];

queue<int> q;
bool vis[maxn];

void spfa(int s) {
	while(!q.empty()) {
		int u=q.front(); q.pop();
		vis[u]=0;
		for(auto [v,w]:G[u]) if(f[u][s]+w<f[v][s]){
			f[v][s]=f[u][s]+w;
			if(!vis[v]) q.push(v),vis[v]=1;
		}
	}
}

void work() {
	mem(f,0x3f);
	in3(n,m,k);
	For(i,1,m) {
		int u,v,d;
		in3(u,v,d);
		G[u].push_back({v,d});
		G[v].push_back({u,d});
	}
	For(i,1,k) {
		p[i]=read();
		f[p[i]][(1<<(i-1))]=0;
	}
	For(S0,0,(1<<k)-1) {
		For(i,1,n) {
			for(int S1=S0;S1;S1=(S1-1)&S0) {
				f[i][S0]=min(f[i][S0],f[i][S1]+f[i][S0^S1]);
			}
			if(f[i][S0]<1e9) q.push(i),vis[i]=1;
		}
		spfa(S0);
	}
	cout<<f[p[1]][(1<<k)-1]<<'\n';
}

signed main() {
//	freopen("data.in","r",stdin);
//	freopen("myans.out","w",stdout);
//	ios::sync_with_stdio(false); 
//	cin.tie(0); cout.tie(0);
	double stt=clock();
	int _=1;
//	_=read();
//	cin>>_;
	For(i,1,_) {
		work();
	}
	cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
	return 0;
}
posted @ 2025-04-29 20:21  coding_goat_qwq  阅读(26)  评论(0)    收藏  举报