【NOIP2017 提高组】 宝藏 题解 (洛谷P3959)
原题:
https://www.luogu.com.cn/problem/P3959
简化下题意,就是给定一张图,我们需要求出这张图的一棵有根生成树,满足生成树中各边与该边深度的乘积之和最小。
该题的暴力算法非常显然,穷举树的根,从已经在树中的点向外dfs即可。下面的代码是不加任何优化的裸暴力:
#include<bits/stdc++.h> using namespace std; #define il inline #define ll long long il int read() { int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar();} while(ch<='9'&&ch>='0') s=s*10+ch-'0',ch=getchar(); return s*w; } const int N=20; const int M=2010; struct edge{ int t,nex,v; } e[M]; int head[N],tot; void add(int x,int y,int v) { e[++tot].t=y; e[tot].nex=head[x]; e[tot].v=v; head[x]=tot; } ll ans=10000000010; int n,m; bool vis[N]; int s[20],top,dep[20]; void check() { printf("114514 "); } void dfs(ll sum,int num) { if(num==n) { ans=min(ans,sum); return; } for(int i=1;i<=top;i++) { int x=s[i]; for(int j=head[x];j;j=e[j].nex) { int y=e[j].t,z=e[j].v; if(!vis[y]) { vis[y]=1; s[++top]=y; dep[y]=dep[x]+1; dfs(sum+dep[x]*z,num+1); top--; vis[y]=0; dep[y]=0; } } } } int main() { n=read(); m=read(); for(int i=1;i<=m;i++) { int a=read(),b=read(),v=read(); add(a,b,v); add(b,a,v); } for(int i=1;i<=n;i++) { vis[i]=1; s[++top]=i; dep[i]=1; dfs(0,1); //check(); vis[i]=0; dep[i]=0; top--; } printf("%lld ",ans); return 0; }
这个暴力可以拿30pts。但如果我们在每次进入dfs时加上以下判断来剪枝,我们就能拿到60pts:
if(sum>=ans) return;
(很明白的剪枝,不讲了)顺便,这个代码在经过合理的剪枝后是可以AC本题的。详情参考 https://www.luogu.com.cn/blog/user54022/solution-p3959。生动的说明了爆搜剪枝的重要性()
下面来考虑正解。看到这个范围,我们第一反应就是状压DP。我们可以把要求的生成树看做一个点集,每个顶点是否已经在这个点集中用0与1表示。这样我们就得到了能够表示当前阶段的状态。该如何转移呢?我们可以另开一个数组dep,来记录每个点在当前状态的深度,仿照之前爆搜的思路,我们可以写出状态转移方程:f[s| (1 << j)]=min(f[s],dep[i]*dist[i][j]),i为s中的点,j为与i有边相连且不在s中的点。对于每一个状态,不断从已扩展的点向外搜索,我们会发现其实这个方程更适合用dfs来求解。所以我们得到如下代码:
#include<bits/stdc++.h> using namespace std; #define il inline #define ll long long il int read() { int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar();} while(ch<='9'&&ch>='0') s=s*10+ch-'0',ch=getchar(); return s*w; } const int N=13; const int M=2010; ll ans=10000000010,f[1 << N]; int dist[N][N],dep[N]; int n,m; bool vis[N]; void dfs(int s) { for(int i=1;i<=n;i++) { if(s & (1 << (i-1))) { for(int j=1;j<=n;j++) { if(!(s & (1 << (j-1)))) { if(f[s | (1<< (j-1))]>f[s]+(ll)dep[i]*(ll)dist[i][j]) { int kx=dep[j]; dep[j]=dep[i]+1; f[s | (1<< (j-1))]=f[s]+(ll)dep[i]*(ll)dist[i][j]; dfs(s | (1<< (j-1))); dep[j]=kx; } } } } } } int main() { n=read(); m=read(); memset(dist,0x3f,sizeof(dist)); for(int i=1;i<=m;i++) { int a=read(),b=read(),v=read(); dist[a][b]=min(v,dist[a][b]); dist[b][a]=min(v,dist[b][a]); } for(int i=1;i<=n;i++) { memset(f,0x3f,sizeof(f)); memset(dep,0x3f,sizeof(dep)); int root=1 << (i-1); f[root]=0; dep[i]=1; dfs(root); ans=min(f[(1<<n)-1],ans); } printf("%lld ",ans); return 0; }
这个代码可以获得满分,复杂度O(2n*n3)
完。

浙公网安备 33010602011771号