gym102222 G. Factories
gym102222 G. Factories
题目大意:
给一棵n个点的树,选m个点,这m个点只能在叶子节点上,问着m个点中两两之间到达其余各点的距离和最小值是多少
题解:
任意两点的树上距离和问题应从边的贡献角度考虑。
树形dp
设 f[u][i] 表示以 u 为根的子树中,选了 i 个叶子节点的最优解,状态转移方程为:
f[u][i+j]=min(f[u][i+j],f[u][i]+f[v][j]+w∗j∗(j−m))
其中所加项为子节点和父节点之间的边的贡献
/* [HAOI2015]树上染色 的弱化版 */ #include<bits/stdc++.h> using namespace std; typedef long long ll; int T,n,m,cas; int main(){ for(scanf("%d",&T);T--;){ scanf("%d%d",&n,&m); vector<vector<pair<int,ll> > >e(n+1); for(int i=1,x,y,z;i<n;i++){ scanf("%d%d%d",&x,&y,&z); e[x].emplace_back(y,z); e[y].emplace_back(x,z); } if(m==1){printf("Case #%d: 0\n",++cas);continue;} if(n==2){printf("Case #%d: %lld\n",++cas,e[1][0].second);continue;} vector<vector<ll> > f(n+1,vector<ll>(m+1,1e14));vector<int>siz(n+1,0); function<void(int,int)>dfs=[&](int u,int fa)->void{ bool leaf=1; f[u][0]=0; for(auto t:e[u]){ int v=t.first;ll w=t.second; if(v==fa) continue; leaf=0; dfs(v,u); for(int i=min(siz[u],m);~i;i--){ for(int j=min(siz[v],m-i);~j;j--){ f[u][i+j]=min(f[u][i+j],f[u][i]+f[v][j]+w*(m-j)*j); } } siz[u]+=siz[v]; } if(leaf){f[u][1]=0;siz[u]=1;} }; int rt=0; for(int i=1;i<=n;i++) if(e[i].size()>1){rt=i;break;} dfs(rt,0); printf("Case #%d: %lld\n",++cas,f[rt][m]); } return 0; }
第二版
#include<bits/stdc++.h> using namespace std; const int N=1e5+5; const int M=N<<1; const int Kn=105; typedef long long ll; int T,n,k,cas; int tot,to[M],nxt[M],head[N],ind[N],siz[N];ll val[M];bool vis[N]; ll f[N][Kn]; inline void add(int x,int y,ll z){ ind[x]++;to[++tot]=y;val[tot]=z;nxt[tot]=head[x];head[x]=tot; } void dfs(int u,int fa){ for(int l=head[u];l;l=nxt[l]){ int v=to[l];ll w=val[l]; if(v==fa) continue; dfs(v,u); siz[u]+=siz[v]; for(int i=min(siz[u],k);i;i--){ for(int j=1,re=min(siz[v],i);j<=re;j++){ f[u][i]=min(f[u][i],f[u][i-j]+f[v][j]+w*(k-j)*j); } } } }; void init_dp(){ for(int i=1;i<=n;i++){ f[i][0]=0; for(int j=1;j<=k;j++) f[i][j]=1e17; if(ind[i]==1) siz[i]=1,f[i][1]=0; } } inline void Clear(){ tot=0; memset(ind,0,sizeof ind); memset(siz,0,sizeof siz); memset(head,0,sizeof head); } int main(){ for(scanf("%d",&T);T--;Clear()){ scanf("%d%d",&n,&k); for(int i=1,x,y,z;i<n;i++){ scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } int rt=1; for(int i=1;i<=n;i++) if(ind[i]>1){rt=i;break;} init_dp(); dfs(rt,0); printf("Case #%d: %lld\n",++cas,f[rt][k]); } return 0; }