树形dp-子树合并
刚知道我一直写的是假的树形dp
树性dp一类是树上背包,一道经典模板选课
我们大多写的是\(nm^2\)的算法,即先枚举点,再枚举背包容量,接着枚举子树选的容量大小
蓝书上的写法也是这样,对于此题由于数据小可以过,但是可以优化
考虑限制枚举范围,每次只有小于当前子树\(size\)的枚举才合法,于是可以进行剪枝
void dfs(int x)
{
size[x]++;f[x][1]=w[x];
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].to;
dfs(y);
for(int j=min(m+1,size[x]);j;j--)
for(int k=1;k<=min(size[y],m+1-j);k++)
f[x][j+k]=max(f[x][j+k],f[x][j]+f[y][k]);
size[x]+=size[y];
}
}
这里用动态更新的上界来优化,表面上看没什么,事实上每对点仅会在lca处贡献答案,复杂度为\(nm\)
\(skyh\)学长似乎有更为本质化的证明


似乎有点像势能分析?
这个在别的题里面就很有用了,是一种重要的树形dp思路,也是树上背包的正确打开方式
模拟46 T2 数树
首先这个有容斥的思想
题目让求合法的,我们可以考虑钦定有多少条边不合法然后容斥
钦定\(i\)条边不合法如果是\(s_i\)的话,那么答案就是\(\sum_{i=0}^{n-1}s_i\times(n-i)!\times(-1)^i\)
容斥系数要乘阶乘的原因很简单,钦定\(i\)条边剩下就随便选就行了,关键的这个\(s\)怎么求
设\(f_{x,i}\)代表在\(x\)子树中,选择\(i\)条不合法边的方案数,用\(0/1/2/3\)分别表示当前\(x\)点上下都没连,连上不连下,连下不连上,上下都连
转移就用上面的套路,先枚举子节点,然后分别枚举\(x\)和\(y\)的\(size\),进行转移
因为各个条件之间不能冲突,所以他选的一定会成链,按照这个转移一波就行,具体看代码
注意了,因为我们写的是子树合并dp,所以每次都是利用原有的dp进行一系列转移
我们在枚举每个子节点的时候,都要用已有状态更新一些其他状态,但由于\(j\),\(k\)都可能枚举到0,所以可能出现自己更新自己的情况,而这势必造成后续转移重复,因此我们需要消除这个影响
一句话:要开辅助数组!
最后用辅助数组更新原dp数组,由于每加入一个儿子新树的形态实际已经改变,因此原来的dp数组已经失去意义,所以直接覆盖而不是累加
感谢付队帮助理解,记得变量名别重!
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int mod=998244353;
const int N=5050;
struct node{
int from,to,next,op;
}a[2*N];
int head[N],mm=1;
inline void add(int x,int y,int op)
{
a[mm].from=x;a[mm].to=y;a[mm].op=op;
a[mm].next=head[x];head[x]=mm++;
}
int size[N],f[N][N][4],g[N][4];bool v[N];
void dfs(int x)
{
v[x]=1;f[x][0][0]=1;
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].to;if(v[y])continue;
dfs(y);memset(g,0,sizeof(g));
for(int j=0;j<=size[x];j++)
for(int k=0;k<=size[y];k++)
{
int sum=(f[y][k][0]+f[y][k][1]+f[y][k][2]+f[y][k][3])%mod;
for(int p=0;p<4;p++)g[j+k][p]=(g[j+k][p]+sum*f[x][j][p]%mod)%mod;
if(a[i].op)
{
g[j+k+1][2]=(g[j+k+1][2]+(f[y][k][0]+f[y][k][2])%mod*f[x][j][0]%mod)%mod;
g[j+k+1][3]=(g[j+k+1][3]+(f[y][k][0]+f[y][k][2])%mod*f[x][j][1]%mod)%mod;
}
else
{
g[j+k+1][1]=(g[j+k+1][1]+(f[y][k][0]+f[y][k][1])%mod*f[x][j][0]%mod)%mod;
g[j+k+1][3]=(g[j+k+1][3]+(f[y][k][0]+f[y][k][1])%mod*f[x][j][2]%mod)%mod;
}
}
size[x]+=size[y];
for(int j=0;j<=size[x];j++)
for(int k=0;k<4;k++)
f[x][j][k]=g[j][k];
}
size[x]++;
}
int jc[N];
signed main()
{
int n;cin>>n;jc[0]=1;
for(int i=1;i<=n;i++)jc[i]=jc[i-1]*i%mod;
for(int i=1;i<n;i++)
{
int x,y;scanf("%lld%lld",&x,&y);
add(x,y,1);add(y,x,0);
}
dfs(1);int ans=0;
for(int i=0;i<n;i++)
{
int sum=(f[1][i][0]+f[1][i][1]+f[1][i][2]+f[1][i][3])%mod;
if((i&1))ans=(ans-sum*jc[n-i]%mod+mod)%mod;
else ans=(ans+sum*jc[n-i]%mod)%mod;
}
cout<<ans<<endl;
return 0;
}

浙公网安备 33010602011771号