多项式乘法——FFT 快速傅里叶变换
log
25-2-26 更新了多项式乘法的实现(前面光写 fft 去了)
25-2-26 更新三次变二次优化 和 迭代 fft
问题引入
现在有两个多项式: \(A(x) = \sum _{i = 0}^{n - 1} a_i x^i\) 和 \(B(x) = \sum _{i = 0}^{m - 1} b_i x^i\),我们的目的是要求这两个多项式相乘后得到的多项式 \(C(x)\), \(C(x) = \sum _{i = 0}^{n + m - 2} c_i x^i\)。
初步分析
要求解 \(C(x)\),就是要求得 \(C(x)\) 每一项的系数 \(c_i\)。
若暴力的计算 \(c_i\),那么 \(c_i = \sum _{j = 0}^i a_j b_{i - j}\),总的时间复杂度是 \(O((n + m)^2)\) 的。
我们知道,直角坐标系上 \(n\) 个不同点可以唯一确定一个过这 \(n\) 个点的 \(n - 1\) 次多项式。也就是说对于 \(C(x)\) 来讲,只要确定了它经过的 \(n + m - 1\) 个不同的点的坐标,我们就有办法确定 \(C(x)\) 这个多项式。于是我们取 \(n + m - 1\) 个不同的点 \(\{ (x_0, C(x_0)), (x_1, C(x_1)), \dots , (x_{n + m - 2}, C(x_{n + m - 2})) \}\),其中 \(C(x_i) = A(x_i)B(x_i)\),即我们要在 \(A(x)\) 和 \(B(x)\) 上分别取 \(n + m - 1\) 个点。如果暴力地取点时间复杂度还是 \(O((n + m)^2)\),而且我们还不知道当取完点后如何通过这些点确定 \(C(x)\)。
若果有方法能够快速地取点(DFT)并快速地根据点来确定多项式(IDFT),那我们就解决了问题。
取点——DFT 离散傅里叶变换
此处用“取点”表示其他博文中的“将系数表达式映射到点值表达式”(因为我感觉这个过程就是在取点)
DFT 的定义
(因为光看FFT代码搞不清FFT与傅里叶变换的关系,于是有了这一部分)
离散傅里叶变换将长度为 \(n\) 的时域信号 \(x[t]\) 转换为频域信号 \(X[k]\),其数学表达式为:
其中:
- \(x[t]\) 表示时域信号的第 \(t\) 个采样点
- \(X[k]\) 表示频域信号的第 \(k\) 个频率分量
那这坨公式和我们的多项式有什么关系呢?
抛开定义中的物理意义,假设有个多项式 \(P(u) = \sum_{t = 0} ^{n - 1} x[t] u^t\)(也就是说把定义中的时域信号看作是多项式的系数),于是 \(P(e^{-i{2 \pi \over n}k}) = \sum_{t = 0} ^{n - 1} x[t] (e^{-i{2 \pi \over n}k})^t = X[k]\)。此时对多项式取点的操作就与 DFT 产生了联系:对于一个 \(n - 1\) 次多项式,我们可以对其进行 DFT 以求得这个多项式在 \(e^{-i{2 \pi \over n}0},e^{-i{2 \pi \over n}1},e^{-i{2 \pi \over n}2}, \dots ,e^{-i{2 \pi \over n}(n - 1)}\) 这 \(n\) 个位置上的值,这 \(n\) 个点其实就是单位根
但是如果按照定义直接每个点单独求值,时间复杂度还是 \(O(n^2)\) 的,并没有什么优化,于是就有 FFT。
(PS:在问题分析中我们提到要在 \(A(x)\) 和 \(B(x)\) 上分别取 \(n + m - 1\) 个点,为了解决这个问题,我们可以在两个多项式后面添项,如:\(A(x) = (\sum _{i = 0}^{n - 1} a_i x^i) + 0 \times x^n + \dots + 0 \times x^{n + m - 2}\))
FFT——DFT 的高效实现
(观看视频 BV1za411F76U 可以比较清楚地知道 FFT 的实现。)
FFT 是 DFT 数学公式的高效实现,通过分治和对称性减少计算量。
我们将原本的多项式(这里是添项过后的,而且项数 \(n\) 是 2 的幂) \(A(x) = \sum _{i = 0}^{n - 1} a_i x^i\) 的系数分成奇偶两部分,分别构成两个新的多项式 \(A_e(x) = \sum _{i = 0}^{\frac{n}{2} - 1} a_{2i} x^i\) 和 \(A_o(x) = \sum _{i = 0}^{\frac{n}{2} - 1} a_{2i + 1} x^i\)。(角标的含义是 even 和 odd)。
则 \(A(x)\) 还可以这样表示:
那么代入单位根 \(w_n^k\) 得到:
而且我们不需要从 0 到 \(n - 1\) 取枚举 \(k\),因为对于 \(k < {n \over 2}\),有 \(w_n^{k + {n \over 2}} = -w_n^k\),即:
所以我们只需要枚举前一半就好了。
对于 \(A_e(w_{n \over 2}^k)\) 和 \(A_o(w_{n \over 2}^k)\),同样用 FFT,递归的求就可以了,当项数为 1 时直接返回系数或者什么都不做就好了。
于是就可以写出递归版的 FFT (代码等说了 IDFT 再给出)
确定多项式——IDFT 离散傅里叶逆变换
在 DFT 中我们实际是解了一个这样的方程:
即求解了 \(A(w_n^i)\)
那么在 IDFT 中我们知道了 \(A(w_n^i)\) 要求 \(a_i\)。在方程两边同时左乘第一个矩阵的逆矩阵(求逆矩阵直接搜范德蒙矩阵的逆矩阵就好了,虽然有点区别,但直接类比过来就好了)得到:
这几乎是和 DFT 一样的问题,只不过 IDFT 中的单位根要和 DFT 中的单位根共轭,并且最后还要除 \(n\)。
于是就可以得到递归版的代码:
// len表示项数,并且一定是 2 的幂
void fft(int len, std::vector<std::complex<double>> &A, bool inv) {
if (len == 1) {
return;
}
int ll = len >> 1;
std::vector<std::complex<double>> Ae(ll), Ao(ll);
for (int i = 0; i + i < len; i++) {
Ae[i] = A[i << 1];
Ao[i] = A[i << 1 | 1];
}
fft(ll, Ae, inv);
fft(ll, Ao, inv);
double angle = 2.0 * M_PI / double(len);
std::complex<double> unit(cos(angle), (inv ? 1 : -1) * sin(angle)), w(1.0, 0.0);
for (int i = 0; i < ll; i++, w *= unit) {
std::complex<double> tmp = w * Ao[i];
A[i] = Ae[i] + tmp;
A[i + ll] = Ae[i] - tmp;
}
// 当使用 IDFT 时,别忘了在函数外边另外除项数。
return;
}
有了最基础递归版的 FFT,就可以将 P3803 【模板】多项式乘法(FFT) 解决了。(题解里面有挺多很好的题解的)
多项式乘法的实现
和一开始分析的思路一样,先在给定的两个多项式中取点(DFT),取完点后再根据点确定要求的多项式的系数(IDFT),于是有如下代码:
void solve()
{
// 最高次数
int n = 0, m = 0;
std::cin >> n >> m;
std::vector<double> a(n + 1), b(m + 1);
for (auto &i : a) {
std::cin >> i;
}
for (auto &i : b) {
std::cin >> i;
}
int len = 1;
while (len <= m + n) {
len <<= 1;
}
std::vector<std::complex<double>> A(len), B(len), C(len);
// 由于 fft 涉及到复数运算,这里就把实数系数转变为了复数系数,同时也完成了添项的操作
for (int i = 0; i < len; i++) {
A[i] = std::complex<double>((i <= n ? a[i] : 0.0), 0.0);
B[i] = std::complex<double>((i <= m ? b[i] : 0.0), 0.0);
}
fft(len, A, false);
fft(len, B, false);
for (int i = 0; i < len; i++) {
C[i] = A[i] * B[i];
}
fft(len, C, true);
for (int i = 0; i <= m + n; i++) {
// +0.5 就相当于是四舍五入了
std::cout << i64((C[i].real())/ double(len) + 0.5) << ' ';
}
std::cout << '\n';
return;
}
三次变二次优化
这里的次数是指在多项式乘法的过程进行了多少次 DFT,显然上面的版本是进行了三次:对多项式 \(A(x)\) 和 \(B(x)\) 各进行了一次正变换,为了得到 \(C(x)\) 的系数又将正变换的结果相乘进行逆变换。可以通过把 \(B(x)\) 的系数放在 \(A(x)\) 的虚部从而只进行一次正变换。然后对变换结果的平方进行逆变换,最后虚部的一半就是 \(C(x)\) 的系数,即:
// code ...
std::vector<std::complex<double>> A(len);
for (int i = 0; i < len; i++) {
A[i] = std::complex<double>(i <= la ? a[i] : 0.0, i <= lb ? b[i] : 0.0);
}
fft(len, A, false);
for (int i = 0; i < len; i++) {
A[i] = A[i] * A[i];
}
fft(len, A, true);
for (int i = 0; i <= la + lb; i++) {
// 这里的 / len 是逆变换的最后一步,如果在 fft 中已经除过了这里就不用再除了。
std::cout << i64(A[i].imag() / len / 2.0 + 0.5) << ' ';
}
std::cout << '\n';
// code ...
解释如下(优化的方法是没问题的,但以下解释没有严格证明,只相当于自己说服自己):
首先 \((a + bi) \times Z = aZ + bZi\) (\(i\) 是虚数单位, \(Z\) 是一个复数),也就是说对一系列复数 DFT 就相当于对这一系列复数的实部和虚部分别做 DFT 然后虚部的结果乘个虚数单位再和实部的结果加起来,反之亦然。
然后 \((a + bi) ^ 2 = (a^2 - b^2) + 2abi\),所以平方后做逆变换,虚部就是我们要的结果的两倍。
迭代 FFT
迭代优化是基于一个神奇的规律。
例如当对长度为 8 的序列进行 FFT 时,在递归版中我们要不断依据下标的奇偶将系数分为两部分。于是观察分组:
第一层: 0 1 2 3 4 5 6 7
第二层: 0 2 4 6,1 3 5 7
第三层: 0 4,2 6,1 5,3 7
第四层: 0,4,2,6,1,5,3,7
当我们将最后一层下标的二进制倒过来,我们就会发现这个规律:
000,100,010,110,001,101,011,111 ->
000,001,010,011,100,101,110,111
0 ,1 ,2 ,3 ,4 ,5 ,6 ,7
于是我们要做的就是将序列按照最后一层的下标排列,然后模拟每层的计算就好了。
那么如何获取最后一层的下标呢?
先看如下代码做了什么:
// N 是 2 的幂
void fun1(int N) {
std::vector idx(N, 0);
for (int l = 1, r = 1, bit = 1; l < N; l <<= 1, bit <<= 1) {
for (int i = 0; i < l; i++) {
idx[r++] = idx[i] | bit;
}
}
}
是的,就是 \(O(N)\) 地生成了从 \(0\) 到 \(N - 1\) 的有序序列。这么做的意义是什么呢?意义在于,只要我们稍微改一下上述代码,就可以 \(O(N)\) 地生成最后一层的下标:
// N 是 2 的幂
void fun2(int N) {
std::vector idx(N, 0);
for (int l = 1, r = 1, bit = (N >> 1); l < N; l <<= 1, bit >>= 1) {
for (int i = 0; i < l; i++) {
idx[r++] = idx[i] | bit;
}
}
}
获取了最后一层的下标后,我们进行 FFT 时第一步就是调整原序列,之后的模拟递归看代码:
// 要先处理出 idx
std::vector<int> idx(N << 2);
void fft(int len, std::vector<std::complex<double>> &A, int inv) {
// 调整原序列,使得其与递归最下层的下标相对应
for (int i = 1; i < len; i++) {
if (i < idx[i]) {
std::swap(A[i], A[idx[i]]);
}
}
for (int mid = 1; mid < len; mid <<= 1) {
std::complex<double> unit(cos(M_PI / mid), inv * sin(M_PI / mid));
for (int l = 0; l < len; l += (mid << 1)) {
std::complex<double> w(1.0, 0.0);
for (int i = 0; i < mid; i++, w *= unit) {
std::complex<double> x(A[i | l]), y(w * A[i | l | mid]);
A[i | l] = x + y;
A[i | l | mid] = x - y;
}
}
}
for (int i = 0; inv == -1 && i < len; i++) {
A[i] /= len;
}
return;
}
于是最终版的多项式乘法代码如下:
CODE
// M 多项式最大的长度
std::vector<int> idx(M << 2);
void fft(int len, std::vector<std::complex<double>> &A, int inv) {
for (int i = 1; i < len; i++) {
if (i < idx[i]) {
std::swap(A[i], A[idx[i]]);
}
}
for (int mid = 1; mid < len; mid <<= 1) {
std::complex<double> unit(cos(M_PI / mid), inv * sin(M_PI / mid));
for (int l = 0; l < len; l += (mid << 1)) {
std::complex<double> w(1.0, 0.0);
for (int i = 0; i < mid; i++, w *= unit) {
std::complex<double> x(A[i | l]), y(w * A[i | l | mid]);
A[i | l] = x + y;
A[i | l | mid] = x - y;
}
}
}
for (int i = 0; inv == -1 && i < len; i++) {
A[i] /= len;
}
return;
}
// 传两个多项式返回这两个多项式相乘(卷积)的结果
std::vector<i64> converlution(std::vector<i64> &a, std::vector<i64> &b) {
int la = a.size() - 1, lb = b.size() - 1;
int len = 1;
// 长度是 2 的幂且必须大于最高次数
while (len <= la + lb) {
len <<= 1;
}
for (int l = 1, r = 1, bit = len >> 1; l < len; l <<= 1, bit >>= 1) {
for (int i = 0; i < l; i++) {
idx[r++] = idx[i] | bit;
}
}
std::vector<std::complex<double>> A(len);
for (int i = 0; i < len; i++) {
A[i] = std::complex<double>(i <= la ? a[i] : 0.0, i <= lb ? b[i] : 0.0);
}
fft(len, A, 1);
for (int i = 0; i < len; i++) {
A[i] = A[i] * A[i];
}
fft(len, A, -1);
std::vector res(la + lb + 1, 0ll);
for (int i = 0; i <= la + lb; i++) {
res[i] = i64(A[i].imag() / 2.0 + 0.5);
}
return res;
}
浙公网安备 33010602011771号