[AGC021F] Trinity

若一个格子左、上、下都有黑格子,那么该格子是否为黑色是不影响最后的三元组的,因此只用统计这样的格子为白色的情况,这样就能考虑到所有三元组了。

考虑按列 \(DP\),设 \(f(i,j)\) 表示考虑前 \(i\) 列,已经有 \(j\) 行至少有一个黑色格子的行的方案数,最终答案为 \(\sum \binom{n}{i}f(m,i)\)。转移就是每次新增一列,考虑新增的有黑色格子的行数 \(k\),即 \(f(i,j)\) 转移到 \(f(i+1,j+k)\)

\(k=0\),那么这一列对三元组的贡献就只有这一列的最小行标和最大行标,相当于从已有的 \(j\) 行中选出不超过 \(2\) 个,贡献为 \(1+j+\binom{j}{2}\)

\(k>1\),考虑原有的行是否在这新加的一列中放黑格子,因为黑格子左、上、下一定有一个方向没有黑格子,所以若是原有的行放黑格子,最多有两行放,且必须是在最小行标或者最大行标的位置。讨论一下最小行标和最大行标是来自新加入的行还是来自原有的行,得贡献为 \(\binom{j+k}{k}+2\binom{j+k}{k+1}+\binom{j+k}{k+2}=\binom{j+k+2}{k+2}\),这个组合数也可以直接用组合意义来说明。

发现转移是卷积的形式,用 \(NTT\) 优化后可以做到 \(O(nm\log n)\)

#include<bits/stdc++.h>
#define maxn 64010
#define p 998244353
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int n,m,lim=1,inv,ans;
int rev[maxn];
ll f[maxn],g[maxn],h[maxn],fac[maxn],ifac[maxn];
ll qp(ll x,ll y)
{
	ll v=1;
	while(y)
	{
		if(y&1) v=v*x%p;
		x=x*x%p,y>>=1;
	}
	return v;
}
void NTT(ll *a,int type)
{
	for(int i=0;i<lim;++i)
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	for(int len=1;len<lim;len<<=1)
	{
		ll wn=qp(3,(p-1)/(len<<1));
		for(int i=0;i<lim;i+=len<<1)
		{
			ll w=1;
			for(int j=i;j<i+len;++j,w=w*wn%p)
			{
				ll x=a[j],y=w*a[j+len]%p;
				a[j]=(x+y)%p,a[j+len]=(x-y+p)%p;
			}
		}
	}
	if(type==1) return;
	for(int i=0;i<lim;++i) a[i]=a[i]*inv%p;
	reverse(a+1,a+lim);
}
void init()
{
	fac[0]=ifac[0]=f[0]=1;
	for(int i=1;i<=n+2;++i) fac[i]=fac[i-1]*i%p;
	ifac[n+2]=qp(fac[n+2],p-2);
	for(int i=n+1;i;--i) ifac[i]=ifac[i+1]*(i+1)%p;
	for(int i=1;i<=n;++i) h[i]=ifac[i+2];
	while(lim<=(n<<1)) lim<<=1;
	for(int i=0;i<lim;++i) rev[i]=(rev[i>>1]>>1)|((i&1)?lim>>1:0);
	inv=qp(lim,p-2),NTT(h,1);
}
int main()
{
	read(n),read(m),init();
	while(m--)
	{
		for(int i=0;i<=n;++i) g[i]=f[i]*ifac[i]%p,f[i]=f[i]*(((ll)i*i+i+2)/2%p)%p;
		for(int i=n+1;i<lim;++i) g[i]=0;
		NTT(g,1);
		for(int i=0;i<lim;++i) g[i]=g[i]*h[i]%p;
		NTT(g,-1);
		for(int i=1;i<=n;++i) f[i]=(f[i]+g[i]*fac[i+2]%p)%p;
	}
	for(int i=0;i<=n;++i) ans=(ans+f[i]*fac[n]%p*ifac[i]%p*ifac[n-i]%p)%p;
	printf("%d",ans);
    return 0;
}
posted @ 2021-03-27 15:21  lhm_liu  阅读(116)  评论(0编辑  收藏  举报