【XSY2666】排列问题 DP 容斥原理 分治FFT

题目大意

  有\(n\)种颜色的球,第\(i\)种有\(a_i\)个。设\(m=\sum a_i\)。你要把这\(m\)个小球排成一排。有\(q\)个询问,每次给你一个\(x\),问你有多少种方案使得相邻的小球同色的对数为\(x\)

  \(n\leq 10000,m\leq 200000\)

题解

  我们考虑把这些小球分段,每段内所有小球颜色相同,但相邻两段的小球颜色可以相同。

  设第\(i\)种颜色有\(b_i\)段,那么分\(j\)段的方案数是\(\frac{(\sum b_i)!}{\sum(bi!)}=\frac{j!}{\sum(bi!)}\)

  那么先DP,设\(f_{i,j}\)为前\(i\)种颜色,分了\(j\)段的方案数\(\div b_i!\)显然枚举第\(i\)中颜色分\(k\)段得

\[f_{i,j}+=f_{i-1,j-k}\times \binom{a_i-1}{k-1}\times\frac{1}{k!} \]

  那个组合数是插板法得到的。

  这个DP的时间复杂度是\(O(m^2)\)(因为枚举第\(i\)种颜色时\(k=1\ldots a_i,j=1\ldots s_i\)\(s\)\(a\)的前缀和))

  然后这个东西可以分治FFT优化到\(O(m\log m\log n)\)

  这样我们得到了分成\(i\)段的方案数\(g_i=f_{n,i}\times i!\),但相邻两段可能颜色相同。我们还要减掉这种情况。

  就是对于一种实际上分成 \(j\) 段的方案,它在分成 \(i\) 段的方案数中会被计算 \(\binom{m-j}{m-i}\) 次(就是在 \(m-j\) 个间隔中取 \(m-i\) 个)。

  答案 \(ans_i=g_i-\sum_{j<i}ans_j\binom{m-j}{i-j}\)

  可以简单暴力的通过分治FFT优化到\(O(m\log^2 m)\)。但有更好的做法。

  考虑容斥。其实总的\(g_j\)\(ans_i\)的贡献就是\({(-1)}^{i-j}\binom{m-j}{i-j}\)。直接FFT一次就可以得到答案。

\[\begin{align} ans_{k->i}&=\sum_{j=k}^{i-1}{(-1)^{j-k}}\binom{m-k}{j-k}\binom{m-j}{i-j}\\ &=\sum_{j=k}^{i-1}{(-1)^{j-k}}\frac{(m-k)!(m-j)!}{(j-k)!(m-j)!(i-j)!(m-i)!}\\ &=\sum_{j=k}^{i-1}{(-1)^{j-k}}\frac{(m-k)!}{(j-k)!(i-j)!(m-i)!}\\ &=\frac{(m-k)!}{(m-i)!(i-k)!}\sum_{j=k}^{i-1}{(-1)^{j-k}}\frac{(i-k)!}{(i-j)!(j-k)!}\\ &=\binom{m-k}{i-k}\sum_{j=k}^{i-1}{(-1)^{j-k}}\binom{i-k}{j-k}\\ &=\binom{m-k}{i-k}{(-1)}^{i-k} \end{align} \]

  那么相邻的小球同色的对数为\(x\)的答案就是\(ans_{m-x}\)

  时间复杂度:\(O(m\log m\log n+q)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
#include<vector>
#include<queue>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
void put(int x)
{
	if(!x)
	{
		putchar('0');
		return;
	}
	static int c[20];
	int t=0;
	while(x)
	{
		c[++t]=x%10;
		x/=10;
	}
	while(t)
		putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
const int p=998244353;
int fp(int a,int b)
{
	int s=1;
	for(;b;b>>=1,a=1ll*a*a%p)
		if(b&1)
			s=1ll*s*a%p;
	return s;
}
int inv[600010];
int fac[600010];
int ifac[600010];
namespace ntt
{
	const int g=3;
	int rev[600010];
	int w1[600010];
	int w2[600010];
	int n;
	void init(int m)
	{
		n=1;
		while(n<=m)
			n<<=1;
		int i;
		rev[0]=0;
		for(i=1;i<n;i++)
			rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
		for(i=1;i<=n;i<<=1)
		{
			w1[i]=fp(g,(p-1)/i);
			w2[i]=fp(w1[i],p-2);
		}
	}
	void ntt(int *a,int t)
	{
		int i,j,k;
		int u,v,w,wn;
		for(i=0;i<n;i++)
			if(rev[i]<i)
				swap(a[i],a[rev[i]]);
		for(i=2;i<=n;i<<=1)
		{
			wn=(t==1?w1[i]:w2[i]);
			for(j=0;j<n;j+=i)
			{
				w=1;
				for(k=j;k<j+i/2;k++)
				{
					u=a[k];
					v=1ll*a[k+i/2]*w%p;
					a[k]=(u+v)%p;
					a[k+i/2]=(u-v)%p;
					w=1ll*w*wn%p;
				}
			}
		}
		if(t==-1)
		{
			int inv=fp(n,p-2);
			for(i=0;i<n;i++)
				a[i]=1ll*a[i]*inv%p;
		}
	}
};
int g[600010];
int h[600010];
int ans[600010];
int a[600010];
int s[600010];
int n,m;
void add(int &a,int b)
{
	a=(a+b)%p;
}
typedef vector<int> vec;
vec mul(vec &a,vec &b)
{
	static int c[600010],d[600010];
	int n1=a.size()-1;
	int n2=b.size()-1;
	int m=n1+n2+1;
	ntt::init(m);
	int i;
	for(i=0;i<=n1;i++)
		c[i]=a[i];
	for(i=n1+1;i<ntt::n;i++)
		c[i]=0;
	for(i=0;i<=n2;i++)
		d[i]=b[i];
	for(i=n2+1;i<ntt::n;i++)
		d[i]=0;
	ntt::ntt(c,1);
	ntt::ntt(d,1);
	for(i=0;i<ntt::n;i++)
		c[i]=1ll*c[i]*d[i]%p;
	ntt::ntt(c,-1);
	vec s(n1+n2+1);
	for(i=1;i<=n1+n2;i++)
		s[i]=c[i];
	return s;
}
vec solve(int l,int r)
{
	if(l==r)
	{
		vec s(a[l]+1);
		int i;
		for(i=1;i<=a[l];i++)
			s[i]=1ll*ifac[i-1]*ifac[i]%p*ifac[a[l]-i]%p;
		return s;
	}
	int mid=(l+r)>>1;
	vec s1=solve(l,mid);
	vec s2=solve(mid+1,r);
	return mul(s1,s2);
}
int c[600010];
int d[600010];
priority_queue<pii,vector<pii>,greater<pii> > q;
void gao()
{
	int i;
	c[0]=0;
	for(i=1;i<=m;i++)
		c[i]=g[i];
	for(i=0;i<=m;i++)
	{
		d[i]=ifac[i];
		if(i&1)
			d[i]=-d[i];
	}
	ntt::init(2*m);
	for(i=m+1;i<ntt::n;i++)
		c[i]=d[i]=0;
	ntt::ntt(c,1);
	ntt::ntt(d,1);
	for(i=0;i<ntt::n;i++)
		c[i]=1ll*c[i]*d[i]%p;
	ntt::ntt(c,-1);
	for(i=1;i<=m;i++)
		g[i]=c[i];
}
int t=0;
vec f[20010];
int main()
{
	open("c");
	scanf("%d",&n);
	int i;
	for(i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
		s[i]=s[i-1]+a[i];
	}
	m=s[n];
	inv[0]=inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
	for(i=2;i<=m;i++)
	{
		inv[i]=-1ll*p/i*inv[p%i]%p;
#ifndef ONLINE_JUDGE
		inv[i]=(inv[i]+p)%p;
#endif
		fac[i]=1ll*fac[i-1]*i%p;
		ifac[i]=1ll*ifac[i-1]*inv[i]%p;
	}
//	f[0][0]=1;
	int times=1;
	for(i=1;i<=n;i++)
		times=1ll*times*fac[a[i]-1]%p;
//	for(i=1;i<=n;i++)
//	{
//		times=times*fac[a[i]-1]%p;
//		for(j=1;j<=s[i];j++)
//		{
//			for(k=1;k<=a[i]&&k<=j;k++)
//				add(f[i][j],f[i-1][j-k]*ifac[k-1]%p*ifac[a[i]-k]%p*ifac[k]%p);
////				add(f[i][j],f[i-1][j-k]*c(a[i]-1,k-1)%p*ifac[k]%p);
////			f[i][j]=f[i][j]*fac[a[i]-1]%p;
//		}
//	}
	int j;
	for(i=1;i<=n;i++)
	{
		f[i].resize(a[i]+1);
		for(j=1;j<=a[i];j++)
			f[i][j]=1ll*ifac[j-1]*ifac[j]%p*ifac[a[i]-j]%p;
		q.push(pii(a[i],i));
	}
	t=n;
	for(i=1;i<n;i++)
	{
		int n1=q.top().first;
		int x=q.top().second;
		q.pop();
		int n2=q.top().first;
		int y=q.top().second;
		q.pop();
		f[++t]=mul(f[x],f[y]);
		f[x].clear();
		f[y].clear();
		q.push(pii(n1+n2+1,t));
	}
	vec ss=f[t];
//	vec ss=solve(1,n);
	for(i=1;i<=m;i++)
		g[i]=1ll*ss[i]*fac[i]%p*times%p;
#ifndef ONLINE_JUDGE
	for(i=1;i<=m;i++)
		add(g[i],p);
#endif
//		g[i]=f[n][i]*fac[i]%p*times%p;	
	for(i=1;i<=m;i++)
		g[i]=1ll*g[i]*fac[m-i]%p;
	gao();
	for(i=1;i<=m;i++)
	{
		g[i]=1ll*g[i]*ifac[m-i]%p;
		add(g[i],p);
	}
//	for(i=1;i<=m;i++)
//	{
//		for(j=1;j<i;j++)
//			add(ans[i],h[j]%p*ifac[i-j]%p);
//		ans[i]=-ans[i]*ifac[m-i]%p;
//		ans[i]=(ans[i]+g[i])%p;
//			add(ans[i],-ans[j]*c(m-j,i-j));
//		add(ans[i],p);
//		h[i]=ans[i]*fac[m-i]%p;
//	}
	int q;
	int x;
	scanf("%d",&q);
	while(q--)
	{
		scanf("%d",&x);
		printf("%lld\n",g[m-x]);
	}
	return 0;
}
posted @ 2018-03-06 11:22  ywwyww  阅读(650)  评论(0编辑  收藏  举报