bzoj 4555 NTT优化子集斯特林

题目大意

读入n
\(f(n)=\sum_{i=0}^n\sum_{j=0}^i\left\{\begin{matrix}i \\ j\end{matrix}\right\}*2^j*j!\)

分析

\(f(n)=\sum_{i=0}^n\sum_{j=0}^i\left\{\begin{matrix}i \\ j\end{matrix}\right\}*2^j*j!\)
因为斯特林三角中\(j>i\)时值为0,j枚举上界可以改为n
\(f(n)=\sum_{i=0}^n\sum_{j=0}^n\left\{\begin{matrix}i \\ j\end{matrix}\right\}*2^j*j!\)
改下求和顺序
\(f(n)=\sum_{j=0}^n2^j*j!\sum_{i=0}^n\left\{\begin{matrix}i \\ j\end{matrix}\right\}\)
关于斯特林三角形总和公式的推导见我上一篇博客

solution

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL Q=998244353;
const int N=262144;
const int M=262145;

inline int rd(){
	int x=0;bool f=1;char c=getchar();
	for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
	for(;isdigit(c);c=getchar()) x=x*10+c-48;
	return f?x:-x;
}

int n;
int rev[N];
LL g;
LL fac[M];
LL ifac[M];
LL inv[M];
LL a[N];
LL b[N];
LL c[N];

LL pwr(LL x,LL tms,LL mod){
	LL res=1;
	for(;tms>0;tms>>=1){
		if(tms&1) res=res*x%mod;
		x=x*x%mod;
	}
	return res;
}

void NTT(LL *a,int fl){
	int i,j,k;
	LL Wn,W,u,v;
	for(i=0;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(i=2;i<=N;i<<=1){
		if(fl==1) Wn=pwr(g,(Q-1)/i,Q);
		else Wn=pwr(inv[g],(Q-1)/i,Q);
		for(j=0;j<N;j+=i){
			for(W=1,k=j;k<j+i/2;k++,W=W*Wn%Q){
				u=a[k];
				v=a[k+i/2]*W%Q;
				a[k]=(u+v)%Q;
				a[k+i/2]=((u-v)%Q+Q)%Q;
			}
		}
	}
	if(fl==-1)
		for(i=0;i<N;i++) a[i]=a[i]*inv[N]%Q;
}

bool judge(LL x,LL mm){
	for(int i=2;i*i<=mm;i++)
		if((mm-1)%i==0&&pwr(x,(mm-1)/i,mm)==1) return 0;
	return 1; 
}

LL getrt(LL mm){
	if(mm==2)return 1;
	for(int i=2;;i++)
		if(judge(i,mm)) return i;
}

int main(){
	int i,kd;
	
	n=rd();
	
	for(i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(N>>1):0);
	for(inv[1]=1,i=2;i<M;i++) inv[i]=(Q-Q/i)*inv[Q%i]%Q;
	for(fac[0]=1,i=1;i<M;i++) fac[i]=fac[i-1]*i%Q;
	for(ifac[0]=1,i=1;i<M;i++) ifac[i]=ifac[i-1]*inv[i]%Q;
	
	for(i=0;i<=n;i++){
		kd=(i&1)?-1:1;
		a[i]=((kd*ifac[i])%Q+Q)%Q;
	}
	
	b[0]=1;b[1]=n+1;
	for(i=2;i<=n;i++){
		b[i]=((pwr(i,n+1,Q)-1)%Q+Q)%Q*inv[i-1]%Q*ifac[i]%Q;
	}
	
	g=getrt(Q);
	NTT(a,1);
	NTT(b,1);
	for(i=0;i<N;i++) c[i]=a[i]*b[i]%Q;
	NTT(c,-1);
	
	LL ans=0;
	for(i=0;i<=n;i++)
		ans=(ans+(pwr(2,i,Q)*fac[i]%Q*c[i]%Q))%Q;
	
	printf("%lld\n",ans);
	
	return 0;
}
posted @ 2017-02-27 08:52  _zwl  阅读(331)  评论(0编辑  收藏  举报