Codeforces 1553I - Stairs(分治 NTT+容斥)

Codeforces 题面传送门 & 洛谷题面传送门

u1s1 感觉这道题放到 D1+D2 里作为 5250 分的 I 有点偏简单了吧

首先一件非常显然的事情是,如果我们已知了排列对应的阶梯序列,那么排列中每个极长的连续阶梯段就已经确定了,具体来说,由于显然极大的连续段之间不能相交,因此假设 \(a\)\(p\) 的阶梯序列,对于 \(a\) 数组中每个值相同的极大连续段 \([l,r]\),显然我们只能每 \(a_l\) 个元素将其划分成 \([l,l+a_l-1],[l+a_l,l+2a_l-1],\cdots\) 这样 \(\dfrac{r-l+1}{a_l}\) 个连续段,当然如果 \((r-l+1)\bmod a_l\ne 0\) 那答案肯定就是 \(0\) 了。因此这样我们考虑将所有极大连续段的长度排成一个序列 \(b_1,b_2,\cdots,b_m\),那么我们如果将每个极大段中最小的数拎出来离散化,肯定会形成一个 \(1\sim m\) 的排列,那么我们的任务就是给每个数安排上一个 \(1\sim m\) 的编号,使得这些编号组成一个 \(1\sim m\)​ 的排列,并给每个连续段钦定一个顺序(从大到小 or 从小到大),值得注意的是,如果 \(b_i=1\),那么从小到大和从大到小是相同的,只有一种钦定方式,记 \(c_i\) 为第 \(i\) 个连续段的编号,那么一个钦定方式合法当且仅当不存在以下两种情形:

  • \(\exists i\in[1,m-1],s.t.c_{i+1}=c_i+1\) 且第 \(i\) 个连续段和第 \(i+1\) 个连续段都是从小到大
  • \(\exists i\in[1,m-1],s.t.c_{i+1}=c_i-1\)​ 且第 \(i\)​ 个连续段和第 \(i+1\)​ 个连续段都是从大到小

直接做肯定不行,因此考虑容斥。我们考虑将这些连续段划分成 \(k\)​ 段,并且每一段我们钦定这一段中的数必须构成一个大的阶梯序列,那么对于一个 \(k\)​,其容斥系数自然就是 \((-1)^{m-k}\)​。那么怎么求对于一个 \(k\)​ 的答案呢?考虑 \(dp\)​,\(dp_{i,j}\)​ 表示前 \(i\)​ 个数划分成 \(j\)​ 段的方案数,转移就枚举上一段的结束位置 \(l\)​,如果 \(i-l=1\)​ 且 \(b_i=1\)​ 那么 \(dp_{i,j}\leftarrow dp_{l,j-1}\)​,否则 \(dp_{i,j}\leftarrow dp_{l,j-1}·2\)​,一目了然。最终答案 \(ans=\sum\limits_{i=1}^m(-1)^{m-i}i!·dp_{m,i}\)

这样直接做是三方的,前缀和优化一下可以做到平方,但还是不够,考虑继续优化。首先注意到最后统计答案涉及阶乘,而这个东西不像 \((-1)^i\)​​​​ 能够搞在状态里面,因此第二维“划分成多少段”不能省。那么怎么办呢?就分治 NTT 呗。注意到一段连续的连续段如果贡献为 \(2\)​​​,那么它不论怎么拼贡献还是 \(2\)​​​,因此当我们分治一段区间 \([l,r]\)​​​,我们设 \(f_{k,x,y}\)​​​ 表示有多少种划分 \([l,r]\)​​​ 中连续段的方法,满足最左边一个连续段的贡献为 \(x\)​,最右边的一个连续段的贡献为 \(y\)​,其中 \(0\)​ 表示 \(1\)​,\(1\)​ 表示 \(2\)​​​,合并左右区间时就分 \(mid,mid+1\)​ 被划分在同一段和不划分在同一段中两种情况即可。最后就按照上面的式子计算答案即可。

时间复杂度 \(n\log^2n\),不过由于自带 \(48\) 的常数所以喜提最劣解/qd

个别细节见代码吧:

const int MAXN=1e5;
const int MAXP=1<<18;
const int pr=3;
const int ipr=332748118;
const int MOD=998244353;
const int INV2=499122177;
int qpow(int x,int e){
	int ret=1;
	for(;e;e>>=1,x=1ll*x*x%MOD) if(e&1) ret=1ll*ret*x%MOD;
	return ret;
}
int rev[MAXP+5];
void NTT(vector<int> &a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
		for(int j=0;j<len;j+=i){
			for(int k=0,w=1;k<(i>>1);k++,w=1ll*w*W%MOD){
				int X=a[j+k],Y=1ll*w*a[(i>>1)+j+k]%MOD;
				a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
			}
		}
	} if(!~type){
		int ivn=qpow(len,MOD-2);
		for(int i=0;i<len;i++) a[i]=1ll*a[i]*ivn%MOD;
	}
}
vector<int> conv(vector<int> a,vector<int> b){
	int LEN=1;while(LEN<a.size()+b.size()) LEN<<=1;
	a.resize(LEN,0);b.resize(LEN,0);NTT(a,LEN,1);NTT(b,LEN,1);
	for(int i=0;i<LEN;i++) a[i]=1ll*a[i]*b[i]%MOD;NTT(a,LEN,-1);
	return a;
}
int n,m,a[MAXN+5],b[MAXN+5],fac[MAXN+5];
void init_fac(int n){for(int i=(fac[0]=1);i<=n;i++) fac[i]=1ll*fac[i-1]*i%MOD;}
struct dat{
	vector<int> v[2][2];
	void resize(int x){
		for(int i=0;i<2;i++) for(int j=0;j<2;j++)
			v[i][j].resize(x+1);
	}
	vector<int>* operator [](int x){return v[x];}
};
dat solve(int l,int r){
	if(r-l+1<=3){
		dat ret;ret.resize(r-l+1);
		if(r-l+1==2){
			ret[b[l]][b[r]][2]=(b[l]+1)*(b[r]+1);
			ret[1][1][1]=2;
		} else {
			ret[b[l]][b[r]][3]=(b[l]+1)*(b[r]+1)*(b[l+1]+1);
			ret[b[l]][1][2]=(b[l]+1<<1);ret[1][b[r]][2]+=(b[r]+1<<1);
			ret[1][1][1]=2;
		} return ret;
	} int mid=l+r>>1;
	dat L=solve(l,mid),R=solve(mid+1,r);
	dat ret;ret.resize(r-l+1);
	for(int i=0;i<2;i++) for(int j=0;j<2;j++){
		for(int x=0;x<2;x++) for(int y=0;y<2;y++){
			vector<int> tmp=conv(L[i][x],R[y][j]);
//			printf("%d %d %d %d\n",i,x,y,j);
			for(int t=1;t<=r-l+1;t++) ret[i][j][t]=(ret[i][j][t]+tmp[t])%MOD;
			for(int t=1;t<=r-l;t++){
				if(x+y==2) ret[i][j][t]=(ret[i][j][t]+1ll*tmp[t+1]*INV2)%MOD;
				else if(x+y==1) ret[i][j][t]=(ret[i][j][t]+tmp[t+1])%MOD;
				else ret[i][j][t]=(ret[i][j][t]+2ll*tmp[t+1])%MOD;
			}
		}
	} return ret;
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int l=1,r;l<=n;l=r){
		r=l;while(a[l]==a[r]) ++r;
		if((r-l)%a[l]) return puts("0"),0;
		for(int j=1;j<=(r-l)/a[l];j++) b[++m]=(a[l]>1);
	} init_fac(m);
	if(m==1) return printf("%d\n",(b[1])?2:1),0;
	dat res=solve(1,m);int ans=0;
	for(int i=1;i<=m;i++){
		int sum=0;
		for(int j=0;j<2;j++) for(int k=0;k<2;k++) sum=(sum+res[j][k][i])%MOD;
		if((m-i)&1) ans=(ans-1ll*sum*fac[i]%MOD+MOD)%MOD;
		else ans=(ans+1ll*sum*fac[i])%MOD;
	} printf("%d\n",ans);
	return 0;
}
posted @ 2021-09-03 23:10  tzc_wk  阅读(50)  评论(0)    收藏  举报