[BZOJ4987] Tree
题目
从前有棵树。
题解
先考虑几个显而易见的性质:
1.选出的点一定是相邻的(不然距离会更大)
2.对于选出的点,如果从ak再走回a1,那么就相当于每条边经过了两次
由于题目没有包含dis(ak,a1),因此就相当于选出的点中的一条链可以只经过一次,其余的需要经过两次。
(如果包含的话就是个很水的树形背包)
就是说,在选点的同时还要选出只计算一次的那条链。
涉及到路径的一般考虑路径拼接的dp方式,
即,设一个状态用来记录当前已经决定了多少个路径的端点,即0,1,2三种取值
又因为要取k个,需要树形背包
结合起来就是:
设$f[i][j][k]$表示以i为根的子树中选择点i,共选出j条边,且包含的链端点数目为k的最小代价。
接下来讨论转移方程
对于每一个节点id,枚举每一个儿子t,在dfs完后合并
cost[i]表示id到t的距离
0的只能从0的来
$dp[id][j+k][0]=min(dp[id][j][0]+dp[t][k][0]+cost[i]*2)$
1的可以分类讨论端点是在已经枚举的子树内还是新加的子树内
$dp[id][j+k][1]=min(dp[id][j][1]+dp[t][k][0]+cost[i]*2,dp[id][j][0]+dp[t][k][1]+cost[i])$
2的有(0,2),(1,1),(2,0)三种情况
$dp[id][j+k][2]=min($
$dp[id][j][2]+dp[t][k][0]+cost[i]*2,$
$dp[id][j][1]+dp[t][k][1]+cost[i],$
$dp[id][j][0]+dp[t][k][2]+cost[i]*2)$
注意要倒序枚举,因为每个物品只能选一个
时间复杂度
咋一看像是$n^3$的,其实是$n^2$
对于节点u,设其每个儿子子树大小为$a_k$,总大小为A
复杂度就是$a_1+a_2*a_1+a_3*(a_1+a_2)+a_4*(a_1+a_2+a_3)$
这样不好分析,我们将上式×2
$a_1+a1*(a_2+a_3+a_4)+a_2*(a_1+a_3+a_4)+a_3*(a_1+a_2+a_4)+a_4*(a_1+a_2+a_3)$
$a_1+\sum_{i=1}^{k} a_i*A-a_i^2$
忽略第一项,剩下的可以跟父亲抵消掉
最后只剩下root的A
也就是$n^2$
代码
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
#define N 10000
int head[N],cost[N],to[N],nxt[N],cnt,dp[3010][3010][3],n,K,ind,sz[N];
void connect(int a,int b,int c)
{
to[++cnt]=b,cost[cnt]=c,nxt[cnt]=head[a],head[a]=cnt;
to[++cnt]=a,cost[cnt]=c,nxt[cnt]=head[b],head[b]=cnt;
}
void dfs(int id,int fa)
{
sz[id]=1;
dp[id][1][1]=dp[id][1][0]=0;
//cout<<id<<" "<<fa<<endl;
for(int i=head[id];i;i=nxt[i])
{
int t=to[i];
if(t==fa) continue;
dfs(t,id);
for(int j=sz[id];j>=0;j--)
{
for(int k=1;k<=sz[t];k++)
{
dp[id][j+k][0]=min(dp[id][j+k][0],dp[id][j][0]+dp[t][k][0]+cost[i]*2);
dp[id][j+k][1]=min(dp[id][j+k][1],min(dp[id][j][1]+dp[t][k][0]+cost[i]*2,dp[id][j][0]+dp[t][k][1]+cost[i]));
int minn=min(min(dp[id][j][2]+dp[t][k][0]+cost[i]*2,dp[id][j][1]+dp[t][k][1]+cost[i]),dp[id][j][0]+dp[t][k][2]+cost[i]*2);
dp[id][j+k][2]=min(dp[id][j+k][2],minn);
}
}
sz[id]+=sz[t];
}
//cout<<id<<":\n";
// for(int i=1;i<=sz[id];i++) printf("%d %d %d\n",dp[id][i][0],dp[id][i][1],dp[id][i][2]);
}
int main()
{
cin>>n>>K;
for(int i=1;i<n;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
connect(a,b,c);
}
memset(dp,0x3f,sizeof(dp));
dfs(1,0);
int ans=0x7fffffff;
for(int i=1;i<=n;i++) ans=min(ans,dp[i][K][2]);
cout<<ans;
}

浙公网安备 33010602011771号