codechef JIIT

考虑如何计算操作后的奇数个数。
假设在行操作了\(i\),列操作\(j\)次。
由补集转化,操作后奇数个数\(=im+jn-ij\)
\(f_i\)表示为行操作\(i\)次的答案,\(g_i\)表示列操作\(i\)次的答案,则答案就是符合要求的所有\(f_i*g_j\)
列出答案的EGF。
由于每行是相同的,强制让选择的\(i\)行在网格的前\(i\)
\(f_i=({e^x-e^{-x}\over 2})^i*({e^x+e^{-x}\over 2})^{n-i}[x^q]*q!*C_n^i\)
后面的\(C_n^i\)表示选择的方案数。
\(f_i=2^n(e^x-e^{-x})^i*(e^x+e^{-x})^{n-i}[x^q]*q!*C_n^i\)
如果使用二项式定理展开\((e^x-e^{-x})\)\((e^x+e^{-x})\),再暴力进行多项式乘法,则生成一个\(2n\)次关于\(e^x\)的多项式(指数可能是负的),时间复杂度\(O(n^3)\)
但是注意到\(f_{i+1}\)的后面的项等于\(f_i\)的多项式乘以\((e^x-e^{-x})\)再除以\((e^x+e^{-x})\),所以可以在\(O(n)\)的时间内得到下面的多项式,时间复杂度\(O(n^2)\)
\(g\)可以同理计算。
这样子已经可以通过本题,然而我们还有更为优秀的做法。
考虑容斥。(CTS2019 珍珠)
\(h_i\)表示钦定至少\(i\)行为为奇数,其它任意。
列出答案的EGF。
\(h_i=\sum C_i^j*g_j\)
\(h_i=({e^x-e^{-x}\over 2})^i*e^{x(n-i)}[x^q]*q!*C_n^i\)
\(h_i=2^{-i}({e^x-e^{-x}})^i*e^{x(n-i)}[x^q]*q!*C_n^i\)
\(h_i=2^{-i}\frac{1}{e^{xi}}({e^{2x}-1})^i*e^{x(n-i)}[x^q]*q!*C_n^i\)
\(h_i=2^{-i}({e^{2x}-1})^i*e^{x(n-2i)}[x^q]*q!*C_n^i\)
\(h_i=C_n^i2^{-i}\sum_{j=0}^{i}e^{2j}(-1)^{i-j}e^{x(n-2i)}[x^q]*q!C_{i}^j\)
\(h_i=C_n^i2^{-i}\sum_{j=0}^{i}(-1)^{i-j}e^{x(n-2i+2j)}[x^q]*q!C_{i}^j\)
\(h_i=C_n^i2^{-i}\sum_{j=0}^{i}(-1)^{i-j}e^{x(n-2(i-j))}[x^q]*q!C_{i}^j\)
\(h_i=C_n^i2^{-i}i!\sum_{j=0}^{i}(-1)^{i-j}((n-2(i-j))^q\frac{1}{(i-j)!j!}\)
\(a_i=(-1)^i((n-2i))^q\frac{1}{i!},b_i=\frac{1}{i!}\)
\(a*b=h\)
考虑二项式反演,\(g_i=\sum_{j\geq i}C_{j}^i(-1)^{j-i}h_j=\frac{1}{i!}\sum_{j\geq i}\frac{j!}{(j-i)!}(-1)^{j-i}h_j\)
\(a_i=(-1)^i\frac{1}{i!},b_i=i!h_i\)
\(a,b\)的减法卷积就是\(g\)
计算答案考虑\(im+jn-ij\leq k\)
\(j(n-2i)\leq k-im\)
根据\((n-2i)\)的正负性分类讨论,使用前缀和计算。
时间复杂度\(O(n\log_2n)\)
细节:
在卷积的时候注意把vector数组resize,以防后面的项错误产生贡献

#include<bits/stdc++.h>
using namespace std;
#define mo 998244353
#define N 500010
#define ll unsigned long long
#define int long long
#define pl vector<int>
int qp(int x,int y){
	int r=1;
	for(;y;y>>=1,x=1ll*x*x%mo)
		if(y&1)r=1ll*r*x%mo;
	return r;
}
int rev[N],v,le,w[N],p[N],ans[N];
void deb(pl x){
	for(int i:x)cout<<i<<' ';
	puts("");
}
void init(int n){
	v=1;
	le=0;
	while(v<n)le++,v*=2;
	for(signed i=0;i<v;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(le-1));
	int g=qp(3,(mo-1)/v);
	w[v/2]=1;
	for(int i=v/2+1;i<v;i++)
		w[i]=1ull*w[i-1]*g%mo;
	for(signed i=v/2-1;~i;i--)
		w[i]=w[i*2];
}
void fft(int v,pl &a,int t){
	static unsigned long long b[N];
	int s=le-__builtin_ctz(v);
   	for(int i=0;i<v;i++)
   		b[rev[i]>>s]=a[i];
	int c=0;
	w[0]=1;
    for(signed i=1;i<v;i*=2,c++)
    	for(signed r=i*2,j=0;j<v;j+=r)
            for(signed k=0;k<i;k++){
               	int tx=b[j+i+k]*w[k+i]%mo;
            	b[j+i+k]=b[j+k]+mo-tx;
            	b[j+k]+=tx;
            }
    for(int i=0;i<v;i++)
    	a[i]=b[i]%mo;
    if(t==0)return;
    int iv=qp(v,mo-2);
    for(signed i=0;i<v;i++)
    	a[i]=1ull*a[i]*iv%mo;
    a.resize(v);
    reverse(a.begin()+1,a.end());
}
pl operator *(pl x,pl y){
	int s=x.size()+y.size()-1;
	if(x.size()<=30||y.size()<=30){
		pl r;
		r.resize(s);
		for(int i=0;i<x.size();i++)
			for(int j=0;j<y.size();j++)
				r[i+j]=(r[i+j]+x[i]*y[j])%mo;
		return r;
	}
	init(s);
	x.resize(v);
	y.resize(v);
	fft(v,x,0);
	fft(v,y,0);
	//deb(x);
	//deb(y);
	for(int i=0;i<v;i++)
		x[i]=x[i]*y[i]%mo;
	fft(v,x,1);
	x.resize(s);
	return x;
}
void ad(pl &x,pl y,int l){
	x.resize(max((int)x.size(),(int)y.size()+l));
	for(int i=0;i<y.size();i++)
		x[i+l]=(x[i+l]+y[i])%mo;
}
pl operator +(pl x,pl y){
	ad(x,y,0);
	return x;
}
int f[N],g[N],n,m,q,k,jc[N],ij[N],h[N],s[N];
int c(int y,int x){
	if(y<0||x<0||y<x)
		return 0;
	return jc[y]*ij[x]%mo*ij[y-x]%mo;
}
void cal(int *f,int l){
	pl x,y;
	x.resize(l+1);
	y.resize(l+1);
	for(int i=0;i<=l;i++){
		x[i]=qp(mo-1,i)*qp(l-2*i,q)%mo*ij[i]%mo;
		y[i]=ij[i];
	}
	x=x*y;
	for(int i=0;i<=l;i++)
		h[i]=x[i]*qp(qp(2,mo-2),i)%mo*c(l,i)%mo*jc[i]%mo;
	for(int i=0;i<=l;i++)
		x[i]=jc[i]*h[i]%mo;
	x.resize(l+1);
	for(int i=0;i<=l;i++)
		y[l-i]=qp(mo-1,i)*ij[i]%mo;
	x=x*y;
	for(int i=0;i<=l;i++)
		f[i]=x[i+l]*ij[i]%mo;
}
signed main(){
	int T;
	jc[0]=1;
	for(int i=1;i<N;i++)
		jc[i]=jc[i-1]*i%mo;
	ij[N-1]=qp(jc[N-1],mo-2);
	for(int i=N-1;i;i--)
		ij[i-1]=ij[i]*i%mo;
	scanf("%lld",&T);
	while(T--){
		memset(f,0,sizeof(f));
		memset(g,0,sizeof(g));
		scanf("%lld%lld%lld%lld",&n,&m,&q,&k);
		cal(f,n);
		cal(g,m);
		int va=0;
		s[0]=g[0];
		for(int i=1;i<=m;i++)
			s[i]=(s[i-1]+g[i])%mo;
		for(int i=0;i<=n;i++){
			int p=n-2*i;
			if(!p){
				if(k-i*m>=0)
					va=(va+f[i]*s[m])%mo;
			}
			if(p<0){
				int v=ceil((long double)(k-i*m)/(long double)(n-2*i));
				if(v<=0){
					va=(va+s[m]*f[i]%mo)%mo;
				}
				else{
					va=(va+(s[m]-s[v-1]+mo)%mo*f[i]%mo)%mo;
				}
			}
			if(p>0){
				int v=floor((long double)(k-i*m)/(long double)(n-2*i));
				if(v>=0)
					va=(va+f[i]*s[min(v,m)])%mo;
			}
		}
		printf("%lld\n",va);
	}
}
posted @ 2020-12-16 07:49  celerity1  阅读(102)  评论(0)    收藏  举报