Atcoder Regular Contest 134 F - Flipping Coins(构造双射+分治 FFT)

构造双射的好题。

首先考虑这个 \(k\) 怎么求。称一条极长的满足 \(i<p[i]<p[p[i]]<\cdots\) 的序列为一条“链”,通过手玩,\(k\) 就是长度为奇数的链的数量。

直接求还是不容易,我们考虑先对排列进行一些变换,对于一个排列 \(p\),我们将排列中每个置换环按最小元素为关键字从大到小排列,并且对于每个置换环而言,从其最小元素开始,按顺序依次写下这个置换环内所有元素。这样可以得到另一个排列 \(q\),容易发现 \(p\to q\) 的映射是双射。这样我们相当于数 \(q\) 中长度为奇数的极长连续上升段个数。

对于每个上升段,我们考虑 ABAB 依次给元素标号,最后落单的元素标 C,这样 \(k\) 的值就是 C 的个数。然后对 ABC 序列进行 DP,每次加入一段形如 CC...CAB 的块。容易发现,对于含 \(len\) 个 C 的块,确定块内元素的相对大小关系的方案恰好有 \(len+1\) 种:A 一定是这里面最小的元素,确定 B 的大小之后,剩余 C 递减排列。因此考虑 \(dp_i\) 表示放到位置 \(i\) 的所有可能的排列个数乘以对应的 \(W^k\) 之和,那么有转移 \(dp_i=\sum\limits_{j<i-1}dp_j·W^{i-j-2}·\dfrac{1}{(i-j)!}·(i-j-1)\)。分治 FFT 即可。

时间复杂度 \(n\log^2n\)

const int MAXN=2e5;
const int MAXP=524288;
const int pr=3;
const int ipr=332748118;
const int MOD=998244353;
int qpow(int x,int e){int ret=1;for(;e;e>>=1,x=1ll*x*x%MOD)if(e&1)ret=1ll*ret*x%MOD;return ret;}
int n,w,rev[MAXP+5],ifac[MAXN+5],fac[MAXN+5];
void init_fac(int n){
	for(int i=(fac[0]=ifac[0]=ifac[1]=1)+1;i<=n;i++)ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
	for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=1ll*ifac[i-1]*ifac[i]%MOD;
}
void NTT(vector<int>&a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++)if(rev[i]<i)swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
		for(int j=0;j<len;j+=i)for(int k=0,w=1;k<(i>>1);k++,w=1ll*w*W%MOD){
			int X=a[j+k],Y=1ll*a[(i>>1)+j+k]*w%MOD;
			a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
		}
	}if(type==-1){
		int iv=qpow(len,MOD-2);
		for(int i=0;i<len;i++)a[i]=1ll*a[i]*iv%MOD;
	}
}
vector<int>conv(vector<int>a,vector<int>b){
	int LEN=1;while(LEN<a.size()+b.size())LEN<<=1;a.resize(LEN,0);b.resize(LEN,0);
	NTT(a,LEN,1);NTT(b,LEN,1);for(int i=0;i<LEN;i++)a[i]=1ll*a[i]*b[i]%MOD;
	NTT(a,LEN,-1);return a;
}
int dp[MAXN+5],a[MAXN+5];
void solve(int l,int r){
	if(l==r)return;int mid=l+r>>1;solve(l,mid);
	vector<int>v1,v2,v3;for(int i=l;i<=mid;i++)v1.pb(dp[i]);
	for(int i=0;i<=r-l+1;i++)v2.pb(a[i]);v3=conv(v1,v2);
	for(int i=mid+1;i<=r;i++)dp[i]=(dp[i]+v3[i-l])%MOD;
	solve(mid+1,r);
}
int main(){
#ifdef LOCAL
	freopen("in.txt","r",stdin);
	freopen("out.txt","w",stdout);
#endif
	scanf("%d%d",&n,&w);dp[0]=1;init_fac(MAXN);
	for(int i=2;i<=n;i++)a[i]=1ll*qpow(w,i-2)*(i-1)%MOD*ifac[i]%MOD;
	// for(int i=2;i<=n;i++)printf("%d %d\n",i,a[i]);
	solve(0,n);int res=0;
	for(int i=0;i<=n;i++)res=(res+1ll*dp[i]*ifac[n-i]%MOD*fac[n]%MOD*qpow(w,n-i))%MOD;
	printf("%d\n",res);
	return 0;
}
posted @ 2022-12-22 11:07  tzc_wk  阅读(127)  评论(0)    收藏  举报