树形dp-子树合并

刚知道我一直写的是假的树形dp
树性dp一类是树上背包,一道经典模板选课
我们大多写的是\(nm^2\)的算法,即先枚举点,再枚举背包容量,接着枚举子树选的容量大小
蓝书上的写法也是这样,对于此题由于数据小可以过,但是可以优化
考虑限制枚举范围,每次只有小于当前子树\(size\)的枚举才合法,于是可以进行剪枝

void dfs(int x)
{
	size[x]++;f[x][1]=w[x];
	for(int i=head[x];i;i=a[i].next)
	{
		int y=a[i].to;
		dfs(y);
		for(int j=min(m+1,size[x]);j;j--)
		  for(int k=1;k<=min(size[y],m+1-j);k++)
		   f[x][j+k]=max(f[x][j+k],f[x][j]+f[y][k]);
		size[x]+=size[y];
	}
}

这里用动态更新的上界来优化,表面上看没什么,事实上每对点仅会在lca处贡献答案,复杂度为\(nm\)
\(skyh\)学长似乎有更为本质化的证明
image
image
似乎有点像势能分析?
这个在别的题里面就很有用了,是一种重要的树形dp思路,也是树上背包的正确打开方式

模拟46 T2 数树

首先这个有容斥的思想
题目让求合法的,我们可以考虑钦定有多少条边不合法然后容斥
钦定\(i\)条边不合法如果是\(s_i\)的话,那么答案就是\(\sum_{i=0}^{n-1}s_i\times(n-i)!\times(-1)^i\)
容斥系数要乘阶乘的原因很简单,钦定\(i\)条边剩下就随便选就行了,关键的这个\(s\)怎么求
\(f_{x,i}\)代表在\(x\)子树中,选择\(i\)条不合法边的方案数,用\(0/1/2/3\)分别表示当前\(x\)点上下都没连,连上不连下,连下不连上,上下都连
转移就用上面的套路,先枚举子节点,然后分别枚举\(x\)\(y\)\(size\),进行转移
因为各个条件之间不能冲突,所以他选的一定会成链,按照这个转移一波就行,具体看代码
注意了,因为我们写的是子树合并dp,所以每次都是利用原有的dp进行一系列转移
我们在枚举每个子节点的时候,都要用已有状态更新一些其他状态,但由于\(j\)\(k\)都可能枚举到0,所以可能出现自己更新自己的情况,而这势必造成后续转移重复,因此我们需要消除这个影响
一句话:要开辅助数组!
最后用辅助数组更新原dp数组,由于每加入一个儿子新树的形态实际已经改变,因此原来的dp数组已经失去意义,所以直接覆盖而不是累加
感谢付队帮助理解,记得变量名别重!

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mod=998244353;
const int N=5050;
struct node{
 	int from,to,next,op;
}a[2*N];
int head[N],mm=1;
inline void add(int x,int y,int op)
{
	a[mm].from=x;a[mm].to=y;a[mm].op=op;
	a[mm].next=head[x];head[x]=mm++;
}
int size[N],f[N][N][4],g[N][4];bool v[N];
void dfs(int x)
{
	v[x]=1;f[x][0][0]=1;
	for(int i=head[x];i;i=a[i].next)
	{
		int y=a[i].to;if(v[y])continue;
		dfs(y);memset(g,0,sizeof(g));
		for(int j=0;j<=size[x];j++)
		 for(int k=0;k<=size[y];k++)
		 {
		 	int sum=(f[y][k][0]+f[y][k][1]+f[y][k][2]+f[y][k][3])%mod;
		 	for(int p=0;p<4;p++)g[j+k][p]=(g[j+k][p]+sum*f[x][j][p]%mod)%mod;
		 	if(a[i].op)
		 	{
		 		g[j+k+1][2]=(g[j+k+1][2]+(f[y][k][0]+f[y][k][2])%mod*f[x][j][0]%mod)%mod;
		 		g[j+k+1][3]=(g[j+k+1][3]+(f[y][k][0]+f[y][k][2])%mod*f[x][j][1]%mod)%mod;
			}
			else 
			{
				g[j+k+1][1]=(g[j+k+1][1]+(f[y][k][0]+f[y][k][1])%mod*f[x][j][0]%mod)%mod;
				g[j+k+1][3]=(g[j+k+1][3]+(f[y][k][0]+f[y][k][1])%mod*f[x][j][2]%mod)%mod;
			}
		 }
	   size[x]+=size[y];
	   for(int j=0;j<=size[x];j++)
	    for(int k=0;k<4;k++)
	     f[x][j][k]=g[j][k];
	}
	size[x]++;
}
int jc[N];
signed main()
{	
	int n;cin>>n;jc[0]=1;
	for(int i=1;i<=n;i++)jc[i]=jc[i-1]*i%mod;
	for(int i=1;i<n;i++)
	{
		int x,y;scanf("%lld%lld",&x,&y);
		add(x,y,1);add(y,x,0);
	}
	dfs(1);int ans=0;
	for(int i=0;i<n;i++)
	{
		int sum=(f[1][i][0]+f[1][i][1]+f[1][i][2]+f[1][i][3])%mod;
		if((i&1))ans=(ans-sum*jc[n-i]%mod+mod)%mod;
		else ans=(ans+sum*jc[n-i]%mod)%mod;
	}
	cout<<ans<<endl;
	return 0;
}
posted @ 2021-09-06 20:13  D'A'T  阅读(96)  评论(0)    收藏  举报