P5825 排列计数 题解 / 二项式反演 容斥
题目传送门:P5825 排列计数。
考虑二项式反演,考虑计算钦定了 \(i\) 个上升位的答案,那么相当于分成了 \(n-k\) 个组,每组的元素均单调递增,相当于把 \(1\) 到 \(n\) 放入 \(n-k\) 个集合且可以集合为空的方案数。
设有 \(i\) 个集合的答案为 \(f_i\) 那么考虑对空集合个数容斥,假设我们钦定了 \(j\) 个空集,那么剩下的方案数显然为 \((i-j)^n\),那么 \(f_i = \sum_{j=0}^i (-1)^j {i\choose j} (i-j)^n\)。
最后答案直接二项式反演即可。
直接做是 \(O(n^2)\) 的,但我不会 poly /ll,会了,直接 ntt 即可。
#include<bits/stdc++.h>
#define int long long
#define double long double
using namespace std;
inline int read(){
char c=getchar();
int f=1,ans=0;
while(c<48||c>57) f=(c==45?f=-1:1),c=getchar();
while(c>=48&&c<=57) ans=(ans<<1)+(ans<<3)+(c^48),c=getchar();
return ans*f;
}
const int N=534288+10,mod=998244353,G=3,invG=(mod+1)/3;
int r[N],f[N],a[N],fac[N],invfac[N],n;
inline int qpow(int a,int b){int s=1;while(b) (b&1)?s=s*a%mod:1,a=a*a%mod,b>>=1;return s;}
void exgcd(int a,int b,int &x,int &y){if (b==0) x=1,y=0;else exgcd(b,a%b,y,x),y-=a/b*x;}
inline int inv(int a){int x,y;exgcd(a,mod,x,y);return (x%mod+mod)%mod;}
inline void ntt(int *a,int n,int op){
for (int i=0;i<n;i++) if (i<r[i]) swap(a[i],a[r[i]]);
for (int i=2;i<=n;i<<=1){
int g1=qpow(op==1?G:invG,(mod-1)/i);
for (int j=0;j<n;j+=i){
int gk=1;
for (int k=0,x,y;k<(i>>1);k++,gk=gk*g1%mod) x=a[j+k],y=a[j+k+i/2]*gk%mod,a[j+k]=(x+y)%mod,a[j+k+i/2]=(x-y+mod)%mod;
}
}
}
inline void mul(int *a,int *b,int n,int m){
int len=1,lg=0;
while(len<=n+m) len<<=1,lg++;
for (int i=0;i<len;i++) r[i]=(r[i>>1]>>1)|((i&1)<<lg-1);
ntt(a,len,1),ntt(b,len,1);
for (int i=0;i<len;i++) a[i]=a[i]*b[i]%mod;
ntt(a,len,-1);
int tmp=inv(len);for (int i=0;i<len;i++) a[i]=a[i]*tmp%mod;
}
main(){
n=read();
fac[0]=invfac[0]=1;for (int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod,invfac[i]=invfac[i-1]*inv(i)%mod;
for (int i=0;i<=n;i++) f[i]=(i&1)?mod-invfac[i]:invfac[i],a[i]=qpow(i,n)*invfac[i]%mod;
mul(f,a,n,n);
for (int i=0;i<=n;i++) f[i]=f[i]*fac[i]%mod;for (int i=n+1;i<N;i++) f[i]=0;
for (int i=0;i<=n;i++) f[i]=f[i]*fac[n-i]%mod;
memset(a,0,sizeof(a));
for (int i=0;i<=n;i++) a[i]=((i&1)?(mod-1):1)*invfac[i]%mod;
mul(f,a,n,n);
for (int i=0;i<=n;i++) f[i]=f[i]*invfac[n-i]%mod;
for (int i=n;i>=0;i--) printf("%lld ",f[i]);
return 0;
}

浙公网安备 33010602011771号