【BZOJ4033】【HAOI2015】树上染色 树形DP

题目描述

  给你一棵\(n\)个点的树,你要把其中\(k\)个点染成黑色,剩下\(n-k\)个点染成白色。要求黑点两两之间的距离加上白点两两之间距离的和最大。问你最大的和是多少。

  \(n\leq 2000\)

题解

  我们考虑树形DP。

  设\(f_{i,j}\)为以\(i\)为根的子树,染了\(j\)个黑点的最大收益。

  若一条边的一端有\(s_1\)个点,选了\(j_1\)个黑点,另一端有\(s_2\)个点,选了\(j_2\)个黑点,那么这条边的贡献就是

\[w\times(j_1\times j_2+(s_1-j_1)\times (s_2-j_2)) \]

  于是我们就可以从\(f_{x,i},f_{v,j}\)转移到\(f_{x,i+j}\)

  表面上看是\(O(n^3)\)的,因为要枚举选了几个黑点,实际上是\(O(n^2)\)的。

  转移可以看成两边各选一个点,这个点\(x\)就是两边的点的lca。因为总共有\(O(n^2)\)个lca,所以就是\(O(n^2)\)的。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
ll upmin(ll &a,ll b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(ll &a,ll b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
struct graph
{
	int v[5010];
	int w[5010];
	int t[5010];
	int h[2010];
	int n;
	graph()
	{
		memset(h,0,sizeof h);
		n=0;
	}
	void add(int x,int y,int z)
	{
		n++;
		v[n]=y;
		w[n]=z;
		t[n]=h[x];
		h[x]=n;
	}
};
graph g;
ll f[2010][2010];
ll h[2010];
int s[2010];
int n,k;
void dfs(int x,int fa)
{
	s[x]=1;
	f[x][0]=f[x][1]=0;
	int i,v,j,l;
	for(i=g.h[x];i;i=g.t[i])
		if(g.v[i]!=fa)
		{
			v=g.v[i];
			dfs(v,x);
			memset(h,0xc0,sizeof h);
			for(j=0;j<=s[x]&&j<=k;j++)
				for(l=0;l<=s[v]&&j+l<=k;l++)
					if(n-k-s[v]+l>=0)
						upmax(h[j+l],f[x][j]+f[v][l]+ll(g.w[i])*(ll(k-l)*l+ll(n-k-s[v]+l)*(s[v]-l)));
			s[x]+=s[v];
			for(j=0;j<=s[x]&&j<=k;j++)
				f[x][j]=h[j];
		}
}
int main()
{
	scanf("%d%d",&n,&k);
	int i,x,y,z;
	for(i=1;i<n;i++)
	{
		scanf("%d%d%d",&x,&y,&z);
		g.add(x,y,z);
		g.add(y,x,z);
	}
	memset(f,0xc0,sizeof f);
	dfs(1,0);
	printf("%lld\n",f[1][k]);
	return 0;
}
posted @ 2018-03-06 10:59  ywwyww  阅读(150)  评论(0编辑  收藏  举报