[NTT] Luogu P4245 任意模数NTT
题解
- 用三模数NTT做,有点小细节,其他都是模板了
代码
1 #include <cstdio> 2 #include <iostream> 3 #define ll long long 4 using namespace std; 5 int const N=(1<<19); 6 int n,m,lim,l,rev[N],a[5][N],b[5][N],p[5]={0,469762049,998244353,1004535809}; 7 ll mo; 8 ll pd(ll x,int mo) { while (x>=mo) x-=mo; while (x<0) x+=mo; return x; } 9 ll mul(ll a,ll b,int mo) 10 { 11 ll r=0; a%=mo,b%=mo; 12 if (a<0) a+=mo; if (b<0) b+=mo; 13 for (;b;b>>=1ll,a=(a+a)%mo) if (b&1) r=(r+a)%mo; 14 return r; 15 } 16 ll ksm(ll a,ll b,int mo) { ll r=1; for (;b;b>>=1,a=mul(a,a,mo)) if (b&1) r=mul(r,a,mo); return r; } 17 void ntt(int *a,int f,int p) 18 { 19 for (int i=0;i<lim;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); 20 for (int mid=1;mid<lim;mid<<=1) 21 { 22 int q=ksm(3,(p-1)/(mid<<1),p); 23 if (f==-1) q=ksm(q,p-2,p); 24 for (int j=0,len=(mid<<1);j<lim;j+=len) 25 { 26 int w=1; 27 for (int k=0,x,y;k<mid;k++,w=(ll)w*q%p) x=a[j+k],y=(ll)w*a[j+mid+k]%p,a[j+k]=pd(x+y,p),a[j+mid+k]=pd(x-y,p); 28 } 29 } 30 if (f==1) return; 31 int inv=ksm(lim,p-2,p); 32 for (int i=0;i<lim;i++) a[i]=(ll)a[i]*inv%p; 33 } 34 ll uni(ll r1,ll r2,ll m1,ll m2,int f,int v) 35 { 36 ll k=mul(r2-r1,v,m2); 37 if (!f) return (r1+k*m1)%(m1*m2); 38 return pd((r1+mul(k,m1,mo))%mo,mo); 39 } 40 int main() 41 { 42 freopen("data.in","r",stdin),scanf("%d%d%lld",&n,&m,&mo); 43 for (int i=0;i<=n;i++) scanf("%d",&a[1][i]),a[2][i]=a[3][i]=a[1][i]; 44 for (int i=0;i<=m;i++) scanf("%d",&b[1][i]),b[2][i]=b[3][i]=b[1][i]; 45 lim=1; while (lim<=n+m+2) lim*=2,l++; 46 for (int i=0;i<lim;i++) rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1))); 47 for (int i=1;i<=3;i++) 48 { 49 ntt(a[i],1,p[i]),ntt(b[i],1,p[i]); 50 for (int j=0;j<lim;j++) a[i][j]=(ll)a[i][j]*b[i][j]%p[i]; 51 ntt(a[i],-1,p[i]); 52 } 53 int inv1=ksm(p[1],p[2]-2,p[2]),inv2=ksm((ll)p[1]*p[2],p[3]-2,p[3]); 54 for (int i=0;i<=n+m;i++) 55 { 56 ll ans=uni(a[1][i],a[2][i],p[1],p[2],0,inv1); 57 ans=uni(ans,a[3][i],(ll)p[1]*p[2],p[3],1,inv2),printf("%lld ",ans); 58 } 59 }