CodeForces - 1499F Diameter Cuts(树形dp)
题意:
给出一棵树,求有多少种删边的方案,使得删完后每个树的直径都不超过 k k k 。
题解:
设 d p [ u ] [ j ] dp[u][j] dp[u][j] 表示以 u u u 为根的子树,并且从 u u u 开始的最长链为 j j j 的方案数 。在 d f s dfs dfs的过程中,求出每个子树的最长链。
然后就是要当前的链长 j j j和子树的链长 z z z ,可以推出状态转移方程:
n u m [ m a x ( j , z + 1 ) ] + = d p [ u ] [ j ] ⋅ d p [ v ] [ z ] / / 子 树 v 与 u 连 接 , 转 移 的 时 候 要 判 断 j + z + 1 ≤ k num[max(j,z+1)]+=dp[u][j] \cdot dp[v][z] //子树v与u连接,转移的时候要判断j+z+1 \leq k num[max(j,z+1)]+=dp[u][j]⋅dp[v][z]//子树v与u连接,转移的时候要判断j+z+1≤k
n u m [ j ] + = d p [ u ] [ j ] ⋅ d p [ v ] [ z ] / / 子 树 v 不 与 u 连 接 num[j]+=dp[u][j] \cdot dp[v][z]//子树v不与u连接 num[j]+=dp[u][j]⋅dp[v][z]//子树v不与u连接
d p [ u ] [ j ] = n u m [ j ] dp[u][j] =num[j] dp[u][j]=num[j]
那么这里的 n u m num num数组是做啥用的呢?
因为是每次都是新加了一个子树对 d p dp dp进行更新,即要用当前 d p dp dp 数组重新更新自己,但是这样会导致可能用了更新过的 d p dp dp数组再更新自己,这样明显会出错,并且我们可能需要用当前的,但是却被更新过了,这样就不好维护,所以我们要重新开一个数组临时存一下。
代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<stack>
#include<set>
#include<ctime>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int mod=998244353;
const int MAXN=5e3+5;
const int inf=0x3f3f3f3f;
struct node
{
int to;
int next;
/* data */
}e[MAXN<<1];
int head[MAXN];
ll dp[MAXN][MAXN];
ll num[MAXN];
int cnt;
int n,k;
void add(int u,int v)
{
e[cnt].to=v;
e[cnt].next=head[u];
head[u]=cnt++;
}
int dfs(int u,int f)
{
int max_link=0;
dp[u][0]=1;
for(int i=head[u];i!=-1;i=e[i].next)
{
int v=e[i].to;
if(v==f) continue;
int son=dfs(v,u);
memset(num,0,sizeof num);
for(int j=0;j<=min(k,max_link);j++)
{
for(int z=0;z<=min(k,son);z++)
{
if(j+z+1<=k) num[max(j,z+1)]=(num[max(j,z+1)]+dp[u][j]*dp[v][z]%mod)%mod;
num[j]=(num[j]+dp[u][j]*dp[v][z]%mod)%mod;
}
}
max_link=max(max_link,son+1);
for(int j=0;j<=min(k,max_link);j++) dp[u][j]=num[j];
}
return max_link;
}
int main()
{
//==========================================
#ifndef ONLINE_JUDGE
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
#endif
//==========================================
memset(head,-1,sizeof head);
cin>>n>>k;
for(int i=1;i<=n-1;i++)
{
int u,v;
cin>>u>>v;
add(u,v);
add(v,u);
}
dfs(1,-1);
ll ans=0;
for(int i=0;i<=k;i++)
{
ans=(ans+dp[1][i])%mod;
}
cout<<ans<<endl;
}

浙公网安备 33010602011771号