LOJ#3160. 「NOI2019」斗主地 打表+拉格朗日插值
裸做的话设一个 $p[i][j]$ 表示两个堆分别抽走 $i,j$ 个的概率.
转移的话就枚举当前是第几个,然后再枚举左/右面由下向上第几个贡献.
不在模意义下做,开 double 打表发现无论怎样洗牌,一次函数还是一次函数,二次函数还是二次函数.
那么我们只需暴力维护出牌的前 3 项,然后后面的项用拉格朗日插值求出即可.
code:
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 500009
#define ll long long
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
using namespace std;
int qpow(int x,int y) {
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) {
if(y&1) tmp=(ll)tmp*x%mod;
}
return tmp;
}
int get_inv(int x) {
return qpow(x,mod-2);
}
namespace Lagrange {
int x[5],y[5],dn[5];
void init() {
for(int i=1;i<=3;++i) {
dn[i]=1;
for(int j=1;j<=3;++j) {
if(i==j) continue;
dn[i]=(ll)(x[i]-x[j]+mod)%mod*dn[i]%mod;
}
dn[i]=get_inv(dn[i]);
}
}
int solve(int v) {
int an=0;
for(int i=1;i<=3;++i) {
int up=1;
for(int j=1;j<=3;++j) {
if(i==j) continue;
up=(ll)(v-x[j]+mod)%mod*up%mod;
}
(an+=(ll)y[i]*up%mod*dn[i]%mod)%=mod;
}
return an;
}
};
int n,m,ty;
int a[N],tmp[10000009],A[N],p[4][4],inv[10000008];
void init() {
inv[1]=1;
for(int i=2;i<10000008;++i) {
inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
}
inv[0]=1;
}
void calc(int tmp) {
memset(p,0,sizeof(p));
p[0][0]=1;
int na=tmp,nb=n-tmp;
for(int i=0;i<=min(3,na);++i) {
for(int j=0;j<=min(3,nb);++j) {
if(!i&&!j) continue;
int tot=na-i+1+nb-j;
if(i) {
(p[i][j]+=(ll)p[i-1][j]*(na-i+1)%mod*inv[tot]%mod)%=mod;
}
if(j) {
(p[i][j]+=(ll)p[i][j-1]*(nb-j+1)%mod*inv[tot]%mod)%=mod;
}
}
}
}
int main() {
// setIO("landlords");
scanf("%d%d%d",&n,&m,&ty);
for(int i=n-2;i<=n;++i) {
Lagrange::x[n-i+1]=i;
Lagrange::y[n-i+1]=(ty==1?i:(ll)i*i%mod);
}
Lagrange::init();
init();
for(int i=1;i<=m;++i) {
scanf("%d",&A[i]);
}
for(int i=1;i<=m;++i) {
calc(A[i]);
for(int j=1;j<=3;++j) {
int cur=n-j+1,na=A[i],nb=n-A[i];
tmp[cur]=0;
for(int k=1;k<=min(na,j);++k) {
if(j-k<=nb)
(tmp[cur]+=(ll)p[k-1][j-k]*(na-k+1)%mod*inv[n-j+1]%mod*Lagrange::solve(na-k+1)%mod)%=mod;
}
for(int k=1;k<=min(nb,j);++k) {
if(j-k<=na)
(tmp[cur]+=(ll)p[j-k][k-1]*(nb-k+1)%mod*inv[n-j+1]%mod*Lagrange::solve(n-k+1)%mod)%=mod;
}
}
for(int j=1;j<=3;++j) {
Lagrange::x[j]=n-j+1;
Lagrange::y[j]=tmp[n-j+1];
}
}
int Q,x,y,z;
scanf("%d",&Q);
for(int i=1;i<=Q;++i) {
scanf("%d",&x);
printf("%d\n",Lagrange::solve(x));
}
return 0;
}

浙公网安备 33010602011771号