题解:P14254 分割(divide)
题目:
交集的限制且 \(b_1\) 深度最小,说明要选 \(k\) 个点深度相同的点(可以从固定 \(b_1\) 选其他 \(k-1\) 个点的角度理解,略)。
手玩样例发现选点的时候我们被子树内最深深度限制,称 \(x\) 子树内最深深度为 \(h_x\)。
把每层的点拎出来:
\(b_1\) 和 \(1\) 为根的点很特殊,所以考虑枚举 \(h_1\),然后发现 \(\{b_2…b_{k+1} \}\) 需要保证 \(h_i\ge h_1\) 且 \(\exists h_i=h_1\)。
\(1\) 为根的树更为厉害,它可以把所有别人没选的 \(h\) 都并起来,所以我们只需要给他预留一个 \(h_i\ge h_1\) 的位置就好了。(也就是满足 \(len+cnt-1\ge k\),变量定义见下方)
假设 \(h_1\) 的数量为 \(cnt\),\(h_i>h_1\) 的数量为 \(len\)。
有两种情况:
-
\(\exists h_i=h_1\)
答案贡献为(我们容斥掉 \(\nexists h_i=h_1\)):
\(cnt C_{cnt+len-1}^{k-1}(k-1)!-cnt C_{len}^{k-1}(k-1)!\) -
\(\nexists h_i=h_1,h_{k+1}=h_1\)
这个条件在 \(len=k-1\) 时成立。
答案贡献为:
\(cnt\times (k-1)!\)
代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read()
{
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*f;
}
const int mo=998244353,QAQ=1e6+19,inf=1e6;
int ksm(int x,int k)
{
int da=1;
for(;k;k>>=1,x=x*x%mo) if(k&1) da=da*x%mo;
return da;
}
int jc[QAQ];
vector<int> dian[QAQ];
int n,k,fa[QAQ],shen[QAQ],mi[QAQ];
vector<int> cun[QAQ];
void dfs1(int x)
{
shen[x]=shen[fa[x]]+1;
int ez=0;
for(int v:dian[x])
{
if(v==fa[x]) continue;
dfs1(v);
mi[x]=max(mi[x],mi[v]);
ez++;
}
if(!ez) mi[x]=shen[x];
cun[shen[x]].push_back(mi[x]);
}
int cnm(int n,int m)
{
if(m>n) return 0;
return jc[n]*ksm(jc[n-m]*jc[m]%mo,mo-2)%mo;
}
int ans;
signed main()
{
jc[0]=1;
for(int i=1;i<=inf;i++) jc[i]=jc[i-1]*i%mo;
cin>>n>>k;
for(int i=2;i<=n;i++) fa[i]=read(),dian[fa[i]].push_back(i);
dfs1(1);
for(int dep=1;dep<=mi[1];dep++)
{
if(cun[dep].size()<k+1) continue;
sort(cun[dep].begin(),cun[dep].end());
int cnt=1,len=cun[dep].size(),zong=cun[dep].size();
len--;
for(int i=1;i<zong;i++)
{
if(cun[dep][i]!=cun[dep][i-1])
{
if(len+cnt-1>=k)
{
if(len==k-1) ans=(ans+cnt*jc[k-1])%mo;
ans=(ans+cnt*cnm(cnt+len-1,k-1)%mo*jc[k-1]%mo)%mo;
ans=(ans-(cnt*cnm(len,k-1)%mo*jc[k-1]%mo))%mo;
cnt=1;
}
}
else cnt++;
len--;
}
if(len+cnt-1>=k)
{
if(len==k-1) ans=(ans+cnt*jc[k-1])%mo;
if(cnt>=2)
ans=(ans+cnt*cnm(cnt+len-1,k-1)%mo*jc[k-1]%mo)%mo,
ans=(ans-(cnt*cnm(len,k-1)%mo*jc[k-1]%mo))%mo;
}
}
cout<<(ans+mo)%mo;
return 0;
}

浙公网安备 33010602011771号