学习笔记《FFT》

一点前置知识

\(e\) 是一个实数,其中 \(e^{\pi i}=-1\)

复数

指形如 \(a+bi\) 的数,其中 \(i=\sqrt{-1}\)\(a,b\) 为实数。

\[(a+bi)+(c+di)=(a+c)+(b+d)i\\ (a+bi)-(c+di)=(a-c)+(b-d)i\\ (a+bi)\times(c+di)=(ac-bd)+(ad+bc)i \]

单位根

如果有一个数 \(\omega_n\),满足 \(\omega_n^n=1\),称这个数为 \(n\) 次单位根。
根据代数基本定理,\(\omega_n\) 有且只有 \(n\) 个,而且显然各不相同。
如果一个单位根 \(x\),它的 \(1,2,\dots,n\) 次方分别为 \(n\) 个单位根,我们称 \(x\)\(n\) 次本原单位根。
\(e^{\frac {2\pi i}n}\) 是一个 \(n\) 次本原单位根。为了方便,下文中 \(\omega_n = e^{\frac {2i\pi}n}\)
单位根的一些性质(\(0 \le k < n\)

  • \(\omega_n^0=\omega_n^n=1\)
  • \(\omega_n^k=\omega_n^{k+n}=-\omega_n^{k+\frac n2}=\omega_{2n}^{2k}\)
  • \(\omega_n^k=\cos(\frac{2k\pi}{n})+\sin(\frac{2k\pi}{n})i\)
  • \(\omega_n^{-k}=\omega_n^{n-k}=\cos(\frac{2k\pi}{n})-\sin(\frac{2k\pi}{n})i\)

多项式的表示方式

系数表示法

对于一个 \(n\) 次多项式可以表示为 \(A(x)=\sum_{i=0}^{n}a_ix^i\)
注意,\(n\) 次多项式有 \(n+1\) 项。

点值表示法

对于一个集合 \(\{(x_0,A(x_0)),(x_1,A(x_1)),\dots, (x_n,A(x_n))\}\),可以确定一个 \(n\) 次多项式 \(A(x)\)
点值表示法虽然不适合给人看,但是可以加快乘法。
\(n\) 次多项式 \(G(x)\)\(m\) 次多项式 \(H(x)\)
我们设 \(F(x)=G(x)\cdot H(x)\),则 \((a, F(a)) = (a, G(a)\cdot H(a))\)
我们取 \(n+m+1\)\(a\) 就能得出 \(F(x)\)
而且 \(a\) 可以随便选,那么我们选有一定关系的 \(a\),就能加快运算了。

快速傅里叶变换

快速傅里叶变换(FFT)的大致流程:

  1. 将两个相乘的多项式转化成点值表示法。
  2. 将点值相乘得到答案的点值表示法。
  3. 还原出答案。

这里,我们称操作 \(1\) 为 DFT,操作 \(3\) 为 IDFT。
我们用分治的思想,把大问题变为小问题,将时间复杂度降到 \(O(n\log n)\)

DFT

我们有两种分割方式,DIT 和 DIF。

DIT(Decimation in Time,按时域抽取)

\(f(x)=\sum_{i=0}^{n-1}a_ix^i\),其中 \(n\)\(2\) 的整数次幂(原因等下会讲)。
我们根据 \(i\) 的奇偶性进行分组。

\[\begin{aligned} f(x)&=\sum_{i=0}^{n-1}a_ix^i\\ &=(\sum_{i=0}^{\frac n2-1}a_{2i}x^{2i})+(\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{2i+1})\\ &=(\sum_{i=0}^{\frac n2-1}a_{2i}x^{2i})+x(\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{2i})\\ \end{aligned} \]

\(g(x) =\sum_{i=0}^{\frac n2-1}a_{2i}x^{i}, h(x)=\sum_{i=0}^{\frac n2-1}a_{2i+1}x^{i}\)\(f(x)=g(x^2)+x\times h(x^2)\)

\(\omega_n^k\)\(\omega_n^{k+\frac n2}\) 分别代入,由单位根的性质得:

\[\begin{aligned} f(\omega_n^k)=g(\omega_n^{2k})+\omega_n^kh(\omega_n^{2k}) &=g(\omega_{\frac n2}^{k})+\omega_n^kh(\omega_{\frac n2}^k)\\ f(\omega_n^{k+\frac n2})=g(\omega_n^{2k+n})+\omega_n^{k+\frac n2}h(\omega_n^{2k+n}) &=g(\omega_{\frac n2}^{k})-\omega_n^kh(\omega_{\frac n2}^k) \end{aligned} \]

这样只需要代入一半的单位根的幂就可以了。对于 \(h(x)\)\(g(x)\) 显然可以继续递归。
即我们每次将 \(a_i\)\(i\) 的奇偶性分组,\(\omega_n^k\) 前后分成两段。
因为每次都需要严格的将多项式分成相等长度的两部分,所以 \(n\) 必须为 \(2\) 的整次幂。

DIF(Decimation in Frequency,按频域抽取)

这次,我们对将 \(a_i\) 分成前后两段,\(\omega_n^k\) 按奇偶分组。

\[\begin{aligned} f(x)&=\sum_{i=0}^{n-1}a_ix^i\\ &=(\sum_{i=0}^{\frac n2-1}a_ix^i)+\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}x^{i+\frac n2}\\ &=(\sum_{i=0}^{\frac n2-1}a_ix^i)+x^{\frac n2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}x^i \end{aligned} \]

\(2\mid k\),将 \(\omega_n^k\)\(\omega_n^{k+1}\) 带入。

\[\begin{aligned} f(\omega_n^k)&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{ik})+\omega_n^{\frac {kn}2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{ik}\\ &=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{ik})+\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{ik}\\ &=\sum_{i=0}^{\frac n2-1}(a_i+a_{i+\frac n2})\omega_n^{ik}\\ f(\omega_n^{k+1})&=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{i(k+1)})+\omega_n^{\frac {(k+1)n}2}\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{i(k+1)}\\ &=(\sum_{i=0}^{\frac n2-1}a_i\omega_n^{i(k+1)})-\sum_{i=0}^{\frac n2-1}a_{i+\frac n2}\omega_n^{i(k+1)}\\ &=\sum_{i=0}^{\frac n2-1}\omega_n^i(a_i-a_{i+\frac n2})\omega_n^{ik}\\ \end{aligned} \]

我们同样只需要带入一半即可。

IDFT

考虑怎么变回去。
\(G(x)=\sum_{i=0}^{n-1}F(\omega_n^i)x^i\),其中 \(F(x)\) 为最终的答案。
结论一:对 \(G(x)\) 做 DFT,但用 \(\omega_n^{-k}\) 代替 \(\omega_n^k\),结果的每一项除以 \(n\) 后为 \(F(x)\)
结论二:对 \(G(x)\) 做 DFT,然后将后 \(n-1\) 项翻转,结果的每一项除以 \(n\) 后为 \(F(x)\)

我代码中均使用结论一。

结论一证明

\[\begin{aligned} G(\omega_n^{-k})&=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{-ki}\\ &=\sum_{i=0}^{n-1}\omega_n^{-ki}\sum_{j=0}^{n-1}a_j\omega_n^{ij}\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{ij-ki}\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i \end{aligned} \]

\(S(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i\)

\(a\equiv 0 \pmod n\) 时,\(S(\omega_n^a)=n\)
否则,我们错位相减

\[\begin{aligned} S(\omega_n^a)&=\sum_{i=0}^{n-1}(\omega_n^a)^i\\ \omega_n^aS(\omega_n^a)&=\sum_{i=1}^{n}(\omega_n^a)^i\\ S(\omega_n^a)&=\frac{(\omega_n^a)^n-(\omega_n^a)^0}{\omega_n^a-1}=0\\ \end{aligned} \]

也就是说

\[S(\omega_n^a)= \begin{cases} n,&{n\mid a}\\ 0,&{n\nmid a} \end{cases} \]

那么代回原式
\(G(\omega_n^{-k})=\sum_{j=0}^{n-1}a_jS(\omega_n^{j-k})=na_k\)

综上所述,将 \(\omega_n^k\) 换成 \(\omega_n^{-k}\)\(G(x)\) 跑一遍 DFT,然后除以 \(n\) 即可。

结论二证明

\[\begin{aligned} G(\omega_n^k)&=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{ki}\\ &=\sum_{i=0}^{n-1}\omega_n^{ki}\sum_{j=0}^{n-1}a_j\omega_n^{ij}\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{ij+ki}\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j+k})^i \end{aligned} \]

\(S(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i\)

同上得知

\[S(\omega_n^a)= \begin{cases} n,&{n\mid a}\\ 0,&{n\nmid a} \end{cases} \]

带回原式得:

\[G(\omega_n^k)=\sum_{j=0}^{n-1}a_jS(\omega_n^{j+k})= \begin{cases} na_0,&{k=0}\\ na_{n-k},&{k\neq0} \end{cases} \]

所以,直接对 \(G(x)\) 跑一遍 DFT,然后将后 \(n-1\) 位翻转,最后除以 \(n\) 即可。

实现

我们用 洛谷P3803 为示例。
先给出一个用递归的简单实现。
我们先以 DIT 为例。

代码
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator*=(const cp&i){return *this = *this * i;}
}tmp[N], a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
    if(n == 1)return;
    cp *fl = f, *fr = f + n / 2;
    for(int i = 0; i < n; i++)tmp[i] = f[i];
    for(int i = 0; i < n / 2; i++)
        fl[i] = tmp[i * 2], fr[i] = tmp[i * 2 + 1];
    FFT(fl, n / 2, on), FFT(fr, n / 2, on);
    cp wn(cos(Pi2 / n), on * sin(Pi2 / n)), w(1, 0);
    for (int i = 0; i < n / 2; i++)
        tmp[i] = fl[i] + w * fr[i],
        tmp[i + n / 2] = fl[i] - w * fr[i], w *= wn;
    for(int i = 0; i < n; i++)f[i] = tmp[i];
}
int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> b[i].x;
    int len = 1;
    while(len <= n + m)len <<= 1;
    FFT(a, len, 1), FFT(b, len, 1);
    for(int i = 0; i < len; i++)a[i] *= b[i];
    FFT(a, len, -1);
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

每次我们都要反复复制,常数太大了。那么提前把每个元素放在最后的位置上,就好了。
我们打表可以发现,把每个位置的二进制翻转得到的数就是最后的下标。

for(int i = 0; i < n; i++)
   rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0);
//递推预处理

另外 \(fl_k\) 复制在 \(f_k\)\(fr_k\) 复制在 \(f_{k+\frac n2}\),没有重叠,那么可以直接共用。
那么可以变成

递归 DIT
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
    if(n == 1)return;
    int half = n >> 1;
    FFT(f, half, on), FFT(f + half, half, on);
    cp wn(cos(Pi2 / n), on * sin(Pi2 / n)), w(1, 0), z;
    for (int i = 0; i < half; i++)
        z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
int rev[N];
int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> b[i].x;
    int len = 1;
    while(len <= n + m)len <<= 1;
    for(int i = 0; i < len; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
    for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
    for(int i = 0; i < len; i++)if(i < rev[i])swap(b[i], b[rev[i]]);
    FFT(a, len, 1), FFT(b, len, 1);
    for(int i = 0; i < len; i++)a[i] *= b[i];
    for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
    FFT(a, len, -1);
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

同样的,我们可以改写为循环迭代。

迭代 DIT
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
int rev[N];
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
    for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
    for(int len, p = 2; len = p >> 1, p <= n; p <<= 1){//枚举区间长度 
        cp wn = cp(cos(Pi2 / p), on * sin(Pi2 / p));
        for(int l = 0; l < n; l += p){//枚举区间左端点 
            cp w = cp(1, 0), z;
            for(int i = l; i < l + len; i++)
                z = f[i + len] * w, f[i + len] = f[i] - z, f[i] = f[i] + z, w *= wn;
        }
    }
}
int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> b[i].x;
    int len = 1;
    while(len <= n + m)len <<= 1;
    for(int i = 0; i < len; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
    FFT(a, len, 1), FFT(b, len, 1);
    for(int i = 0; i < len; i++)a[i] *= b[i];
    FFT(a, len, -1);
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

由于 DIF 是按 \(k\) 的奇偶性分组,所以我们要给答案做 rev。

递归 DIF
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
    if(n == 1)return;
    int half = n >> 1;
    cp wn(cos(Pi2 / n), on * sin(Pi2 / n)), w(1, 0), x, y;
    for (int i = 0; i < half; w *= wn, i++)
        x = f[i], y = f[i + half], f[i] = x + y, f[i + half] = (x - y) * w;
    FFT(f, half, on), FFT(f + half, half, on);
}
int rev[N];
int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> b[i].x;
    int len = 1;
    while(len <= n + m)len <<= 1;
    for(int i = 0; i < len; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
    FFT(a, len, 1), FFT(b, len, 1);
    for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
    for(int i = 0; i < len; i++)if(i < rev[i])swap(b[i], b[rev[i]]);
    for(int i = 0; i < len; i++)a[i] *= b[i];
    FFT(a, len, -1);
    for(int i = 0; i < len; i++)if(i < rev[i])swap(a[i], a[rev[i]]);
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}
迭代 DIF
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
int rev[N];
void FFT(cp*f, int n, int on){//on = 1 表示 dft,on = -1 为 idft
    for(int len, p = n; len = p >> 1, p > 1; p >>= 1){//从大到小枚举
        cp wn = cp(cos(Pi2 / p), on * sin(Pi2 / p));
        for(int l = 0; l < n; l += p){//枚举区间左端点 
            cp w = cp(1, 0), x, y;
            for(int i = l; i < l + len; w *= wn, i++)
                x = f[i], y = f[i + len], f[i] = x + y, f[i + len] = (x - y) * w;
        }
    }
    for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
}
int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> b[i].x;
    int len = 1;
    while(len <= n + m)len <<= 1;
    for(int i = 0; i < len; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
    FFT(a, len, 1), FFT(b, len, 1);
    for(int i = 0; i < len; i++)a[i] *= b[i];
    FFT(a, len, -1);
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

优化

简单优化

众所周知 \((a+bi)^2 = a^2-b^2+2abi\),故我们可以将两个多项式分别放在同一多项式的实部和虚部,FFT 后再计算结果的平方,虚部的一半就是答案。

DIF 需要最后 rev,DIT 需要开始时 rev,我们用 DIF 做 DFT,DIT 做 IDFT。就可以避免用 rev。

递归的常数太大了,我们可以用模版递归,像这样。

模版递归
#include<bits/stdc++.h>
using namespace std;
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;
struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator*=(const cp&i){return *this = *this * i;}
}a[N], b[N];//实现复数类,x 代表实部,y 代表虚部。
template<const int n>void DFT(cp*f){//要分开写常数小
    const int half = n >> 1;
    DFT<half>(f), DFT<half>(f + half);
    cp wn(cos(Pi2 / n), sin(Pi2 / n)), w(1, 0), z;
    for (int i = 0; i < half; i++)
        z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
template<>void DFT<1>(cp*f){}template<>void DFT<0>(cp*f){}//特化
#define Case(x, fu) case x: fu<x>(f);break;
#define Runfft(x) switch(n){\
Case(1<<1,x)Case(1<<2,x)Case(1<<3,x)Case(1<<4,x)\
Case(1<<5,x)Case(1<<6,x)Case(1<<7,x)Case(1<<8,x)\
Case(1<<9,x)Case(1<<10,x)Case(1<<11,x)Case(1<<12,x)\
Case(1<<13,x)Case(1<<14,x)Case(1<<15,x)Case(1<<16,x)\
Case(1<<17,x)Case(1<<18,x)Case(1<<19,x)Case(1<<20,x)Case(1<<21,x)}
int rev[N];
void rundft(cp *f,const int& n){
    for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
	Runfft(DFT);
}
template<const int n>void IDFT(cp*f){
    const int half = n >> 1;
    IDFT<half>(f), IDFT<half>(f + half);
    cp wn(cos(Pi2 / n), -sin(Pi2 / n)), w(1, 0), z;
    for (int i = 0; i < half; i++)
        z = w * f[i + half], f[i + half] = f[i] - z, f[i] = f[i] + z, w *= wn;
}
template<>void IDFT<1>(cp*f){}template<>void IDFT<0>(cp*f){}//特化
void runidft(cp *f,const int& n){
    for(int i = 0; i < n; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
	Runfft(IDFT);
}
int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> b[i].x;
    int len = 1;
    while(len <= n + m)len <<= 1;
    for(int i = 0; i < len; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? len >> 1 : 0);
    rundft(a, len), rundft(b, len);
    for(int i = 0; i < len; i++)a[i] *= b[i];
    runidft(a, len);
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].x / len + 0.49) << " ";
}

分裂基

我们称每次一分为二为基2,相应的我们也有基4,基8。
这里只介绍 DIT 的做法,但两种代码都会给出,相信读者能想明白 DIF 的做法。
分裂基每次把基2的第 \(2\) 项用同样办法拆开,变成:

\[h(x) = h_1(x^2) + xh_2(x^2)\\ f(x) = g(x^2) + xh_1(x^4) + x^3h_2(x^4) \]

我们分别将 \(\omega_n^k,\omega_n^{k+\frac n4},\omega_n^{k+\frac n2},\omega_n^{k+\frac{3n}4}\) 带入。

\[\begin{aligned} f(\omega_n^k)&=g(\omega_n^{2k})+\omega_n^kh_1(\omega_n^{4k})+\omega_n^3kh_2(\omega_n^{4k})\\ f(\omega_n^{k+\frac n4})&=g(\omega_n^{2k+\frac n2})-(\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h_2(\omega_n^{4k}))i\\ f(\omega_n^{k+\frac n2})&=g(\omega_n^{2k})-\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h2(\omega_n^{4k})\\ f(\omega_n^{k+\frac {3n}4})&=g(\omega_n^{2k+\frac n2})+(\omega_n^kh_1(\omega_n^{4k})-\omega_n^{3k}h_2(\omega_n^{4k}))i\\ \end{aligned} \]

分裂基适用于序列长度为 \(2^n\) 的 FFT,并且是运算次数比基8更少,并且更灵活。
运算次数的比较见 ooura 的博客
由于迭代版本过于复杂,所以只给出递归版本。

代码
#include<bits/stdc++.h>
using namespace std;
#define Case(x, fu) case x: fu<x>(f);break;
#define Runfft(x) switch(n){\
Case(1<<1,x)Case(1<<2,x)Case(1<<3,x)Case(1<<4,x)\
Case(1<<5,x)Case(1<<6,x)Case(1<<7,x)Case(1<<8,x)\
Case(1<<9,x)Case(1<<10,x)Case(1<<11,x)Case(1<<12,x)\
Case(1<<13,x)Case(1<<14,x)Case(1<<15,x)Case(1<<16,x)\
Case(1<<17,x)Case(1<<18,x)Case(1<<19,x)Case(1<<20,x)Case(1<<21,x)}
const double Pi = acos(-1), Pi2 = 2 * Pi;
const int N = 4e6 + 10;

struct cp{
    double x, y;
    cp(double x = 0, double y = 0):x(x), y(y){}
    cp operator+(const cp&i)const{return cp(x + i.x, y + i.y);}
    cp operator-(const cp&i)const{return cp(x - i.x, y - i.y);}
    cp operator*(const cp&i)const{return cp(x * i.x - y * i.y, y * i.x + x * i.y);}
    cp&operator+=(const cp&i){return this->x += i.x, this->y += i.y, *this;}
    cp&operator-=(const cp&i){return this->x -= i.x, this->y -= i.y, *this;}
    cp&operator*=(const cp&i){return *this = *this * i;}
}a[N];//实现复数类,x 代表实部,y 代表虚部。

template<const int n>void DFT(cp*f){//DIF
    const int half = n >> 1, quar = n >> 2;
	cp w(1, 0), wn(cos(Pi2 / n), sin(Pi2 / n)), w3, x, y;
	cp *a1 = &f[0], *a2 = a1 + quar, *a3 = a2 + quar, *a4 = a3 + quar;
	for(int i = 0; i < quar; i++){
		w3 = w * w * w, x = *a1 - *a3, y = *a2 - *a4, y = cp(y.y, -y.x);
		*a1 += *a3, *a2 += *a4, *a3 = (x - y) * w, *a4 = (x + y) * w3;
		a1++, a2++, a3++, a4++, w *= wn; 
	}
	DFT<half>(f), DFT<quar>(f + half), DFT<quar>(f + half + quar);
}
template<>void DFT<2>(cp*f){cp x = f[0], y = f[1];f[0] = x + y, f[1] = x - y;}
template<>void DFT<1>(cp*f){}template<>void DFT<0>(cp*f){}//特化
void rundft(cp *f,const int& n){Runfft(DFT);}

template<const int n>void IDFT(cp*f){//DIT
    const int half = n >> 1, quar = n >> 2;
	IDFT<half>(f), IDFT<quar>(f + half), IDFT<quar>(f + half + quar);
    cp wn(cos(Pi2 / n), -sin(Pi2 / n)), w(1, 0), w3, tmp1, tmp2, x, y;
    cp *a1 = &f[0], *a2 = a1 + quar, *a3 = a2 + quar, *a4 = a3 + quar;
    for (int i = 0; i < quar; i++){
    	w3 = w * w * w, tmp1 = w * *a3, tmp2 = w3 * *a4;
    	x = tmp1 + tmp2, y = tmp1 - tmp2, y = cp(y.y, -y.x);
    	*a4 = *a2 - y, *a3 = *a1 - x, *a2 += y, *a1 += x;
    	a1++, a2++, a3++, a4++, w *= wn;
	}
}
template<>void IDFT<2>(cp*f){cp x = f[0], y = f[1];f[0] = x + y, f[1] = x - y;}
template<>void IDFT<1>(cp*f){}template<>void IDFT<0>(cp*f){}//特化 
void runidft(cp *f,const int& n){Runfft(IDFT);}

int main(){
    cin.tie(0)->sync_with_stdio(0);
    int n, m;
    cin >> n >> m;
    for(int i = 0; i <= n; i++)cin >> a[i].x;
    for(int i = 0; i <= m; i++)cin >> a[i].y;
    int len = 1;
    while(len <= n + m)len <<= 1;
    rundft(a, len);
    for(int i = 0; i < len; i++)a[i] *= a[i];
    runidft(a, len);
    double inv = 0.5 / len;
    for(int i = 0; i <= n + m; i++)cout << (long long)(a[i].y * inv + 0.49) << " ";
}

参考资料

https://oi-wiki.org/math/poly/fft/
https://www.bilibili.com/opus/785022478912061446
https://www.luogu.com.cn/article/rj58c2eb
https://charleswu.site/archives/3065

posted @ 2025-09-01 08:29  fush's_blog  阅读(15)  评论(0)    收藏  举报