FFT和NTT
\(\mathrm{FFT}\ \&\ \mathrm{NTT}\)
前置
多项式插值
\(n+1\)个横坐标互不相同的点可以唯一确定一个\(n\)次多项式。
证:
利用待定系数法可以得到一个\(n+1\)元一次方程组。
观察系数矩阵,可以发现其一定是满秩的。(关于矩阵的秩,见线性代数)
也就意味着消元时不可能消掉一个变量的同时消掉另一个变量,那当然就有一个唯一解。
复数
复数\(a+bi\)可看成复平面上的一个向量\((a,b)\),可观察出其运算的几何意义:
-
加法:向量加法。
-
乘法:模长相乘,辐角相加。
单位根
- 代数基本定理:任何复系数一元\(n\)次多项式方程在复数域上至少有一根。
- 推论:任何复系数一元\(n\)次多项式方程在复数域上恰好有\(n\)个根。
其中\(x^n-1=0\)的\(n\)个根分别记为\(\omega_n^1,\omega_n^2\dots\omega_n^{n-1}\),统称为\(n\)次单位根。
根据其几何意义,可知\(\omega_n^k=\cos\frac{2k\pi}{n}+i\sin\frac{2k\pi}{n}\)
并且\(\omega_n^0=1\)总是成立的。
单位根将单位圆\(n\)等分。
拓宽一下\(\omega_n^i\)中\(i\)的范围,结合几何意义,可以得到以下性质:
-
\(\omega_n^{i+n}=\omega_n^i\)
-
\(\omega_n^i\omega_n^j=\omega_n^{i+j}\)
-
\((\omega_n^i)^j=\omega_n^{ij}\)
-
\(\omega_n^i=\omega_{kn}^{ki}\)
-
\(\omega_n^{i+\frac{n}{2}}=-\omega_n^i\)
\(\mathrm{FFT}\)
我们现在要求\(H(x)=F(x)G(x)\)
根据多项式插值中的原理,我们可以将\(F(x)\)和\(G(x)\)进行多点求值,得到
\((x_1,F(x_1)),(x_2,F(x_2))\dots (x_n,F(x_n))\),以及
\((x_1,G(x_1)),(x_2,G(x_2))\dots (x_n,G(x_n))\)这两系列点。
(这一步将系数表示法转化为点值表示法。)
由\(H(x)=F(x)G(x)\),可知\(\forall_i H(x_i)=F(x_i)G(x_i)\),
所以将上文两系列点相乘即可得到\(H(x)\)的点值表示,再将点值表示转化为系数表示即可得到\(H(x)\)。(这也叫插值)
\(Question\ 1\)
普通单点求值是\(O(n)\)的,求\(n\)个点的值就\(O(n^2)\)了。
那么如何选取\(x_1,x_2\dots x_n\)使得多点求值能快速进行?
Sol:
选取\(x_i=\omega_n^i\)
由于单位根的一些性质,我们可以\(O(n\log n)\)进行多点求值。
定义操作\(\mathrm{DFT}:\{f_i\}\rightarrow \{F(\omega_n^i)\}\),即输入\(F(x)\)的系数列,得到\(F(x)\)在单位根处的点值表示。
考虑如何计算\(F(\omega_n^k)=\sum\limits_{i=0}^n f_i(\omega_n^k)^i\)。
若\(n=2^m\),我们按\(i\)的奇偶性分组:
于是可以递归,每次\(n\)减半,合并的过程是\(O(n)\)的,所以\(\mathrm{DFT}\)的复杂度为\(O(n\log n)\)。
\(Question\ 2\)
如何以优秀的时间复杂度将点值表示转化为系数表示?
Sol:
定义操作\(\mathrm{IDFT}:\{H(x_i)\}\rightarrow\{h_i\}\),即将点值表示转化为系数表示。
注意到一个重要的式子:
(即单位根反演)
证明:
\(\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=\sum\limits_{i=0}^{n-1}(\omega_n^k)^i\)
发现这是等比数列求和的形式,讨论公比\(\omega_n^k\):
-
若\(n \mid k\),则\(\omega_n^k=1\),故\(\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=\sum\limits_{i=0}^{n-1}(\omega_n^k)^i=n\)
-
若\(n\nmid k\),则\(\omega_n^k \ne 1\),故\(\sum\limits_{i=0}^{n-1}(\omega_n^i)^k=\sum\limits_{i=0}^{n-1}(\omega_n^k)^i=\frac{(\omega_n^k)^n-1}{w_n^k-1}=0\)
综上,原式得证。
于是又有另一个式子:\(nh_k=\sum\limits_{i=0}^{n-1}H(\omega_n^i)(\omega_n^{-k})^i\)
证明:
有了这个式子后,\(\mathrm{IDFT}\)的其余操作与\(\mathrm{DFT}\)的操作几乎相同,同样可以做到\(O(n\log n)\)。
至此,我们成功以\(O(n\log n)\)的时间复杂度解决了多项式乘法。
\(Question\ 3\)
如何代码实现?
Sol:
直接按上面的流程做,是递归版本的写法。
(摘自oiwiki)
#include <cmath>
#include <complex>
typedef std::complex<double> Comp; // STL complex
const Comp I(0, 1); // i
const int MAX_N = 1 << 20;
Comp tmp[MAX_N];
// rev=1,DFT; rev=-1,IDFT
void DFT(Comp* f, int n, int rev) {
if (n == 1) return;
for (int i = 0; i < n; ++i) tmp[i] = f[i];
// 偶数放左边,奇数放右边
for (int i = 0; i < n; ++i) {
if (i & 1)
f[n / 2 + i / 2] = tmp[i];
else
f[i / 2] = tmp[i];
}
Comp *g = f, *h = f + n / 2;
// 递归 DFT
DFT(g, n / 2, rev), DFT(h, n / 2, rev);
// cur 是当前单位复根,对于 k = 0 而言,它对应的单位复根 omega^0_n = 1。
// step 是两个单位复根的差,即满足 omega^k_n = step*omega^{k-1}*n,
// 定义等价于 exp(I*(2*M_PI/n*rev))
Comp cur(1, 0), step(cos(2 * M_PI / n), sin(2 * M_PI * rev / n));
for (int k = 0; k < n / 2;++k) {
// F(omega^k_n) = G(omega^k*{n/2}) + omega^k*n\*H(omega^k*{n/2})
tmp[k] = g[k] + cur * h[k];
// F(omega^{k+n/2}*n) = G(omega^k*{n/2}) - omega^k_n*H(omega^k\_{n/2})
tmp[k + n / 2] = g[k] - cur * h[k];
cur *= step;
}
for (int i = 0; i < n; ++i) f[i] = tmp[i];
}
当然,我们还可以进一步压缩常数,就是将递归版改成非递归版。
事实上从上往下递归的过程中将\(f\)数组的下标对应的二进制数翻转了,例如\(f_{(1011)_2}\)最后被放到了\(f_{(1101)_2}\)上。
因此我们可以非递归地完成这一步,然后向上合并即可。
交换位置的操作通过预处理可做到\(O(n)\)。
代码如下:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1<<22;
const double pi=acos(-1.0);
struct com{
double x,y;
com(double a=0,double b=0):x(a),y(b){}
com operator+(const com &a) const{
return com(x+a.x,y+a.y);
}
com operator-(const com &a) const{
return com(x-a.x,y-a.y);
}
com operator*(const com &a) const{
return com(x*a.x-y*a.y,x*a.y+y*a.x);
}
}a[maxn],b[maxn];
int n,m,len,rev[maxn];
void init(int len){
for(int i=0;i<len;++i){
rev[i]=rev[i>>1]>>1;
if(i&1) rev[i]|=len>>1;
}
}
void DFT(com f[],int len,int rv){
for(int i=0;i<len;++i){
if(i<rev[i]) swap(f[i],f[rev[i]]);
}
for(int h=2;h<=len;h<<=1){
com wn(cos(2.0*pi/h),sin(2.0*rv*pi/h));
for(int i=0;i<len;i+=h){
com w(1.0,0);
for(int k=i;k<i+h/2;++k){
com u=f[k];
com v=w*f[k+h/2];
f[k]=u+v;
f[k+h/2]=u-v;
w=w*wn;
}
}
}
if(rv==-1){
for(int i=0;i<len;++i){
f[i].x/=len;
}
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;++i) cin>>a[i].x;
for(int i=0;i<=m;++i) cin>>b[i].x;
len=1;
while(len<=n+m) len<<=1;
init(len);
DFT(a,len,1),DFT(b,len,1);
for(int i=0;i<len;++i) a[i]=a[i]*b[i];
DFT(a,len,-1);
for(int i=0;i<=n+m;++i) printf("%d ",(int)(a[i].x+0.5));
return 0;
}
注意非递归版最适合实战。
一些小细节:
-
从代码中可以看到,向上合并的过程是倍增的。所以多项式的长度要拓展到\(2^k\)并且大于相乘的两个多项式\(F(x),G(x)\)的长度之和。
-
这样做是不会对答案有影响的,因为多出来的位上系数默认为\(0\)。最后输出时去掉多出来的位数即可。
-
\(\mathrm{DFT}\)之后别忘了\(\mathrm{IDFT}\)。
-
别混淆要处理的多项式长度\(len\)与要保留的多项式长度\(n\)。
\(\mathrm{NTT}\)
我们已知复数域中有性质良好的单位根\(\omega\)来加速\(\mathrm{DFT}\)和\(\mathrm{IDFT}\),那么我们在模质数\(p\)的有限域下能否找到这样的数?
我们发现,当\(g\)是\(p\)的一个原根时,\(\omega_n^i=g^{i\frac{\varphi(p)}{n}}\)具有同复数域中的单位根一样的性质(两个重要的式子也成立)。
但是这里要求\(n\mid \varphi(p)\),对于\(n\)有很大限制,故对于任意模数\(p\),我们不能直接将它套入\(\mathrm{FFT}\)中。
我们发现对于质数\(998244353\),\(\varphi(998244353)=998244352=119\times 2^{23}\),因此当\(n=2^k\le 2^{23}\)时,程序都能跑,因此\(998244353\)成为知名常用模数(尽管有时题目与多项式无关)。
代码如下:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=4e6+5,mo=998244353;
ll n,m,a[maxn],b[maxn],rev[maxn];
void init(int len){
for(int i=1;i<=len;++i){
rev[i]=rev[i>>1]>>1;
if(i&1) rev[i]|=len>>1;
}
}
ll ksm(ll a,ll b){
ll res=1;
a%=mo;
while(b){
if(b&1) res=res*a%mo;
a=a*a%mo;
b>>=1;
}
return res;
}
void NTT(ll f[],int len,int rv){
for(int i=0;i<len;++i){
if(i<rev[i]) swap(f[i],f[rev[i]]);
}
for(int h=2;h<=len;h<<=1){
int wn=ksm(3,(mo-1)/h);
for(int i=0;i<len;i+=h){
ll w=1;
for(int k=i;k<i+h/2;++k){
int u=f[k],v=w*f[k+h/2]%mo;
f[k]=(u+v)%mo;
f[k+h/2]=(u-v+mo)%mo;
w=w*wn%mo;
}
}
}
if(rv==-1){
reverse(f+1,f+len);
int inv=ksm(len,mo-2);
for(int i=0;i<len;++i) f[i]=f[i]*inv%mo;
}
}
int main(){
scanf("%lld%lld",&n,&m);
for(int i=0;i<=n;++i) scanf("%lld",&a[i]);
for(int i=0;i<=m;++i) scanf("%lld",&b[i]);
int len=1;
while(len<=n+m) len<<=1;
init(len);
NTT(a,len,1),NTT(b,len,1);
for(int i=0;i<len;++i) a[i]=a[i]*b[i]%mo;
NTT(a,len,-1);
for(int i=0;i<=n+m;++i) printf("%lld ",a[i]%mo);
return 0;
}

浙公网安备 33010602011771号