【模板】FFT与NTT高精度乘法模板
FFT高精度乘法 与NTT高精度乘法 搞得不甚明了,不予解释
FFT:
#include<complex> #include<cstdio> #include<cmath> #include<cstring> using namespace std; const int maxn = 300000; const double pi = 3.1415926535; typedef complex<double> cd; char s1[maxn],s2[maxn]; cd aa[maxn],bb[maxn]; int out[maxn],rev[maxn]; void getrev(int n) { for(int i=0,j=0;i<n;i++) { rev[i]=j; for(int k=n>>1;(j^=k)<k;k>>=1); } }//来自sparrow的神奇方法 void fft(cd *a,int n,int dft) { for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int st=1;st<n;st<<=1) { cd dwfg=exp(cd(0,dft*pi/st));//单位复根 for(int i=0;i<n;i+=(st<<1)) { cd nfg=1; for(int j=i;j<i+st;j++) { cd x=a[j],y=a[j+st]*nfg; nfg*=dwfg; a[j]=x+y; a[j+st]=x-y; } } } if(dft==-1) for(int i=0;i<n;i++) a[i]/=n;//逆变换要/n } int main() { scanf("%s%s",s1,s2); int l1=strlen(s1),l2=strlen(s2); for(int i=0;i<l1;i++) aa[l1-i-1]=s1[i]-'0'; for(int j=0;j<l2;j++) bb[l2-j-1]=s2[j]-'0'; int s=2; for(;s<l1+l2-1;s<<=1); getrev(s); fft(aa,s,1); fft(bb,s,1); for(int i=0;i<s;i++) aa[i]*=bb[i]; fft(aa,s,-1);//FFT逆变换 for(int i=0;i<s;i++) { out[i]+=(int)(aa[i].real()+0.5); out[i+1]+=out[i]/10; out[i]%=10; } int now=s; for(;(!out[now])&&now>=0;now--); if(now==-1) printf("0"); for(int i=now;i>=0;i--) printf("%d",out[i]); }NTT: 这里是将FFT的单位复根e^(dft*pi/2n)改成了 g^(dft*phi(p)/n),由于我们知道g,phi(p)长度一个循环,于是-1*phi(p)/n就可以改成phi(p)-phi(p)/n。 模板考虑的是质数的情况,即phi(p)=p-1。
#include<complex> #include<cstdio> #include<cmath> #include<cstring> typedef long long ll; using namespace std; const ll maxn = 300000; const ll mod = (479<<21)+1; const ll g=3; char s1[maxn],s2[maxn]; ll aa[maxn],bb[maxn]; ll out[maxn],rev[maxn]; void getrev(ll n) { for(ll i=0,j=0;i<n;i++) { rev[i]=j; for(ll k=n>>1;(j^=k)<k;k>>=1); } } ll ksm(ll a,ll b) { ll ans=1; a%=mod; while(b) { if(b&1) ans=ans*a%mod; a=a*a%mod; b>>=1; } return ans; } void ntt(ll *a,ll n,ll dft) { for(ll i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]); for(ll st=1;st<n;st<<=1) { ll dwg; if(dft==1) dwg = ksm(g,(mod-1)/(st<<1)); else dwg = ksm(g, (mod-1) - (mod-1)/(st<<1) ); for(ll i=0;i<n;i+=(st<<1)) { ll ng=1; for(ll j=i;j<i+st;j++) { ll x=a[j]%mod,y=a[j+st]%mod*ng%mod; ng=dwg*ng%mod; a[j]=(x+y)%mod; a[j+st]=((x-y)%mod+mod)%mod; } } } if(dft==1) return; ll inv = ksm(n,mod-2); if(dft==-1) for(ll i=0;i<n;i++) a[i]=a[i]*inv%mod; } int main() { scanf("%s%s",s1,s2); ll l1=strlen(s1),l2=strlen(s2); for(ll i=0;i<l1;i++) aa[l1-i-1]=s1[i]-'0'; for(ll j=0;j<l2;j++) bb[l2-j-1]=s2[j]-'0'; ll s=2; for(;s<l1+l2-1;s<<=1); getrev(s); ntt(aa,s,1); ntt(bb,s,1); for(ll i=0;i<s;i++) aa[i]=aa[i]*bb[i]%mod; ntt(aa,s,-1); for(ll i=0;i<s;i++) { out[i]+=aa[i]; out[i+1]+=out[i]/10; out[i]%=10; } ll now=s; for(;(!out[now])&&now>=0;now--); if(now==-1) printf("0"); for(ll i=now;i>=0;i--) printf("%lld",out[i]); }