【BZOJ4042】【CERC2014】parades 状压DP

题目大意

  给你一棵\(n\)个点的树和\(m\)条路径要求你找出最多的路径,使得这些路径不共边。特别的,每个点的度数\(\leq 10\)

  \(n\leq 1000,m\leq \frac{n(n-1)}{2}\)

题解

  先对于每个点把相邻的边编号。

  考虑状压DP。

  设\(f_{i,j}\)为以\(i\)个点的子树内,状态为\(j\)的边的子树内的边也没有选(这些边也没选),所选的最多路径数。

  然后对于每个点有很多种选法,分为两类:1.某条边不选,选对应的子树;2.选\(1\)~\(2\)条边和对应的路径,那么这些路径经过的边都不能选。

  然后直接状压DP。

  对于每个点来说,总共有最多\(O(d^2)\)种转移。考虑每个儿子的贡献,就是\(O(d)\)

  时间复杂度:\(O(n^2+nd2^d)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
struct list
{
	int t[1000010];
	pii v[1000010];
	int h[1010];
	int n;
	void clear()
	{
		memset(h,0,sizeof h);
		n=0;
	}
	void add(int x,pii y)
	{
		n++;
		v[n]=y;
		t[n]=h[x];
		h[x]=n;
	}
};
list l;
int f[1010][1<<10];
int g[1010];
int c[1010][20];
int d[1010];
int ns[12][12];
int e[1010];
void dfs2(int x,int fa,int t,int s)
{
	int fc;
	int i;
	for(i=1;i<=d[x];i++)
		if(c[x][i]==fa)
			fc=i;
	g[x]=s+f[x][((1<<d[x])-1)^(1<<(fc-1))];
	e[x]=t;
	for(i=1;i<=d[x];i++)
		if(c[x][i]!=fa)
			dfs2(c[x][i],x,t,s+f[x][((1<<d[x])-1)^(1<<(fc-1))^(1<<(i-1))]);
}
int dd[1010];
int ff[1010];
int lca[1010][1010];
void dfs(int x,int fa,int dep)
{
	ff[x]=fa;
	dd[x]=dep;
	int i;
	for(i=1;i<=d[x];i++)
		if(c[x][i]!=fa)
			dfs(c[x][i],x,dep+1);
}
int getlca(int x,int y)
{
	if(x==y)
		return x;
	if(lca[x][y])
		return lca[x][y];
	if(dd[x]>dd[y])
		return lca[x][y]=getlca(ff[x],y);
	return lca[x][y]=getlca(x,ff[y]);
}
void dp(int x,int fa)
{
	int i;
	for(i=1;i<=d[x];i++)
		if(c[x][i]!=fa)
			dp(c[x][i],x);
	for(i=1;i<=d[x];i++)
		if(c[x][i]!=fa)
			dfs2(c[x][i],x,i,0);
	memset(ns,0,sizeof ns);
	int cx,cy,cs;
	for(i=l.h[x];i;i=l.t[i])
	{
		if(l.v[i].first==x)
		{
			cx=0;
			cy=e[l.v[i].second];
			cs=g[l.v[i].second];
		}
		else if(l.v[i].second==x)
		{
			cx=e[l.v[i].first];
			cy=0;
			cs=g[l.v[i].first];
		}
		else
		{
			cx=e[l.v[i].first];
			cy=e[l.v[i].second];
			cs=g[l.v[i].first]+g[l.v[i].second];
		}
		cs++;
		if(cx>cy)
			swap(cx,cy);
		ns[cx][cy]=max(ns[cx][cy],cs);
	}
	for(i=1;i<=d[x];i++)
		if(c[x][i]!=fa)
		{
			cx=0;
			cy=i;
			cs=f[c[x][i]][(1<<d[c[x][i]])-1];
			ns[cx][cy]=max(ns[cx][cy],cs);
		}
	int j,k;
	for(i=0;i<=d[x];i++)
		for(j=0;j<=d[x];j++)
			if(ns[i][j])
			{
				int s=0;
				if(i)
					s|=1<<(i-1);
				if(j)
					s|=1<<(j-1);
				for(k=0;k<(1<<d[x]);k++)
					if(!(k&s))
						f[x][k|s]=max(f[x][k|s],f[x][k]+ns[i][j]);
			}
}
void solve()
{
	memset(d,0,sizeof d);
	int n;
	scanf("%d",&n);
	int i,j;
	for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
			lca[i][j]=0;
	for(i=1;i<=n;i++)
		for(j=0;j<(1<<10);j++)
			f[i][j]=0;
	l.clear();
	int x,y;
	for(i=1;i<=n-1;i++)
	{
		scanf("%d%d",&x,&y);
		c[x][++d[x]]=y;
		c[y][++d[y]]=x;
	}
	dfs(1,0,1);
	int m;
	scanf("%d",&m);
	for(i=1;i<=m;i++)
	{
		scanf("%d%d",&x,&y);
		l.add(getlca(x,y),pii(x,y));
	}
	dp(1,0);
	int ans=0;
	for(i=1;i<=n;i++)
		ans=max(ans,f[i][(1<<d[i])-1]);
	printf("%d\n",ans);
}
int main()
{
#ifdef DEBUG
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
#endif
	int t;
	scanf("%d",&t);
	while(t--)
		solve();
	return 0;
}
posted @ 2018-03-05 21:09  ywwyww  阅读(312)  评论(0编辑  收藏  举报