[TJOI2019]唱、跳、rap和篮球——NTT+生成函数+容斥

题目链接:

[TJOI2019]唱、跳、rap和篮球

 

直接求不好求,我们考虑容斥,求出至少有$i$个聚集区间的方案数$ans_{i}$,那么最终答案就是$\sum\limits_{i=0}^{n}(-1)^i\ ans_{i}$

那么现在只需要考虑至少有$i$个聚集区间的方案数,我们枚举这$i$个区间的起始点位置,一共有$C_{n-3i}^{i}$种方案(可以看作是刚开始先将每个区间后三个位置去掉,从剩下$n-3i$个位置中选出$i$个区间起点,然后再在每个起点后面加上三个位置)。

那么剩下的$n-4i$个位置就是随便放这四种学生,假设第$j$种学生放了$a_{j}$个、一共有$num_{j}$个,那么方案数就是$\frac{(n-4i)!}{\prod_{j=1}^{4}a_{j}!}$。

由此可以构造出这四种学生的生成函数,以第一种学生为例:$\sum\limits_{j=0}^{num_{1}-i}\frac{x^j}{j!}$

将四个生成函数分别用$NTT$乘在一起然后取$x^{n-4i}$前的系数乘上$(n-4i)!$即可得到$n-4i$个位置随便放的方案数。

#include<set>
#include<map>
#include<cmath>
#include<stack>
#include<queue>
#include<bitset>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int mod=998244353;
int f[3000];
int g[3000];
int inv[2000];
int fac[2000];
int mask;
int n,a,b,c,d;
int ans;
int mn,mx;
int quick(int x,int y)
{
	int res=1;
	while(y)
	{
		if(y&1)
		{
			res=1ll*res*x%mod;
		}
		x=1ll*x*x%mod;
		y>>=1;
	}
	return res;
}
void NTT(int *a,int len,int opt)
{
	for(int i=0,k=0;i<len;i++)
	{
		if(i>k)
		{
			swap(a[i],a[k]);
		}
		for(int j=len>>1;(k^=j)<j;j>>=1);
	}
	for(int i=2;i<=len;i<<=1)
	{
		int t=i>>1;
		int x=quick(3,(mod-1)/i);
		if(opt==-1)
		{
			x=quick(x,mod-2);
		}
		for(int j=0;j<len;j+=i)
		{
			int w=1;
			for(int k=j;k<j+t;k++)
			{
				int tmp=1ll*a[k+t]*w%mod;
				a[k+t]=(a[k]-tmp+mod)%mod;
				a[k]=(a[k]+tmp)%mod;
				w=1ll*w*x%mod;
			}
		}
	}
	if(opt==-1)
	{
		int x=quick(len,mod-2);
		for(int i=0;i<len;i++)
		{
			a[i]=1ll*a[i]*x%mod;
		}
	}
}
int C(int n,int m)
{
	return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int solve(int x)
{
	memset(f,0,sizeof(f));
	memset(g,0,sizeof(g));
	for(int i=0;i<=a-x;i++)
	{
		f[i]=inv[i];
	}
	for(int i=0;i<=b-x;i++)
	{
		g[i]=inv[i];
	}
	NTT(f,mask,1);
	NTT(g,mask,1);
	for(int i=0;i<mask;i++)
	{
		f[i]=1ll*f[i]*g[i]%mod;
	}
	memset(g,0,sizeof(g));
	for(int i=0;i<=c-x;i++)
	{
		g[i]=inv[i];
	}
	NTT(g,mask,1);
	for(int i=0;i<mask;i++)
	{
		f[i]=1ll*f[i]*g[i]%mod;
	}
	memset(g,0,sizeof(g));
	for(int i=0;i<=d-x;i++)
	{
		g[i]=inv[i];
	}
	NTT(g,mask,1);
	for(int i=0;i<mask;i++)
	{
		f[i]=1ll*f[i]*g[i]%mod;
	}
	NTT(f,mask,-1);
	return 1ll*f[n-4*x]*fac[n-4*x]%mod*C(n-3*x,x)%mod;
}
int main()
{
	inv[1]=inv[0]=fac[0]=1;
	for(int i=1;i<=1000;i++)
	{
		fac[i]=1ll*fac[i-1]*i%mod;
	}
	for(int i=2;i<=1000;i++)
	{
		inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
	}
	for(int i=1;i<=1000;i++)
	{
		inv[i]=1ll*inv[i-1]*inv[i]%mod;
	}
	mask=1;
	scanf("%d%d%d%d%d",&n,&a,&b,&c,&d);
	mn=min(n/4,min(min(a,b),min(c,d)));
	mx=max(max(a,b),max(c,d));
	while(mask<=(mx<<2))
	{
		mask<<=1;
	}
	for(int i=0;i<=mn;i++)
	{
		if(i&1)
		{
			ans=(ans-solve(i)+mod)%mod;
		}
		else
		{
			ans=(ans+solve(i))%mod;
		}
	}
	printf("%d",ans);
}
posted @ 2019-05-03 21:56  The_Virtuoso  阅读(560)  评论(0编辑  收藏  举报