题解:P14254 分割(divide)

题目:


交集的限制且 \(b_1\) 深度最小,说明要选 \(k\) 个点深度相同的点(可以从固定 \(b_1\) 选其他 \(k-1\) 个点的角度理解,略)。

手玩样例发现选点的时候我们被子树内最深深度限制,称 \(x\) 子树内最深深度为 \(h_x\)


把每层的点拎出来
\(b_1\)\(1\) 为根的点很特殊,所以考虑枚举 \(h_1\),然后发现 \(\{b_2…b_{k+1} \}\) 需要保证 \(h_i\ge h_1\)\(\exists h_i=h_1\)
\(1\) 为根的树更为厉害,它可以把所有别人没选的 \(h\) 都并起来,所以我们只需要给他预留一个 \(h_i\ge h_1\) 的位置就好了。(也就是满足 \(len+cnt-1\ge k\),变量定义见下方)


假设 \(h_1\) 的数量为 \(cnt\)\(h_i>h_1\) 的数量为 \(len\)

有两种情况:

  • \(\exists h_i=h_1\)
    答案贡献为(我们容斥掉 \(\nexists h_i=h_1\)):
    \(cnt C_{cnt+len-1}^{k-1}(k-1)!-cnt C_{len}^{k-1}(k-1)!\)

  • \(\nexists h_i=h_1,h_{k+1}=h_1\)
    这个条件在 \(len=k-1\) 时成立。
    答案贡献为:
    \(cnt\times (k-1)!\)

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
int read()
{
	int x=0,f=1;
	char c=getchar();
	while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
	return x*f;
} 
const int mo=998244353,QAQ=1e6+19,inf=1e6;
int ksm(int x,int k)
{
	int da=1;
	for(;k;k>>=1,x=x*x%mo) if(k&1) da=da*x%mo;
	return da;
}
int jc[QAQ];
vector<int> dian[QAQ];
int n,k,fa[QAQ],shen[QAQ],mi[QAQ];
vector<int> cun[QAQ];
void dfs1(int x)
{
	shen[x]=shen[fa[x]]+1;
	int ez=0;
	for(int v:dian[x])
	{
		if(v==fa[x]) continue;
		dfs1(v);
		mi[x]=max(mi[x],mi[v]);
		ez++;
	}
	if(!ez) mi[x]=shen[x];
	cun[shen[x]].push_back(mi[x]);
}
int cnm(int n,int m)
{
	if(m>n) return 0;
	return jc[n]*ksm(jc[n-m]*jc[m]%mo,mo-2)%mo;
}
int ans;
signed main()
{
	jc[0]=1;
	for(int i=1;i<=inf;i++) jc[i]=jc[i-1]*i%mo;
	cin>>n>>k;
	for(int i=2;i<=n;i++) fa[i]=read(),dian[fa[i]].push_back(i);
	dfs1(1);
	for(int dep=1;dep<=mi[1];dep++)
	{
		if(cun[dep].size()<k+1) continue;
		sort(cun[dep].begin(),cun[dep].end());
		int cnt=1,len=cun[dep].size(),zong=cun[dep].size();
		len--;
		for(int i=1;i<zong;i++)
		{
			if(cun[dep][i]!=cun[dep][i-1])
			{
				if(len+cnt-1>=k)
				{
					if(len==k-1) ans=(ans+cnt*jc[k-1])%mo;
					ans=(ans+cnt*cnm(cnt+len-1,k-1)%mo*jc[k-1]%mo)%mo;
					ans=(ans-(cnt*cnm(len,k-1)%mo*jc[k-1]%mo))%mo;
					cnt=1;
				}
			}
			else cnt++;
			len--;
		}
		if(len+cnt-1>=k)
		{
			if(len==k-1) ans=(ans+cnt*jc[k-1])%mo;
			if(cnt>=2)
				ans=(ans+cnt*cnm(cnt+len-1,k-1)%mo*jc[k-1]%mo)%mo,
				ans=(ans-(cnt*cnm(len,k-1)%mo*jc[k-1]%mo))%mo;
		}
	}
	cout<<(ans+mo)%mo;
	return 0;
}
posted @ 2025-10-18 21:47  _a1a2a3a4a5  阅读(9)  评论(0)    收藏  举报