LOJ#3160. 「NOI2019」斗主地 打表+拉格朗日插值

裸做的话设一个 $p[i][j]$ 表示两个堆分别抽走 $i,j$ 个的概率.   

转移的话就枚举当前是第几个,然后再枚举左/右面由下向上第几个贡献.      

不在模意义下做,开 double 打表发现无论怎样洗牌,一次函数还是一次函数,二次函数还是二次函数.   

那么我们只需暴力维护出牌的前 3 项,然后后面的项用拉格朗日插值求出即可.  

code: 

#include <cstdio> 
#include <cstring>
#include <algorithm>    
#define N 500009  
#define ll long long  
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)  
using namespace std;    
int qpow(int x,int y) {    
	int tmp=1; 
	for(;y;y>>=1,x=(ll)x*x%mod) { 	
		if(y&1) tmp=(ll)tmp*x%mod;  
	} 
	return tmp;   
} 
int get_inv(int x) { 
	return qpow(x,mod-2); 
}
namespace Lagrange {  
	int x[5],y[5],dn[5];  
	void init() {     	
		for(int i=1;i<=3;++i) {  	  	
			dn[i]=1;  
			for(int j=1;j<=3;++j) { 
				if(i==j) continue;  
				dn[i]=(ll)(x[i]-x[j]+mod)%mod*dn[i]%mod;   
			}   
			dn[i]=get_inv(dn[i]);  
		}
	}
	int solve(int v) {  	
		int an=0; 
		for(int i=1;i<=3;++i) {   
			int up=1;  
			for(int j=1;j<=3;++j) { 	
				if(i==j) continue;  
				up=(ll)(v-x[j]+mod)%mod*up%mod;        
			}       
			(an+=(ll)y[i]*up%mod*dn[i]%mod)%=mod;  
		}  
		return an;  
	} 
}; 
int n,m,ty;  
int a[N],tmp[10000009],A[N],p[4][4],inv[10000008];        
void init() { 
	inv[1]=1; 
	for(int i=2;i<10000008;++i) { 
		inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 
	}  
	inv[0]=1;  
}
void calc(int tmp) {  
	memset(p,0,sizeof(p));      
	p[0][0]=1;    	
	int na=tmp,nb=n-tmp;         
	for(int i=0;i<=min(3,na);++i) {               
		for(int j=0;j<=min(3,nb);++j) {             
			if(!i&&!j) continue;     
			int tot=na-i+1+nb-j;        
			if(i) { 	          
				(p[i][j]+=(ll)p[i-1][j]*(na-i+1)%mod*inv[tot]%mod)%=mod;  
			}  
			if(j) { 
				(p[i][j]+=(ll)p[i][j-1]*(nb-j+1)%mod*inv[tot]%mod)%=mod;  
			}
		}
	}
}                    
int main() {  
	// setIO("landlords");                           
	scanf("%d%d%d",&n,&m,&ty);                                   
	for(int i=n-2;i<=n;++i) { 
		Lagrange::x[n-i+1]=i; 
		Lagrange::y[n-i+1]=(ty==1?i:(ll)i*i%mod);  
	}
	Lagrange::init();  
 	init();  
	for(int i=1;i<=m;++i) { 
		scanf("%d",&A[i]); 
	}               
	for(int i=1;i<=m;++i) { 
		calc(A[i]);
		for(int j=1;j<=3;++j) {  
			int cur=n-j+1,na=A[i],nb=n-A[i];     
			tmp[cur]=0;       
			for(int k=1;k<=min(na,j);++k) {                    
				if(j-k<=nb) 
					(tmp[cur]+=(ll)p[k-1][j-k]*(na-k+1)%mod*inv[n-j+1]%mod*Lagrange::solve(na-k+1)%mod)%=mod; 
			}
			for(int k=1;k<=min(nb,j);++k) { 	
				if(j-k<=na)              
					(tmp[cur]+=(ll)p[j-k][k-1]*(nb-k+1)%mod*inv[n-j+1]%mod*Lagrange::solve(n-k+1)%mod)%=mod;        
			}
		}
		for(int j=1;j<=3;++j) { 
			Lagrange::x[j]=n-j+1;   
			Lagrange::y[j]=tmp[n-j+1];       
		}
	}            
	int Q,x,y,z; 
	scanf("%d",&Q);  
	for(int i=1;i<=Q;++i) { 
		scanf("%d",&x); 
		printf("%d\n",Lagrange::solve(x));    
	}    
	return 0; 
} 

  

posted @ 2020-08-03 10:07  EM-LGH  阅读(139)  评论(0编辑  收藏  举报