【模板】FFT与NTT高精度乘法模板

FFT高精度乘法 与NTT高精度乘法 搞得不甚明了,不予解释 FFT:
#include<complex>
#include<cstdio>
#include<cmath>
#include<cstring>
using namespace std;
const int maxn = 300000;
const double pi = 3.1415926535;
typedef complex<double> cd;
char s1[maxn],s2[maxn];
cd aa[maxn],bb[maxn];
int out[maxn],rev[maxn];
void getrev(int n)
{
	for(int i=0,j=0;i<n;i++)
	{
		rev[i]=j;
		for(int k=n>>1;(j^=k)<k;k>>=1);
	}
}//来自sparrow的神奇方法
void fft(cd *a,int n,int dft)
{
	for(int i=0;i<n;i++)
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int st=1;st<n;st<<=1)
	{
		cd dwfg=exp(cd(0,dft*pi/st));//单位复根
		for(int i=0;i<n;i+=(st<<1))
		{
			cd nfg=1;
			for(int j=i;j<i+st;j++)
			{
				cd x=a[j],y=a[j+st]*nfg;
				nfg*=dwfg;
				a[j]=x+y; a[j+st]=x-y;
			}
		}
	}
	if(dft==-1) for(int i=0;i<n;i++) a[i]/=n;//逆变换要/n
}
int main()
{
	scanf("%s%s",s1,s2);
	int l1=strlen(s1),l2=strlen(s2);
	for(int i=0;i<l1;i++) aa[l1-i-1]=s1[i]-'0';
	for(int j=0;j<l2;j++) bb[l2-j-1]=s2[j]-'0';
	int s=2;
	for(;s<l1+l2-1;s<<=1);
	getrev(s);
	fft(aa,s,1);
	fft(bb,s,1);
	for(int i=0;i<s;i++) aa[i]*=bb[i];
	fft(aa,s,-1);//FFT逆变换
	for(int i=0;i<s;i++)
	{
		out[i]+=(int)(aa[i].real()+0.5);
		out[i+1]+=out[i]/10;
		out[i]%=10;
	}
	int now=s;
	for(;(!out[now])&&now>=0;now--);
	if(now==-1) printf("0");
	for(int i=now;i>=0;i--) printf("%d",out[i]);
}
NTT: 这里是将FFT的单位复根e^(dft*pi/2n)改成了 g^(dft*phi(p)/n),由于我们知道g,phi(p)长度一个循环,于是-1*phi(p)/n就可以改成phi(p)-phi(p)/n。 模板考虑的是质数的情况,即phi(p)=p-1。
#include<complex>
#include<cstdio>
#include<cmath>
#include<cstring>
typedef long long ll;
using namespace std;
const ll maxn = 300000;
const ll mod = (479<<21)+1;
const ll g=3;
char s1[maxn],s2[maxn];
ll aa[maxn],bb[maxn];
ll out[maxn],rev[maxn];
void getrev(ll n)
{
	for(ll i=0,j=0;i<n;i++)
	{
		rev[i]=j;
		for(ll k=n>>1;(j^=k)<k;k>>=1);
	}
}
ll ksm(ll a,ll b)
{
	ll ans=1;
	a%=mod;
	while(b)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod; 
		b>>=1;
	}
	return ans;
}
void ntt(ll *a,ll n,ll dft)
{
	for(ll i=0;i<n;i++)
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(ll st=1;st<n;st<<=1)
	{
		ll dwg;
		if(dft==1) dwg = ksm(g,(mod-1)/(st<<1));
		else  dwg = ksm(g, (mod-1) - (mod-1)/(st<<1) );
		for(ll i=0;i<n;i+=(st<<1))
		{
			ll ng=1;
			for(ll j=i;j<i+st;j++)
			{
				ll x=a[j]%mod,y=a[j+st]%mod*ng%mod;
				ng=dwg*ng%mod;
				a[j]=(x+y)%mod; a[j+st]=((x-y)%mod+mod)%mod;
			}
		}
	}
        if(dft==1) return;
	ll inv = ksm(n,mod-2);
	if(dft==-1) for(ll i=0;i<n;i++) a[i]=a[i]*inv%mod;
}
int main()
{
	scanf("%s%s",s1,s2);
	ll l1=strlen(s1),l2=strlen(s2);
	for(ll i=0;i<l1;i++) aa[l1-i-1]=s1[i]-'0';
	for(ll j=0;j<l2;j++) bb[l2-j-1]=s2[j]-'0';
	ll s=2;
	for(;s<l1+l2-1;s<<=1);
	getrev(s);
	ntt(aa,s,1);
	ntt(bb,s,1);
	for(ll i=0;i<s;i++) aa[i]=aa[i]*bb[i]%mod;
	ntt(aa,s,-1);
	for(ll i=0;i<s;i++)
	{
		out[i]+=aa[i];
		out[i+1]+=out[i]/10;
		out[i]%=10;
	}
	ll now=s;
	for(;(!out[now])&&now>=0;now--);
	if(now==-1) printf("0");
	for(ll i=now;i>=0;i--) printf("%lld",out[i]);
}
 
posted @ 2018-04-18 15:23  Newuser233  阅读(17)  评论(0)    收藏  举报