FFT和NTT

\(\mathrm{FFT}\ \&\ \mathrm{NTT}\)

前置

多项式插值

\(n+1\)个横坐标互不相同的点可以唯一确定一个\(n\)次多项式。

证:

利用待定系数法可以得到一个\(n+1\)元一次方程组。

观察系数矩阵,可以发现其一定是满秩的。(关于矩阵的秩,见线性代数

也就意味着消元时不可能消掉一个变量的同时消掉另一个变量,那当然就有一个唯一解。

复数

复数\(a+bi\)可看成复平面上的一个向量\((a,b)\),可观察出其运算的几何意义:

  • 加法:向量加法。

  • 乘法:模长相乘,辐角相加。

单位根

  • 代数基本定理:任何复系数一元\(n\)次多项式方程在复数域上至少有一根。
  • 推论:任何复系数一元\(n\)次多项式方程在复数域上恰好有\(n\)个根。

其中\(x^n-1=0\)\(n\)个根分别记为\(\omega_n^1,\omega_n^2\dots\omega_n^{n-1}\),统称为\(n\)次单位根。

根据其几何意义,可知\(\omega_n^k=\cos\frac{2k\pi}{n}+i\sin\frac{2k\pi}{n}\)

并且\(\omega_n^0=1\)总是成立的。

单位根将单位圆\(n\)等分。

拓宽一下\(\omega_n^i\)\(i\)的范围,结合几何意义,可以得到以下性质:

  • \(\omega_n^{i+n}=\omega_n^i\)

  • \(\omega_n^i\omega_n^j=\omega_n^{i+j}\)

  • \((\omega_n^i)^j=\omega_n^{ij}\)

  • \(\omega_n^i=\omega_{kn}^{ki}\)

  • \(\omega_n^{i+\frac{n}{2}}=-\omega_n^i\)

\(\mathrm{FFT}\)

我们现在要求\(H(x)=F(x)G(x)\)

根据多项式插值中的原理,我们可以将\(F(x)\)\(G(x)\)进行多点求值,得到

\((x_1,F(x_1)),(x_2,F(x_2))\dots (x_n,F(x_n))\),以及

\((x_1,G(x_1)),(x_2,G(x_2))\dots (x_n,G(x_n))\)这两系列点。

(这一步将系数表示法转化为点值表示法。)

\(H(x)=F(x)G(x)\),可知\(\forall_i H(x_i)=F(x_i)G(x_i)\)

所以将上文两系列点相乘即可得到\(H(x)\)的点值表示,再将点值表示转化为系数表示即可得到\(H(x)\)。(这也叫插值)

\(Question\ 1\)

普通单点求值是\(O(n)\)的,求\(n\)个点的值就\(O(n^2)\)了。

那么如何选取\(x_1,x_2\dots x_n\)使得多点求值能快速进行?

Sol:

选取\(x_i=\omega_n^i\)

由于单位根的一些性质,我们可以\(O(n\log n)\)进行多点求值。

定义操作\(\mathrm{DFT}:\{f_i\}\rightarrow \{F(\omega_n^i)\}\),即输入\(F(x)\)的系数列,得到\(F(x)\)在单位根处的点值表示。

考虑如何计算\(F(\omega_n^k)=\sum\limits_{i=0}^n f_i(\omega_n^k)^i\)

\(n=2^m\),我们按\(i\)的奇偶性分组:

\[\begin{aligned} F(\omega_n^k)&=\sum\limits_{i=0}^n f_i(\omega_n^k)^i\\ &=\sum\limits_{i=0}^{\frac{n}{2}} f_{2i}(\omega_n^k)^{2i}+\sum\limits_{i=0}^{\frac{n}{2}} f_{2i+1}(\omega_n^k)^{2i+1}\\ &=\sum\limits_{i=0}^{\frac{n}{2}} f_{2i}(\omega_n^k)^{2i}+\omega_n^k\sum\limits_{i=0}^{\frac{n}{2}} f_{2i+1}(\omega_n^k)^{2i}\\ &=\sum\limits_{i=0}^{\frac{n}{2}} f_{2i}(\omega_{\frac{n}{2}}^k)^i+\omega_n^k\sum\limits_{i=0}^{\frac{n}{2}} f_{2i+1}(\omega_{\frac{n}{2}}^k)^i\\ &=\mathrm{DFT}(\{f_{2i}\})_k+\omega_n^k\mathrm{DFT}(\{f_{2i+1}\})_k \end{aligned} \]

于是可以递归,每次\(n\)减半,合并的过程是\(O(n)\)的,所以\(\mathrm{DFT}\)的复杂度为\(O(n\log n)\)

\(Question\ 2\)

如何以优秀的时间复杂度将点值表示转化为系数表示?

Sol:

定义操作\(\mathrm{IDFT}:\{H(x_i)\}\rightarrow\{h_i\}\),即将点值表示转化为系数表示。

注意到一个重要的式子:

\[\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=n[n\mid k] \]

(即单位根反演)

证明:

\(\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=\sum\limits_{i=0}^{n-1}(\omega_n^k)^i\)

发现这是等比数列求和的形式,讨论公比\(\omega_n^k\)

  • \(n \mid k\),则\(\omega_n^k=1\),故\(\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=\sum\limits_{i=0}^{n-1}(\omega_n^k)^i=n\)

  • \(n\nmid k\),则\(\omega_n^k \ne 1\),故\(\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=\sum\limits_{i=0}^{n-1}(\omega_n^k)^i=\frac{(\omega_n^k)^n-1}{w_n^k-1}=0\)

综上,原式得证。

于是又有另一个式子:\(nh_k=\sum\limits_{i=0}^{n-1}H(\omega_n^i)(\omega_n^{-k})^i\)

证明:

\[\begin{aligned} nh_k&=\sum\limits_{i=0}^{n-1}H(\omega_n^i)(\omega_n^{-k})^i\\ &=\sum\limits_{i=0}^{n-1}(\omega_n^{-k})^i\sum\limits_{j=0}^{n-1}h_j(\omega_n^i)^j\\ &=\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{n-1}(\omega_n^{-k})^i(\omega_n^i)^jh_j\\ &=\sum\limits_{j=0}^{n-1}h_j\sum\limits_{i=0}^{n-1}(\omega_n^i)^{j-k}\\ &=\sum\limits_{j=0}^{n-1}nh_j[n \mid j-k]\\ &=nh_k \end{aligned} \]

有了这个式子后,\(\mathrm{IDFT}\)的其余操作与\(\mathrm{DFT}\)的操作几乎相同,同样可以做到\(O(n\log n)\)

至此,我们成功以\(O(n\log n)\)的时间复杂度解决了多项式乘法。

\(Question\ 3\)

如何代码实现?

Sol:

直接按上面的流程做,是递归版本的写法。

(摘自oiwiki

#include <cmath>
#include <complex>

typedef std::complex<double> Comp;  // STL complex

const Comp I(0, 1);  // i
const int MAX_N = 1 << 20;

Comp tmp[MAX_N];

// rev=1,DFT; rev=-1,IDFT
void DFT(Comp* f, int n, int rev) {
  if (n == 1) return;
  for (int i = 0; i < n; ++i) tmp[i] = f[i];
  // 偶数放左边,奇数放右边
  for (int i = 0; i < n; ++i) {
    if (i & 1)
      f[n / 2 + i / 2] = tmp[i];
    else
      f[i / 2] = tmp[i];
  }
  Comp *g = f, *h = f + n / 2;
  // 递归 DFT
  DFT(g, n / 2, rev), DFT(h, n / 2, rev);
  // cur 是当前单位复根,对于 k = 0 而言,它对应的单位复根 omega^0_n = 1。
  // step 是两个单位复根的差,即满足 omega^k_n = step*omega^{k-1}*n,
  // 定义等价于 exp(I*(2*M_PI/n*rev))
  Comp cur(1, 0), step(cos(2 * M_PI / n), sin(2 * M_PI * rev / n));
  for (int k = 0; k < n / 2;++k) {  
    // F(omega^k_n) = G(omega^k*{n/2}) + omega^k*n\*H(omega^k*{n/2})
    tmp[k] = g[k] + cur * h[k];
    // F(omega^{k+n/2}*n) = G(omega^k*{n/2}) - omega^k_n*H(omega^k\_{n/2})
    tmp[k + n / 2] = g[k] - cur * h[k];
    cur *= step;
  }
  for (int i = 0; i < n; ++i) f[i] = tmp[i];
}

当然,我们还可以进一步压缩常数,就是将递归版改成非递归版。

事实上从上往下递归的过程中将\(f\)数组的下标对应的二进制数翻转了,例如\(f_{(1011)_2}\)最后被放到了\(f_{(1101)_2}\)上。

因此我们可以非递归地完成这一步,然后向上合并即可。

交换位置的操作通过预处理可做到\(O(n)\)

代码如下:

#include<bits/stdc++.h>

using namespace std;

const int maxn=1<<22;
const double pi=acos(-1.0);
struct com{
	double x,y;
	
	com(double a=0,double b=0):x(a),y(b){}
	
	com operator+(const com &a) const{
		return com(x+a.x,y+a.y);
	}
	
	com operator-(const com &a) const{
		return com(x-a.x,y-a.y);
	}
	
	com operator*(const com &a) const{
		return com(x*a.x-y*a.y,x*a.y+y*a.x);
	}
}a[maxn],b[maxn];
int n,m,len,rev[maxn];

void init(int len){
	for(int i=0;i<len;++i){
		rev[i]=rev[i>>1]>>1;
		if(i&1) rev[i]|=len>>1;
	}
}

void DFT(com f[],int len,int rv){
	for(int i=0;i<len;++i){
		if(i<rev[i]) swap(f[i],f[rev[i]]);
	}
	for(int h=2;h<=len;h<<=1){
		com wn(cos(2.0*pi/h),sin(2.0*rv*pi/h));
		for(int i=0;i<len;i+=h){
			com w(1.0,0);
			for(int k=i;k<i+h/2;++k){
				com u=f[k];
				com v=w*f[k+h/2];
				f[k]=u+v;
				f[k+h/2]=u-v;
				w=w*wn;
			}
		}
	}
	if(rv==-1){
		for(int i=0;i<len;++i){
			f[i].x/=len;
		}
	}
}

int main(){
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;++i) cin>>a[i].x;
	for(int i=0;i<=m;++i) cin>>b[i].x;
	len=1;
	while(len<=n+m) len<<=1;
	init(len);
	DFT(a,len,1),DFT(b,len,1);
	for(int i=0;i<len;++i) a[i]=a[i]*b[i];
	DFT(a,len,-1);
	for(int i=0;i<=n+m;++i) printf("%d ",(int)(a[i].x+0.5));
	return 0;
} 

注意非递归版最适合实战。

一些小细节:

  • 从代码中可以看到,向上合并的过程是倍增的。所以多项式的长度要拓展到\(2^k\)并且大于相乘的两个多项式\(F(x),G(x)\)的长度之和。

  • 这样做是不会对答案有影响的,因为多出来的位上系数默认为\(0\)。最后输出时去掉多出来的位数即可。

  • \(\mathrm{DFT}\)之后别忘了\(\mathrm{IDFT}\)

  • 别混淆要处理的多项式长度\(len\)与要保留的多项式长度\(n\)

\(\mathrm{NTT}\)

我们已知复数域中有性质良好的单位根\(\omega\)来加速\(\mathrm{DFT}\)\(\mathrm{IDFT}\),那么我们在模质数\(p\)的有限域下能否找到这样的数?

我们发现,当\(g\)\(p\)的一个原根时,\(\omega_n^i=g^{i\frac{\varphi(p)}{n}}\)具有同复数域中的单位根一样的性质(两个重要的式子也成立)。

但是这里要求\(n\mid \varphi(p)\),对于\(n\)有很大限制,故对于任意模数\(p\),我们不能直接将它套入\(\mathrm{FFT}\)中。

我们发现对于质数\(998244353\)\(\varphi(998244353)=998244352=119\times 2^{23}\),因此当\(n=2^k\le 2^{23}\)时,程序都能跑,因此\(998244353\)成为知名常用模数(尽管有时题目与多项式无关)。

代码如下:

#include<bits/stdc++.h>

using namespace std;
typedef long long ll;

const int maxn=4e6+5,mo=998244353;
ll n,m,a[maxn],b[maxn],rev[maxn];

void init(int len){
	for(int i=1;i<=len;++i){
		rev[i]=rev[i>>1]>>1;
		if(i&1) rev[i]|=len>>1;
	}
}

ll ksm(ll a,ll b){
	ll res=1;
	a%=mo;
	while(b){
		if(b&1) res=res*a%mo;
		a=a*a%mo;
		b>>=1;
	}
	return res;
}

void NTT(ll f[],int len,int rv){
	for(int i=0;i<len;++i){
		if(i<rev[i]) swap(f[i],f[rev[i]]);
	}
	for(int h=2;h<=len;h<<=1){
		int wn=ksm(3,(mo-1)/h);
		for(int i=0;i<len;i+=h){
			ll w=1;
			for(int k=i;k<i+h/2;++k){
				int u=f[k],v=w*f[k+h/2]%mo;
				f[k]=(u+v)%mo;
				f[k+h/2]=(u-v+mo)%mo;
				w=w*wn%mo;
			}
		}
	}
	if(rv==-1){
		reverse(f+1,f+len);
		int inv=ksm(len,mo-2);
		for(int i=0;i<len;++i) f[i]=f[i]*inv%mo;
	}
}

int main(){
	scanf("%lld%lld",&n,&m);
	for(int i=0;i<=n;++i) scanf("%lld",&a[i]);
	for(int i=0;i<=m;++i) scanf("%lld",&b[i]);
	int len=1;
	while(len<=n+m) len<<=1;
	init(len);
	NTT(a,len,1),NTT(b,len,1);
	for(int i=0;i<len;++i) a[i]=a[i]*b[i]%mo;
	NTT(a,len,-1);
	for(int i=0;i<=n+m;++i) printf("%lld ",a[i]%mo);
	
	return 0;
}
posted @ 2025-03-26 09:56  RandomShuffle  阅读(49)  评论(0)    收藏  举报