FFT 学习笔记
FFT 学习笔记
前置知识:
多项式的点值表达和系数表达
系数表达:最常见的形式。例子:\(x^3-2x^2+1\)。
点值表达:对于一个 \(x\) 次多项式,我们可以用 \(x+1\) 个互不相同的点表示。(详见拉格朗日插值法相关知识)
复数
首先你需要知道虚数 \(i\) 。
然后复数就是形如 \(a+bi\) 的数。
一个复数的模就是其在复平面上对应的点 \((a,b)\) 到原点的距离。
接下来是复数的四则运算:
-
加法:\((a+bi)+(c+di)=(a+c)+(b+d)i\)
-
减法:\((a+bi)-(c+di)=(a-c)+(b-d)i\)
-
乘法:\((a+bi)(c+di)=ac+adi+bci-bd=(ac-bd)+(ad+bc)i\)
-
除法:这里我们不关心。
关于复数相乘的几何意义:在复平面上,复数相乘,模长相乘,幅角相加。(幅角的始边是 \(x\) 轴正半轴,终边是原点到实数对应点构成的射线)
一些符号和约定
我们用大写字母表示多项式,例如 \(F\)。
我们用 \(F(x)\) 表示将 \(x\) 带入多项式得的值。
我们用 \(F[x]\) 表示 \(F\) 中 \(x\) 次项的系数。
FFT 是谁?
FFT 是一种在 \(O(nlogn)\) 内将一个多项式由系数表达转换为点值表达的工具。
它还有一个孪生兄弟:IFFT,也就是将一个多项式由点值表达转换为系数表达的工具。
那为什么我们要干这个事情呢?
-
假如你正在打高精度乘法,但是值域上界为 \(10^{100000}\) 时,你就不能 \(O(n^2)\) 暴力乘,需要更优秀的算法。
-
假如你很无聊,你突发奇想,想把两个多项式相乘。
-
假如你正在用 LGV 引理或者用什么神奇知识时,你可以得到一个点值表达,此时你迫切地想要求出某一项的系数。(假设你并不会拉格朗日插值法求多项式系数。)
但是 FFT 和多项式乘法的关系是什么?点值表达。
系数表达乘法是 \(O(n^2)\) 的,但是点值表达乘法是 \(O(n)\) 的。
因为假如你在做 \(F*G=W\) 时(这三都是多项式):
\(F\) 中有一个点 \((x,y_1)\) ,\(G\) 中有一个点 \((x,y_2)\) ,那么 \(W\) 中的点就应该是 \((x,y_1*y_2)\)。
有一个细节:\(W\) 的次数不一定等于 \(F\) 和 \(G\),那么它需要的点数会比这两个多,因此你需要多用几个点做乘法。
FFT 怎么实现
暴力
首先,最简单的,对于一个 \(n\) 次多项式 \(F\),你可以任意选 \(n+1\) 个数,然后暴力带入,然后你就有一个 \(O(n^2)\) 的算法,但这并不优秀。
那应该朝什么方向优化呢?选数是优化不了了,但是我们可以把暴力带入优化掉。
推广至复平面
为了进行这一步优化,我们要把带入的值的值域从实数推广至复数,并且引入一个概念:单位根。
首先,你需要画一个复平面;然后,画一个半径为 \(1\) 的圆。显然,这个圆上的每一个点对应的实数的模都是 \(1\)。这个圆叫单位圆。
然后我们定义 \(w_n^k\) 为一个 \(n\) 次单位根,当且仅当它在单位圆上且它的幅角为 \(\frac{k}{n}\) 份周角。
它有一些性质:
- 0:\(w_n^k=(w^1_n)^k\)
显然,利用复数乘法的几何意义可以得证。
- 1:\(w^k_n = w^{k-n}_n\)
显然,因为转了一周等价于没转。
- 2:\(w_n^k*w_n^j=w^{j+k}_n\)
利用第 \(0\) 个性质转化后,底数相同,指数相加。
- 3:\(w_n^k=-w^{k-n/2}_n(n \bmod 2 =0)\)
转了半周,那就是关于原点对称,也就是改变正负性。
- 4:\(w^k_n=w^{2k}_{2n}\)
根据单位根的定义,幅角不变。
其中, \(w^1_n=(cos(\frac{2k\pi}{n}),sin(\frac{2k\pi}{n}))\)。(依据单位根的几何意义,读者不难自证)
推式子
说了这么多,那单位根和 FFT 有什么关系呢?
假如我们有一个 \(n-1\) 次多项式 \(F\),我们根据次数的奇偶性把 \(n\) 项分为两份,即
我们设 \(LF\) 和 \(RF\),分别满足 \(LF[i]=F[i*2]\) 和 \(RF[i]=F[i*2+1]\),那么有:
此时将 \(w_n^k(0 \le k < n/2)\) 带入 \(x\):
再将 \(w_n^{k+n/2}(0 \le k < n/2)\) 带入 \(x\):
然后发现这是一个分治的过程,到最后都会带入的值变成 \(w^k_1=w^0_1=1\),然后再向上传递。
如果就这么实现,你就会得到递归版的 FFT,也叫做 DFT。
非递归优化
因为在递归版本中,我们大量使用了数组复制之类的操作,这常数很大。
我们来观察一下每一层的下标变化:(摘自 command_block 的blog)
原来的递归版(数组下标,先偶后奇,从0开始):
0 1 2 3 4 5 6 7 第1层
0 2 4 6|1 3 5 7 第2层
0 4|2 6|1 5|3 7 第3层
0|4|2|6|1|5|3|7 第4层
观察一下最开始的下标和最后的下标之间的联系。如果你注意力惊人,你会发现最后的下表就是开始的下表的二进制翻转。
二进制翻转是很好预处理的:设 \(tr_i\) 为 \(i\) 的二进制翻转,那么有 tr[i]=(tr[i>>1]>>1)|((i&1)?(n>>1):0);
。
那么只需要在开始时,交换 \(F[i]\) 和 \(F[tr_i]\) ,你就可以非递归地实现了,这就是 FFT。
关于 IFFT
直接上结论:只需要把所有 \(w_n^k\) 换成 \(w_n^{-k}\),再将结果除以 \(n\),就这是 IFFT
然OIer不需要证明
参考实现
(例题是 P1919 【模板】高精度乘法 | A*B Problem 升级版)
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define Inf (1ll<<60)
#define For(i,s,t) for(int i=s;i<=t;++i)
#define Down(i,s,t) for(int i=s;i>=t;--i)
#define ls (i<<1)
#define rs (i<<1|1)
#define bmod(x) ((x)>=mod?(x)-mod:(x))
#define lowbit(x) ((x)&(-(x)))
#define End {printf("NO\n");exit(0);}
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
inline void ckmx(int &x,int y){x=(x>y)?x:y;}
inline void ckmn(int &x,int y){x=(x<y)?x:y;}
inline void ckmx(ll &x,ll y){x=(x>y)?x:y;}
inline void ckmn(ll &x,ll y){x=(x<y)?x:y;}
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline ll min(ll x,ll y){return x<y?x:y;}
inline ll max(ll x,ll y){return x>y?x:y;}
inline int read(){
register int x=0,f=1;
char c=getchar();
while(c<'0' || '9'<c) f=(c=='-')?-1:1,c=getchar();
while('0'<=c && c<='9') x=(x<<1)+(x<<3)+c-'0',c=getchar();
return x*f;
}
void write(int x){
if(x>=10) write(x/10);
putchar(x%10+'0');
}
const int N=4e6+100;
const double Pi=acos(-1);
struct CP{
double x,y;
CP(double _x=0,double _y=0){
x=_x,y=_y;
}
CP operator +(const CP a) const{return CP(x+a.x,y+a.y);}
CP operator -(const CP a) const{return CP(x-a.x,y-a.y);}
CP operator *(const CP a) const{return CP(x*a.x-y*a.y,y*a.x+x*a.y);}
}f[N],g[N];
int sz1,sz2,n,m,tr[N],res[N];
char s1[N],s2[N];
void fft(CP *f,bool flag){//1:FFT 0:IFFT
For(i,0,n-1)
if(i<tr[i])
swap(f[i],f[tr[i]]);
for(int len=2;len<=n;len<<=1){
CP unit=CP(cos(2*Pi/len),sin(2*Pi/len));
if(!flag) unit.y*=-1;
for(int i=0;i<n;i+=len){
CP nw=CP(1,0);
for(int j=i;j<i+len/2;++j){
CP tt=nw*f[j+len/2];
//F(i)=FL(w^{k/2}_{n/2})-w^k_nFR(w^{k/2}_{n/2})
f[j+len/2]=f[j]-tt;
f[j]=f[j]+tt;
nw=nw*unit;
}
}
}
}
int main()
{
#if !ONLINE_JUDGE
freopen("test.in","r",stdin);
freopen("test.out","w",stdout);
#endif
scanf("%s%s",s1,s2);
sz1=strlen(s1),sz2=strlen(s2);
For(i,0,sz1-1) f[i].x=s1[sz1-1-i]-'0';
For(i,0,sz2-1) g[i].x=s2[sz2-1-i]-'0';
m=sz1+sz2;
n=1;
while(n<m) n<<=1;
For(i,0,n-1)
tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
fft(f,1),fft(g,1);
For(i,0,n-1) f[i]=f[i]*g[i];
fft(f,0);
For(i,0,n-1) res[i]=(f[i].x/n+0.49);
For(i,0,n-1) res[i+1]+=res[i]/10,res[i]=res[i]%10;
++n;
while(!res[n]) --n;
Down(i,n,0) printf("%d",res[i]);
return 0;
}
参考资料
大佬讲得很详细,觉得本文简陋的可以去看一下他的blog。