CF1799H Tree Cutting 解题报告
\(\texttt{Description}\)
给你一个有 \(n\) 个点的树和一个长度为 \(k\) 的序列 \(s_i\),你要执行以下操作 \(k\) 次:
-
选出目前存在的一条边。
-
删除这条边。
-
选出分裂出的两个联通块的其中一块。
-
删除这个联通块,并写下剩余的联通块的大小。
问有多少种不同的操作方案使得写下的序列等于给出的序列,答案对 \(998244353\) 取模。
\(n\le 5000,k\le 6\)。
\(\texttt{Solution}\)
没看懂官方题解在说什么,于是自己搞了一个做法,复杂度比标算劣一些。
考虑操作 \(u\) 到其父亲的边,一种可能是保留以 \(u\) 为根的子树,一种可能是删除以 \(u\) 为根的子树。如果是后者,将符合题意的条件从保留的联通块大小为 \(s_i\),改成删去大小为 \(s_{i-1}-s_i\) 的联通块,容易发现这是等价的。
这样转化的好处是,每次操作都是对一个子树操作,方便树形 dp。
状压,记 \(f(u,i,s)\) 表示以 \(u\) 为根的子树,子树内最早的保留子树操作是 \(i\)(如果不存在,则认为 \(i=k+1\)),操作的是以 \(u\) 为根的子树的子树的集合为 \(s\)。
合并两个子树的时候,要求至多有一个子树满足 \(i\le k\),且两个子树的集合 \(s1,s2\) 满足 \(s1\cap s2=\emptyset\)。
在合并完所有儿子后,再考虑操作 \(u\) 的子树,枚举对应的操作 \(j\)。
如果这是一个保留子树的操作,则要求 \(i>j\);如果是一个删除子树的操作,则要求 \(j>\forall l\in s\)。同时,可以简单计算出操作 \(j\) 之前 \(u\) 的子树大小,从而判断出是否满足条件。
时间复杂度 \(O(n(k3^k+k^22^k))\),可以通过高维前缀和优化到 \(O(nk^22^k)\),但是没有必要。
\(\texttt{Code}\)
需要注意的是,代码中的 \(s_i\) 是从 \(0\) 开始编号的,所以上述的值部分要减一。
#include<bits/stdc++.h>
#define IL inline
#define reg register
#define N 5050
#define mod 998244353
IL int read()
{
reg int x=0; reg char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
return x;
}
IL int Add(reg int x,reg int y){return x+y<mod?x+y:x+y-mod;}
IL void Pls(reg int &x,reg int y){x=Add(x,y);}
IL int Mul(reg int x,reg int y){reg long long r=1ll*x*y; return r<mod?r:r%mod;}
int n,k,c[7],d[7];
std::vector<int>G[N];
IL void add(reg int u,reg int v){G[u].push_back(v),G[v].push_back(u);}
int full,f[N][7][1<<6|5],high[1<<6|5],sum[1<<6|5],sz[N],ans;
void dfs(reg int u,reg int fa=0)
{
sz[u]=1;
for(reg auto v:G[u])if(v!=fa)dfs(v,u),sz[u]+=sz[v];
static int g[7][1<<6|5]; f[u][k][0]=1;
for(reg auto v:G[u])if(v!=fa)
{
for(reg int j=0,s;j<=k;++j)for(s=0;s<=full;++s)g[j][s]=0;
for(reg int s1=0,s2,s;s1<=full;++s1)for(s=s1;s<=full;s=(s+1)|s1)
s2=s^s1,Pls(g[k][s],Mul(f[u][k][s1],f[v][k][s2]));
for(reg int j=k,s1,s2,s;j--;)for(s1=0;s1<=full;++s1)if(f[u][j][s1]||f[v][j][s1])for(s=s1;s<=full;s=(s+1)|s1)
{
if(high[s2=s^s1]>j)continue;
Pls(g[j][s],Mul(f[u][j][s1],f[v][k][s2])),Pls(g[j][s],Mul(f[v][j][s1],f[u][k][s2]));
}
for(reg int j=0,s;j<=k;++j)for(s=0;s<=full;++s)f[u][j][s]=g[j][s];
}
if(u==1)return;
for(reg int j=0,s;j<=k;++j)for(s=0;s<=full;++s)g[j][s]=0;
for(reg int j=0,l,s,S,x;j<=k;++j)for(s=0;s<=full;++s)if(f[u][j][s])for(l=0;l<j;++l)if(~s>>l&1)
{
S=s&(1<<l)-1,x=sz[u]-sum[S];
if(x==c[l])Pls(g[l][s^(1<<l)],f[u][j][s]);
if(x==d[l]&&l>high[s])Pls(g[j][s^(1<<l)],f[u][j][s]);
}
for(reg int j=0,s;j<=k;++j)for(s=0;s<=full;++s)Pls(f[u][j][s],g[j][s]);
}
main()
{
n=read();
for(reg int i=n;--i;add(read(),read()));
k=read(),full=(1<<k)-1;
for(reg int i=0;i<k;++i)c[i]=read(),d[i]=i?c[i-1]-c[i]:n-c[i];
high[0]=-1;
for(reg int s=0,i;s<=full;++s)for(i=0;i<k;++i)if(s>>i&1)high[s]=i;
for(reg int s=1;s<=full;++s)sum[s]=sum[s^(1<<high[s])]+d[high[s]];
dfs(1);
for(reg int i=0;i<=k;++i)Pls(ans,f[1][i][full]);
printf("%d\n",ans);
}

浙公网安备 33010602011771号