codechef JIIT
考虑如何计算操作后的奇数个数。
假设在行操作了\(i\),列操作\(j\)次。
由补集转化,操作后奇数个数\(=im+jn-ij\)
令\(f_i\)表示为行操作\(i\)次的答案,\(g_i\)表示列操作\(i\)次的答案,则答案就是符合要求的所有\(f_i*g_j\)。
列出答案的EGF。
由于每行是相同的,强制让选择的\(i\)行在网格的前\(i\)行
\(f_i=({e^x-e^{-x}\over 2})^i*({e^x+e^{-x}\over 2})^{n-i}[x^q]*q!*C_n^i\)。
后面的\(C_n^i\)表示选择的方案数。
\(f_i=2^n(e^x-e^{-x})^i*(e^x+e^{-x})^{n-i}[x^q]*q!*C_n^i\)
如果使用二项式定理展开\((e^x-e^{-x})\)和\((e^x+e^{-x})\),再暴力进行多项式乘法,则生成一个\(2n\)次关于\(e^x\)的多项式(指数可能是负的),时间复杂度\(O(n^3)\)
但是注意到\(f_{i+1}\)的后面的项等于\(f_i\)的多项式乘以\((e^x-e^{-x})\)再除以\((e^x+e^{-x})\),所以可以在\(O(n)\)的时间内得到下面的多项式,时间复杂度\(O(n^2)\)。
\(g\)可以同理计算。
这样子已经可以通过本题,然而我们还有更为优秀的做法。
考虑容斥。(CTS2019 珍珠)
令\(h_i\)表示钦定至少\(i\)行为为奇数,其它任意。
列出答案的EGF。
则\(h_i=\sum C_i^j*g_j\)。
\(h_i=({e^x-e^{-x}\over 2})^i*e^{x(n-i)}[x^q]*q!*C_n^i\)
\(h_i=2^{-i}({e^x-e^{-x}})^i*e^{x(n-i)}[x^q]*q!*C_n^i\)
\(h_i=2^{-i}\frac{1}{e^{xi}}({e^{2x}-1})^i*e^{x(n-i)}[x^q]*q!*C_n^i\)
\(h_i=2^{-i}({e^{2x}-1})^i*e^{x(n-2i)}[x^q]*q!*C_n^i\)
\(h_i=C_n^i2^{-i}\sum_{j=0}^{i}e^{2j}(-1)^{i-j}e^{x(n-2i)}[x^q]*q!C_{i}^j\)
\(h_i=C_n^i2^{-i}\sum_{j=0}^{i}(-1)^{i-j}e^{x(n-2i+2j)}[x^q]*q!C_{i}^j\)
\(h_i=C_n^i2^{-i}\sum_{j=0}^{i}(-1)^{i-j}e^{x(n-2(i-j))}[x^q]*q!C_{i}^j\)
\(h_i=C_n^i2^{-i}i!\sum_{j=0}^{i}(-1)^{i-j}((n-2(i-j))^q\frac{1}{(i-j)!j!}\)
令\(a_i=(-1)^i((n-2i))^q\frac{1}{i!},b_i=\frac{1}{i!}\)
则\(a*b=h\)
考虑二项式反演,\(g_i=\sum_{j\geq i}C_{j}^i(-1)^{j-i}h_j=\frac{1}{i!}\sum_{j\geq i}\frac{j!}{(j-i)!}(-1)^{j-i}h_j\)
令\(a_i=(-1)^i\frac{1}{i!},b_i=i!h_i\)
则\(a,b\)的减法卷积就是\(g\)。
计算答案考虑\(im+jn-ij\leq k\)
则\(j(n-2i)\leq k-im\)
根据\((n-2i)\)的正负性分类讨论,使用前缀和计算。
时间复杂度\(O(n\log_2n)\)
细节:
在卷积的时候注意把vector数组resize,以防后面的项错误产生贡献
#include<bits/stdc++.h>
using namespace std;
#define mo 998244353
#define N 500010
#define ll unsigned long long
#define int long long
#define pl vector<int>
int qp(int x,int y){
int r=1;
for(;y;y>>=1,x=1ll*x*x%mo)
if(y&1)r=1ll*r*x%mo;
return r;
}
int rev[N],v,le,w[N],p[N],ans[N];
void deb(pl x){
for(int i:x)cout<<i<<' ';
puts("");
}
void init(int n){
v=1;
le=0;
while(v<n)le++,v*=2;
for(signed i=0;i<v;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(le-1));
int g=qp(3,(mo-1)/v);
w[v/2]=1;
for(int i=v/2+1;i<v;i++)
w[i]=1ull*w[i-1]*g%mo;
for(signed i=v/2-1;~i;i--)
w[i]=w[i*2];
}
void fft(int v,pl &a,int t){
static unsigned long long b[N];
int s=le-__builtin_ctz(v);
for(int i=0;i<v;i++)
b[rev[i]>>s]=a[i];
int c=0;
w[0]=1;
for(signed i=1;i<v;i*=2,c++)
for(signed r=i*2,j=0;j<v;j+=r)
for(signed k=0;k<i;k++){
int tx=b[j+i+k]*w[k+i]%mo;
b[j+i+k]=b[j+k]+mo-tx;
b[j+k]+=tx;
}
for(int i=0;i<v;i++)
a[i]=b[i]%mo;
if(t==0)return;
int iv=qp(v,mo-2);
for(signed i=0;i<v;i++)
a[i]=1ull*a[i]*iv%mo;
a.resize(v);
reverse(a.begin()+1,a.end());
}
pl operator *(pl x,pl y){
int s=x.size()+y.size()-1;
if(x.size()<=30||y.size()<=30){
pl r;
r.resize(s);
for(int i=0;i<x.size();i++)
for(int j=0;j<y.size();j++)
r[i+j]=(r[i+j]+x[i]*y[j])%mo;
return r;
}
init(s);
x.resize(v);
y.resize(v);
fft(v,x,0);
fft(v,y,0);
//deb(x);
//deb(y);
for(int i=0;i<v;i++)
x[i]=x[i]*y[i]%mo;
fft(v,x,1);
x.resize(s);
return x;
}
void ad(pl &x,pl y,int l){
x.resize(max((int)x.size(),(int)y.size()+l));
for(int i=0;i<y.size();i++)
x[i+l]=(x[i+l]+y[i])%mo;
}
pl operator +(pl x,pl y){
ad(x,y,0);
return x;
}
int f[N],g[N],n,m,q,k,jc[N],ij[N],h[N],s[N];
int c(int y,int x){
if(y<0||x<0||y<x)
return 0;
return jc[y]*ij[x]%mo*ij[y-x]%mo;
}
void cal(int *f,int l){
pl x,y;
x.resize(l+1);
y.resize(l+1);
for(int i=0;i<=l;i++){
x[i]=qp(mo-1,i)*qp(l-2*i,q)%mo*ij[i]%mo;
y[i]=ij[i];
}
x=x*y;
for(int i=0;i<=l;i++)
h[i]=x[i]*qp(qp(2,mo-2),i)%mo*c(l,i)%mo*jc[i]%mo;
for(int i=0;i<=l;i++)
x[i]=jc[i]*h[i]%mo;
x.resize(l+1);
for(int i=0;i<=l;i++)
y[l-i]=qp(mo-1,i)*ij[i]%mo;
x=x*y;
for(int i=0;i<=l;i++)
f[i]=x[i+l]*ij[i]%mo;
}
signed main(){
int T;
jc[0]=1;
for(int i=1;i<N;i++)
jc[i]=jc[i-1]*i%mo;
ij[N-1]=qp(jc[N-1],mo-2);
for(int i=N-1;i;i--)
ij[i-1]=ij[i]*i%mo;
scanf("%lld",&T);
while(T--){
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
scanf("%lld%lld%lld%lld",&n,&m,&q,&k);
cal(f,n);
cal(g,m);
int va=0;
s[0]=g[0];
for(int i=1;i<=m;i++)
s[i]=(s[i-1]+g[i])%mo;
for(int i=0;i<=n;i++){
int p=n-2*i;
if(!p){
if(k-i*m>=0)
va=(va+f[i]*s[m])%mo;
}
if(p<0){
int v=ceil((long double)(k-i*m)/(long double)(n-2*i));
if(v<=0){
va=(va+s[m]*f[i]%mo)%mo;
}
else{
va=(va+(s[m]-s[v-1]+mo)%mo*f[i]%mo)%mo;
}
}
if(p>0){
int v=floor((long double)(k-i*m)/(long double)(n-2*i));
if(v>=0)
va=(va+f[i]*s[min(v,m)])%mo;
}
}
printf("%lld\n",va);
}
}

浙公网安备 33010602011771号