Atcoder Regular Contest 133 F - Random Transition(FWT+FFT)

格局打开了,FWT 居然还可以这么玩。

首先模型转化,每次以 \(\dfrac{x}{n}\) 的概率令 \(x\)\(1\)\(1-\dfrac{x}{n}\) 的概率令 \(x\)\(1\) 可以视作,有 \(n\) 个硬币摆成一排,其中 \(x\) 为正面向上的硬币个数,每次随机翻转一个硬币。

考虑将硬币朝上和朝下的状态视作一个二进制数并以此构建集合幂级数。考虑集合幂级数 \(F,G\)\([x^S]F(x)=a_{|S|},[x^S]G(x)=\dfrac{1}{n}[|S|=1]\),集合幂级数的乘法定义为异或卷积,那么我们要求的东西即为 \(F\times G^k\)

乍一看状态数是 \(2^n\) 的,不过发现此题中所有集合幂级数都满足一个性质:所有 popcount 相同的位的系数是相同的。因此可以用 \(O(n)\) 的信息存储一个集合幂级数。

接下来考虑如何进行异或卷积。异或卷积常用处理方法是 FWTxor 后转化为点值相乘,考虑如何 FWTxor。首先对于朴素的集合幂级数而言,有

\[\text{FWT}(F)_S=\sum\limits_{T}F_T(-1)^{|S\&T|} \]

考虑对于题目中的集合幂级数怎样 FWT。枚举 \(|T|\),以及有多少位满足 \(S,T\) 这一位都是 \(1\),以及有多少位满足 \(S\) 这一位是 \(1\)\(T\) 这一位是 \(0\),则可以列出以下式子:

\[\text{FWT}(F)_S=\sum\limits_{i=0}^nf_i\sum\limits_{x=0}^{i}\sum\limits_{y=0}^{n-i}\dbinom{i}{x}(-1)^x\dbinom{n-i}{y}[x+y=|S|] \]

其中 \(f_i\) 表示所有 \(|S|=i\)\(F_S\)(根据前面的推论它们都相等)

写成生成函数的形式就是

\[f'_k=[x^k]\sum\limits_{i=0}^nf_i(1-x)^i(1+x)^{n-i} \]

因此我们只用求出这个多项式的系数即可。

二项式定理把后面那一项拆一拆:

\[\sum\limits_{i=0}^nf_i(1-x)^i(1-x+2)^{n-i} \]

\[\sum\limits_{i=0}^nf_i\sum\limits_{j=0}^{n-i}(1-x)^{n-j}2^j\dbinom{n-i}{j} \]

\[\sum\limits_{i+j\le n}f_i(1-x)^{n-j}2^j\dfrac{(n-i)!}{(n-i-j)!j!} \]

首先考虑求出 \(g_j=\sum_{x-y=j}f_x\dfrac{(n-x)!}{(n-y)!}\),这样多项式可以写成

\[\sum\limits_{j=0}^ng_j2^j\dfrac{1}{j!}(1-x)^{n-j} \]

分治 FFT 即可。

知道了怎么 FWT,原问题也就迎刃而解。

代码中部分式子可能和上面推的不太一样。

const int MAXN=1e5;
const int MAXP=262144;
const int pr=3;
const int ipr=332748118;
const int MOD=998244353;
const int V=1e9;
const int INV2=MOD+1>>1;
int n,k,fac[MAXN+5],ifac[MAXN+5];
void init_fac(int n){
	for(int i=(fac[0]=ifac[0]=ifac[1]=1)+1;i<=n;i++)ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
	for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=1ll*ifac[i-1]*ifac[i]%MOD;
}
int binom(int n,int k){if(n<0||k<0||n<k)return 0;return 1ll*fac[n]*ifac[k]%MOD*ifac[n-k]%MOD;}
int qpow(int x,int e){int ret=1;for(;e;e>>=1,x=1ll*x*x%MOD)if(e&1)ret=1ll*ret*x%MOD;return ret;}
int rev[MAXP+5];
void FFT(vector<int>&a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
		for(int j=0;j<len;j+=i)for(int k=0,w=1;k<(i>>1);k++,w=1ll*w*W%MOD){
			int X=a[j+k],Y=1ll*a[(i>>1)+j+k]*w%MOD;
			a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
		}
	}if(type<0){
		int iv=qpow(len,MOD-2);
		for(int i=0;i<len;i++)a[i]=1ll*a[i]*iv%MOD;
	}
}
vector<int>conv(vector<int>a,vector<int>b){
	int LEN=1,lim=a.size()+b.size()-1;while(LEN<a.size()+b.size())LEN<<=1;
	a.resize(LEN,0);b.resize(LEN,0);FFT(a,LEN,1);FFT(b,LEN,1);
	for(int i=0;i<LEN;i++)a[i]=1ll*a[i]*b[i]%MOD;FFT(a,LEN,-1);
	while(a.size()>lim)a.ppb();return a;
}
pair<vector<int>,vector<int> > calc(vector<int>&h,int l,int r){
	if(l==r){return mp(vector<int>{h[l]},vector<int>{MOD-1,1});}
	int mid=l+r>>1;auto L=calc(h,l,mid),R=calc(h,mid+1,r);
	vector<int>res1=conv(R.fi,L.se);
	for(int i=0;i<L.fi.size();i++)res1[i]=(res1[i]+L.fi[i])%MOD;
	return mp(res1,conv(L.se,R.se));
}
vector<int>FWT(vector<int>a){
	vector<int>f(n+1),g(n+1),h;
	for(int i=0;i<=n;i++)f[i]=1ll*a[i]*fac[n-i]%MOD*qpow(MOD-1,i)%MOD,g[i]=ifac[i];
	h=conv(f,g);while(h.size()>n+1)h.ppb();
	for(int i=0;i<=n;i++)h[i]=1ll*h[i]*ifac[n-i]%MOD*qpow(2,n-i)%MOD;
	pair<vector<int>,vector<int> > res=calc(h,0,n);
	// for(int i=0;i<=n;i++)printf("%d%c",125ll*h[i]%MOD," \n"[i==n]);
	// for(int i=0;i<=n;i++)printf("%d%c",res.fi[i]," \n"[i==n]);
	return res.fi;
	// vector<int>b(n+1);
	// for(int c=0;c<=n;c++)for(int i=0;i<=n;i++)for(int j=0;j<=min(i,c);j++)
	// 	b[c]=(b[c]+1ll*binom(i,j)*qpow(MOD-1,j)%MOD*binom(n-i,c-j)%MOD*a[i])%MOD;
	// for(int i=0;i<=n;i++)printf("%d%c",b[i]," \n"[i==n]);
	// return b;
}
int main(){
#ifdef LOCAL
	freopen("in.txt","r",stdin);
	freopen("out.txt","w",stdout);
#endif
	init_fac(MAXN);scanf("%d%d",&n,&k);vector<int>f(n+1),g(n+1),h(n+1);
	for(int i=0;i<=n;i++)scanf("%d",&f[i]),f[i]=1ll*f[i]*qpow(V,MOD-2)%MOD;f=FWT(f);g[1]=1;g=FWT(g);
	for(int i=0;i<=n;i++)h[i]=1ll*f[i]*qpow(1ll*g[i]*ifac[n]%MOD*fac[i]%MOD*fac[n-i]%MOD,k)%MOD;
	h=FWT(h);for(int i=0;i<=n;i++)printf("%d%c",1ll*h[i]*qpow(INV2,n)%MOD," \n"[i==n]);
	return 0;
}
posted @ 2022-12-21 11:29  tzc_wk  阅读(132)  评论(0)    收藏  举报