[HAOI2018]染色(容斥+NTT)
补充一篇详细得不能再详细的题解,比如让我自己看懂。
可能与前面的题解有些相同,我想补充一下自己的想法。
显然,最多 \(K\) 最大为 \(N=min(\lfloor \frac nS\rfloor,m)\)
首先,我们看到出现 \(S\) 次的颜色恰好 \(K\) 种的话,我们就可以考虑容斥,将其化为出现 \(S\) 次的颜色至少 \(K\) 种的方案数 \(f[K]\)
那么先选定在 \(m\) 中颜色中选定 \(i\) 种颜色,有 \(C_m^i\) 种
选定在 \(n\) 个位置中选定 \(iS\) 个位置,有 \(C_n^{iS}\) 种
但是 \(iS\) 个位置中随机排列的话,因为颜色相同交换算一种,所以有 \(\frac {iS!}{(S!)^i}\) 种
其他位置可以乱选,剩下 \(m-i\) 中颜色,\(n-iS\) 个位置,有 \((m-i)^{n-iS}\) 种
那么乘法原理,\(f[i]=C_m^iC_n^{iS}\frac {iS!}{(S!)^i}(m-i)^{n-iS}\)
现在定义 \(ans[i]\) 为出现 \(S\) 次颜色恰好 \(K\) 种的方案数,开始容斥。首先想到容斥系数 \((-1)^{j-i}\)
然后在 \(j\) 种颜色中 \(i\) 种颜色的方案数被算了 \(C_j^i\) 次
那么可以推出:
\(ans[i]=\sum_{j=i}^{N}(-1)^{j-i}C_j^if[j]\)
\(ans[i]=\sum_{j=i}^{N}(-1)^{j-i}\frac {j!}{i!(j-i)!}f[j]\)
\(ans[i]\times i!=\sum_{j=i}^{N}(-1)^{j-i}\frac {1}{(j-i)!}f[j]\times j!\)
定义 \(F(x)=\frac {(-1)^i}{i!}\ x^i,G(x)=f[i]\times i!\ x^i\)
但是这样还不能卷积。我们将 \(G(x)\) 系数翻转一下,\(ans[i]\times i!=F(x)*G(x)\) 中 \(x^{N-i}\) 项的系数
那么最终 \(Ans=\sum_{i=0}^{N}w[i]\times ans[i]\)。用 \(NTT\) 实现多项式乘法,时间复杂度 \(O(n\log n)\)
\(Code\ Below:\)
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=100000+10;
const int maxm=10000000+10;
const int mod=1004535809;
int n,m,s,N,lim,w[maxn],f[maxn],a[maxn<<2],b[maxn<<2],r[maxn<<2],fac[maxm],inv[maxm];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
int C(int n,int m){
if(n<m) return 0;
return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int fpow(int a,int b){
int ret=1;
for(;b;b>>=1,a=1ll*a*a%mod)
if(b&1) ret=1ll*ret*a%mod;
return ret;
}
void NTT(int *f,int n,int op){
for(int i=0;i<n;i++)
if(i<r[i]) swap(f[i],f[r[i]]);
int buf,tmp,x,y;
for(int len=1;len<n;len<<=1){
tmp=fpow(3,(mod-1)/(len<<1));
if(op==-1) tmp=fpow(tmp,mod-2);
for(int i=0;i<n;i+=len<<1){
buf=1;
for(int j=0;j<len;j++){
x=f[i+j];y=1ll*buf*f[i+j+len]%mod;
f[i+j]=(x+y)%mod;f[i+j+len]=(x-y+mod)%mod;
buf=1ll*buf*tmp%mod;
}
}
}
if(op==1) return ;
int inv=fpow(n,mod-2);
for(int i=0;i<n;i++) f[i]=1ll*f[i]*inv%mod;
}
int main()
{
scanf("%d%d%d",&n,&m,&s);N=min(n/s,m);
for(int i=0;i<=m;i++) scanf("%d",&w[i]);
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=max(n,m);i++){
fac[i]=1ll*fac[i-1]*i%mod;
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=2;i<=max(n,m);i++) inv[i]=1ll*inv[i]*inv[i-1]%mod;
for(int i=0;i<=N;i++) f[i]=1ll*C(m,i)*C(n,i*s)%mod*fac[i*s]%mod*fpow(inv[s],i)%mod*fpow(m-i,n-i*s)%mod;
for(int i=0;i<=N;i++){
a[i]=(((i&1)?-1:1)*inv[i]+mod)%mod;
b[i]=1ll*f[i]*fac[i]%mod;
}
reverse(b,b+N+1);
for(lim=1;lim<=(N<<1);lim<<=1);
for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)?(lim>>1):0);
NTT(a,lim,1);NTT(b,lim,1);
for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
NTT(a,lim,-1);
reverse(a,a+N+1);
int ans=0;
for(int i=0;i<=N;i++) ans=(ans+1ll*w[i]*a[i]%mod*inv[i]%mod)%mod;
printf("%d\n",ans);
return 0;
}