[HAOI2018]染色(容斥+NTT)

补充一篇详细得不能再详细的题解,比如让我自己看懂。

可能与前面的题解有些相同,我想补充一下自己的想法。

显然,最多 \(K\) 最大为 \(N=min(\lfloor \frac nS\rfloor,m)\)

首先,我们看到出现 \(S\) 次的颜色恰好 \(K\) 种的话,我们就可以考虑容斥,将其化为出现 \(S\) 次的颜色至少 \(K\) 种的方案数 \(f[K]\)

那么先选定在 \(m\) 中颜色中选定 \(i\) 种颜色,有 \(C_m^i\)

选定在 \(n\) 个位置中选定 \(iS\) 个位置,有 \(C_n^{iS}\)

但是 \(iS\) 个位置中随机排列的话,因为颜色相同交换算一种,所以有 \(\frac {iS!}{(S!)^i}\)

其他位置可以乱选,剩下 \(m-i\) 中颜色,\(n-iS\) 个位置,有 \((m-i)^{n-iS}\)

那么乘法原理,\(f[i]=C_m^iC_n^{iS}\frac {iS!}{(S!)^i}(m-i)^{n-iS}\)

现在定义 \(ans[i]\) 为出现 \(S\) 次颜色恰好 \(K\) 种的方案数,开始容斥。首先想到容斥系数 \((-1)^{j-i}\)

然后在 \(j\) 种颜色中 \(i\) 种颜色的方案数被算了 \(C_j^i\)

那么可以推出:

\(ans[i]=\sum_{j=i}^{N}(-1)^{j-i}C_j^if[j]\)

\(ans[i]=\sum_{j=i}^{N}(-1)^{j-i}\frac {j!}{i!(j-i)!}f[j]\)

\(ans[i]\times i!=\sum_{j=i}^{N}(-1)^{j-i}\frac {1}{(j-i)!}f[j]\times j!\)

定义 \(F(x)=\frac {(-1)^i}{i!}\ x^i,G(x)=f[i]\times i!\ x^i\)

但是这样还不能卷积。我们将 \(G(x)\) 系数翻转一下,\(ans[i]\times i!=F(x)*G(x)\)\(x^{N-i}\) 项的系数

那么最终 \(Ans=\sum_{i=0}^{N}w[i]\times ans[i]\)。用 \(NTT\) 实现多项式乘法,时间复杂度 \(O(n\log n)\)

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=100000+10;
const int maxm=10000000+10;
const int mod=1004535809;
int n,m,s,N,lim,w[maxn],f[maxn],a[maxn<<2],b[maxn<<2],r[maxn<<2],fac[maxm],inv[maxm];

inline int read(){
	register int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
	return (f==1)?x:-x;
}

int C(int n,int m){
	if(n<m) return 0;
	return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}

int fpow(int a,int b){
	int ret=1;
	for(;b;b>>=1,a=1ll*a*a%mod)
		if(b&1) ret=1ll*ret*a%mod;
	return ret;
}

void NTT(int *f,int n,int op){
	for(int i=0;i<n;i++)
		if(i<r[i]) swap(f[i],f[r[i]]);
	int buf,tmp,x,y;
	for(int len=1;len<n;len<<=1){
		tmp=fpow(3,(mod-1)/(len<<1));
		if(op==-1) tmp=fpow(tmp,mod-2);
		for(int i=0;i<n;i+=len<<1){
			buf=1;
			for(int j=0;j<len;j++){
				x=f[i+j];y=1ll*buf*f[i+j+len]%mod;
				f[i+j]=(x+y)%mod;f[i+j+len]=(x-y+mod)%mod;
				buf=1ll*buf*tmp%mod;
			}
		}
	}
	if(op==1) return ;
	int inv=fpow(n,mod-2);
	for(int i=0;i<n;i++) f[i]=1ll*f[i]*inv%mod;
}

int main()
{
	scanf("%d%d%d",&n,&m,&s);N=min(n/s,m);
	for(int i=0;i<=m;i++) scanf("%d",&w[i]);
	fac[0]=fac[1]=inv[0]=inv[1]=1;
	for(int i=2;i<=max(n,m);i++){
		fac[i]=1ll*fac[i-1]*i%mod;
		inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
	}
	for(int i=2;i<=max(n,m);i++) inv[i]=1ll*inv[i]*inv[i-1]%mod;
	for(int i=0;i<=N;i++) f[i]=1ll*C(m,i)*C(n,i*s)%mod*fac[i*s]%mod*fpow(inv[s],i)%mod*fpow(m-i,n-i*s)%mod;
	for(int i=0;i<=N;i++){
		a[i]=(((i&1)?-1:1)*inv[i]+mod)%mod;
		b[i]=1ll*f[i]*fac[i]%mod;
	}
	reverse(b,b+N+1);
	for(lim=1;lim<=(N<<1);lim<<=1);
	for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)?(lim>>1):0);
	NTT(a,lim,1);NTT(b,lim,1);
	for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
	NTT(a,lim,-1);
	reverse(a,a+N+1);
	int ans=0;
	for(int i=0;i<=N;i++) ans=(ans+1ll*w[i]*a[i]%mod*inv[i]%mod)%mod;
	printf("%d\n",ans);
	return 0;
}
posted @ 2019-01-16 21:31  Owen_codeisking  阅读(317)  评论(0编辑  收藏  举报