洛谷 P4516 [JSOI2018]潜入行动
一眼树形 \(dp\)
本题有 \(2\) 大难点。
难点之一是状态的设计,这里需要四维状态,\(dp[i][j][0/1][0/1]\) 表示在以 \(i\) 为根的子树内放了 \(j\) 个监听器,\(i\) 号点是否放了监听器,\(i\) 号点是否被它的儿子覆盖,在这种情况下的方案数。
设计好了状态,转移也就水到渠成了。
\(dp[u][j][0][0]\) 只能从 \(dp[v][j][0][1]\) 转移:\(i\) 号节点没放监听设备也没被覆盖,说明它的儿子都没放监听设备,并且它的儿子只能被它的儿子的儿子所覆盖。
\(dp[u][j][0][1]\) 可以从 \(dp[v][j][0][1]\) 和 \(dp[v][j][1][1]\) 转移过来。但还需减掉 \(dp[u][j][0][0]\) 的情况:\(i\) 号节点没放监听设备但被覆盖,说明它所有儿子都没放监听器,至于它的儿子有没有被覆盖,怎么样都行。
\(dp[u][j][1][0]\) 可以从 \(dp[v][j][0][0]\) 和 \(dp[v][j][0][1]\) 转移过来:\(i\) 号节点放了监听设备但没被覆盖,说明它至少一个儿子放了监听器,并且它的儿子只能被它的儿子的儿子所覆盖。
\(dp[u][j][1][1]\) 可以从 \(dp[v][j][0/1][0/1]\) 转移过来。但还需减掉 \(dp[u][j][1][0]\) 的情况。
至于第二维,合并两个子树的时候跑个树上背包就可以了。
难点之二是复杂度的计算。
说实话这题一开始我想到正解了可不知道它能过。
暴力合并其实是 \(\mathcal O(nk)\) 而不是 \(\mathcal O(nk^2)\) 的,下面给出简单的证明(开始抄题解ing):
- 若合并两个大小 \(>k\) 的子树,由于这样的子树最多 \(\frac{n}{k}\) 个,暴力合并复杂度是 \(nk\) 的。
- 若合并一棵大小 \(\leq k\) 的和一棵大小 \(>k\) 的子树,这样那个大小 \(\leq k\) 的子树就变成了大小 \(>k\) 的子树。由于每个点最多只在它的某个祖先处被合并一次,这样复杂度均摊也是 \(nk\) 的。
- 若合并两棵大小 \(\leq k\) 的子树,那相当于对两棵子树中每个点都合并了一次。而合并之后得到的子树的大小 \(\leq 2k\),故每个点最多与 \(2k\) 个这样的点进行了合并,故复杂度还是 \(nk\) 的。
证明比较玄乎,大概看看即可。
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define mp make_pair
typedef pair<int,int> pii;
typedef long long ll;
const ll MOD=1e9+7;
int n,k,siz[100005];
vector<int> g[100005];
int dp[100005][105][2][2];
int tmp[105][2][2];
inline void dfs(int x,int f){
siz[x]=1;
for(int i=0;i<g[x].size();i++){
int y=g[x][i];if(y==f) continue;
dfs(y,x);
}
dp[x][0][0][0]=dp[x][1][1][0]=dp[x][0][0][1]=dp[x][1][1][1]=1;
for(int i=0;i<g[x].size();i++){
int y=g[x][i];if(y==f) continue;
memset(tmp,0,sizeof(tmp));
for(int j=0;j<=min(siz[y],k);j++) for(int l=0;l<=min(siz[x],k-j);l++){
tmp[j+l][0][0]=(tmp[j+l][0][0]+1ll*dp[x][l][0][0]*dp[y][j][0][1]%MOD)%MOD;
tmp[j+l][0][1]=(tmp[j+l][0][1]+1ll*dp[x][l][0][1]*(dp[y][j][0][1]+dp[y][j][1][1])%MOD)%MOD;
tmp[j+l][1][0]=(tmp[j+l][1][0]+1ll*dp[x][l][1][0]*(dp[y][j][0][0]+dp[y][j][0][1])%MOD)%MOD;
tmp[j+l][1][1]=(tmp[j+l][1][1]+1ll*dp[x][l][1][1]*(((dp[y][j][0][0]+dp[y][j][0][1])%MOD+dp[y][j][1][0])%MOD+dp[y][j][1][1])%MOD)%MOD;
}
for(int j=0;j<=min(siz[x]+siz[y],k);j++){
dp[x][j][0][0]=tmp[j][0][0];dp[x][j][0][1]=tmp[j][0][1];
dp[x][j][1][0]=tmp[j][1][0];dp[x][j][1][1]=tmp[j][1][1];
}
siz[x]+=siz[y];
}
for(int j=0;j<=k;j++){
dp[x][j][0][1]=(dp[x][j][0][1]-dp[x][j][0][0]+MOD)%MOD;
dp[x][j][1][1]=(dp[x][j][1][1]-dp[x][j][1][0]+MOD)%MOD;
}
// for(int j=0;j<=k;j++) for(int p=0;p<2;p++) for(int q=0;q<2;q++){
// printf("%d %d %d %d %d\n",x,j,p,q,dp[x][j][p][q]);
// }
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++){
int u,v;scanf("%d%d",&u,&v);
g[u].pb(v);g[v].pb(u);
}
dfs(1,0);int ans=0;
for(int x=0;x<2;x++) ans=(ans+dp[1][k][x][1])%MOD;
printf("%d\n",ans);
return 0;
}
/*
5 3
1 2
1 3
2 4
2 5
6 3
1 2
1 3
2 4
2 5
3 6
*/