NTT 学习笔记

前言

ntt和fft一样,都是用来处理卷积的,但用处不一样

fft因为浮点数的性质,系数的大小没有限制,但是会丢失精度

ntt是通过整数运算在剩余系下计算卷积,卷积后的系数不能超过整形的范围,但是速度较快,而且不掉精

如果系数不大,一般用ntt

如果系数大大,且不能取模,则用fft

理论

原根:一个数g为p的原根,当且仅当$g^{ \phi (p)} \equiv 1 (mod \ p)$

发现原根也满足fft中单位根的四条性质,所以可以用原根代替单位根

我们设$G^i_n = g^{\frac{p-1}{n}*i}$

可以将$w^i_n$ 替换为$G^i_n$

ntt要满足$p=g*2^x+1$,这样$(p-1)/n$才是整数

然后就是fft的过程了

代码

#include <iostream>
#include <cstdio>
#include <cmath>
#define N 4000001
using namespace std;
#define mod 998244353
#define int long long
int lim,rev[N],len;
int inv[N],a[N],b[N];
int read()
{
    char c=getchar();
	int x=0,f=1;
    while(c<'0'||c>'9')
	{
		if(c=='-')f=-1;c=getchar();
	}
    while(c>='0'&&c<='9')
	{
		x=x*10+c-'0';c=getchar();
	}
    return x*f;
}
int qpow(int base,int index)
{
	int ans=1;
	while(index)
	{
		if(index&1) ans*=base,ans%=mod;
		base*=base,base%=mod;
		index>>=1;
	}
	return ans;
}
void ntt(int arr[],int gen)
{
	for(int i=0;i<lim;i++) if(rev[i]>i) swap(arr[i],arr[rev[i]]);
	for(int i=1;i<lim;i*=2)//枚举区间长度的一般(方便合并) 
	{
		int val=qpow(gen,(mod-1)/(i<<1));//相当于计算G(1,i*2) 即相邻根之间的增量 
		for(int j=0;j<lim;j+=(i<<1))//枚举每个区间 
		{
			int val2=1; //每个区间的根要从头开始代入 
			for(int k=0;k<i;k++,val2*=val,val2%=mod)//计算
			{
				int t=arr[j+k],t2=val2*arr[j+k+i]%mod;//蝴蝶变换 
				arr[j+k]=(t+t2)%mod;
				arr[j+k+i]=(t-t2+mod)%mod;
			}
		}
	}
}
signed main()
{
	int n,m;
	cin>>n>>m;
	for(int i=0;i<=n;i++) a[i]=read();
	for(int i=0;i<=m;i++) b[i]=read();
	lim=1;
	while(lim<=n+m) len++,lim<<=1;
	for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
	ntt(a,3);
	ntt(b,3);
	for(int i=0;i<lim;i++) a[i]=a[i]*b[i]%mod;
	ntt(a,qpow(3,mod-2));//idft带入的是单位根的逆元,这里也相应地带入3的逆元 
	for(int i=0;i<=n+m;i++) printf("%lld ", a[i]*qpow(lim,mod-2)%mod);//idft最后要除以项数 
}

  

posted @ 2020-09-05 08:58  linzhuohang  阅读(402)  评论(0编辑  收藏  举报