多项式乘法——FFT 快速傅里叶变换

log

25-2-26 更新了多项式乘法的实现(前面光写 fft 去了)

25-2-26 更新三次变二次优化 和 迭代 fft

问题引入

现在有两个多项式: \(A(x) = \sum _{i = 0}^{n - 1} a_i x^i\)\(B(x) = \sum _{i = 0}^{m - 1} b_i x^i\),我们的目的是要求这两个多项式相乘后得到的多项式 \(C(x)\)\(C(x) = \sum _{i = 0}^{n + m - 2} c_i x^i\)

初步分析

要求解 \(C(x)\),就是要求得 \(C(x)\) 每一项的系数 \(c_i\)

若暴力的计算 \(c_i\),那么 \(c_i = \sum _{j = 0}^i a_j b_{i - j}\),总的时间复杂度是 \(O((n + m)^2)\) 的。

我们知道,直角坐标系上 \(n\) 个不同点可以唯一确定一个过这 \(n\) 个点的 \(n - 1\) 次多项式。也就是说对于 \(C(x)\) 来讲,只要确定了它经过的 \(n + m - 1\) 个不同的点的坐标,我们就有办法确定 \(C(x)\) 这个多项式。于是我们取 \(n + m - 1\) 个不同的点 \(\{ (x_0, C(x_0)), (x_1, C(x_1)), \dots , (x_{n + m - 2}, C(x_{n + m - 2})) \}\),其中 \(C(x_i) = A(x_i)B(x_i)\),即我们要在 \(A(x)\)\(B(x)\) 上分别取 \(n + m - 1\) 个点。如果暴力地取点时间复杂度还是 \(O((n + m)^2)\),而且我们还不知道当取完点后如何通过这些点确定 \(C(x)\)

若果有方法能够快速地取点(DFT)并快速地根据点来确定多项式(IDFT),那我们就解决了问题。

取点——DFT 离散傅里叶变换

此处用“取点”表示其他博文中的“将系数表达式映射到点值表达式”(因为我感觉这个过程就是在取点)

DFT 的定义

因为光看FFT代码搞不清FFT与傅里叶变换的关系,于是有了这一部分

离散傅里叶变换将长度为 \(n\) 的时域信号 \(x[t]\) 转换为频域信号 \(X[k]\),其数学表达式为:

\[X[k] = \sum_{t = 0} ^{n - 1} x[t] e^{-i{2 \pi \over n}kt} \]

其中:

  • \(x[t]\) 表示时域信号的第 \(t\) 个采样点
  • \(X[k]\) 表示频域信号的第 \(k\) 个频率分量

那这坨公式和我们的多项式有什么关系呢?

抛开定义中的物理意义,假设有个多项式 \(P(u) = \sum_{t = 0} ^{n - 1} x[t] u^t\)(也就是说把定义中的时域信号看作是多项式的系数),于是 \(P(e^{-i{2 \pi \over n}k}) = \sum_{t = 0} ^{n - 1} x[t] (e^{-i{2 \pi \over n}k})^t = X[k]\)。此时对多项式取点的操作就与 DFT 产生了联系:对于一个 \(n - 1\) 次多项式,我们可以对其进行 DFT 以求得这个多项式在 \(e^{-i{2 \pi \over n}0},e^{-i{2 \pi \over n}1},e^{-i{2 \pi \over n}2}, \dots ,e^{-i{2 \pi \over n}(n - 1)}\)\(n\) 个位置上的值,这 \(n\) 个点其实就是单位根

但是如果按照定义直接每个点单独求值,时间复杂度还是 \(O(n^2)\) 的,并没有什么优化,于是就有 FFT。

(PS:在问题分析中我们提到要在 \(A(x)\)\(B(x)\) 上分别取 \(n + m - 1\) 个点,为了解决这个问题,我们可以在两个多项式后面添项,如:\(A(x) = (\sum _{i = 0}^{n - 1} a_i x^i) + 0 \times x^n + \dots + 0 \times x^{n + m - 2}\)

FFT——DFT 的高效实现

(观看视频 BV1za411F76U 可以比较清楚地知道 FFT 的实现。)

FFT 是 DFT 数学公式的高效实现,通过分治和对称性减少计算量。

我们将原本的多项式(这里是添项过后的,而且项数 \(n\) 是 2 的幂) \(A(x) = \sum _{i = 0}^{n - 1} a_i x^i\) 的系数分成奇偶两部分,分别构成两个新的多项式 \(A_e(x) = \sum _{i = 0}^{\frac{n}{2} - 1} a_{2i} x^i\)\(A_o(x) = \sum _{i = 0}^{\frac{n}{2} - 1} a_{2i + 1} x^i\)。(角标的含义是 even 和 odd)。

\(A(x)\) 还可以这样表示:

\[A(x) = A_e(x^2) + x A_o(x^2) \]

那么代入单位根 \(w_n^k\) 得到:

\[A(w_n^k) = A_e((w_n^k)^2) + w_n^k A_o((w_n^k)^2) \]

\[A(w_n^k) = A_e(w_{n \over 2}^k) + w_n^k A_o(w_{n \over 2}^k) \]

而且我们不需要从 0 到 \(n - 1\) 取枚举 \(k\),因为对于 \(k < {n \over 2}\),有 \(w_n^{k + {n \over 2}} = -w_n^k\),即:

\[A(w_n^{k + {n \over 2}}) = A(-w_n^k) = A_e(w_{n \over 2}^k) - w_n^k A_o(w_{n \over 2}^k) \]

所以我们只需要枚举前一半就好了。

对于 \(A_e(w_{n \over 2}^k)\)\(A_o(w_{n \over 2}^k)\),同样用 FFT,递归的求就可以了,当项数为 1 时直接返回系数或者什么都不做就好了。

于是就可以写出递归版的 FFT (代码等说了 IDFT 再给出)

确定多项式——IDFT 离散傅里叶逆变换

在 DFT 中我们实际是解了一个这样的方程:

\[\begin{bmatrix} (w_n^0)^0&(w_n^0)^1&\dots&(w_n^0)^{n - 1}\\ (w_n^1)^0&(w_n^1)^1&\dots&(w_n^1)^{n - 1}\\ \vdots&\vdots& &\vdots\\ (w_n^{n - 1})^0&(w_n^{n - 1})^1&\dots&(w_n^{n - 1})^{n - 1}\\ \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ \vdots\\ a_{n - 1} \end{bmatrix} = \begin{bmatrix} A(w_n^0)\\ A(w_n^1)\\ \vdots\\ A(w_n^{n - 1})\\ \end{bmatrix} \]

即求解了 \(A(w_n^i)\)

那么在 IDFT 中我们知道了 \(A(w_n^i)\) 要求 \(a_i\)。在方程两边同时左乘第一个矩阵的逆矩阵(求逆矩阵直接搜范德蒙矩阵的逆矩阵就好了,虽然有点区别,但直接类比过来就好了)得到:

\[\begin{bmatrix} a_0\\ a_1\\ \vdots\\ a_{n - 1} \end{bmatrix} = {1 \over n}\begin{bmatrix} (w_n^0)^0&(w_n^0)^{-1}&\dots&(w_n^0)^{-(n - 1)}\\ (w_n^1)^0&(w_n^1)^{-1}&\dots&(w_n^1)^{-(n - 1)}\\ \vdots&\vdots& &\vdots\\ (w_n^{n - 1})^0&(w_n^{n - 1})^{-1}&\dots&(w_n^{-(n - 1)})^{-(n - 1)}\\ \end{bmatrix} \begin{bmatrix} A(w_n^0)\\ A(w_n^1)\\ \vdots\\ A(w_n^{n - 1})\\ \end{bmatrix} \]

这几乎是和 DFT 一样的问题,只不过 IDFT 中的单位根要和 DFT 中的单位根共轭,并且最后还要除 \(n\)

于是就可以得到递归版的代码:

// len表示项数,并且一定是 2 的幂
void fft(int len, std::vector<std::complex<double>> &A, bool inv) {
    if (len == 1) {
        return;
    }

    int ll = len >> 1;
    std::vector<std::complex<double>> Ae(ll), Ao(ll);
    for (int i = 0; i + i < len; i++) {
        Ae[i] = A[i << 1];
        Ao[i] = A[i << 1 | 1];
    }
    fft(ll, Ae, inv);
    fft(ll, Ao, inv);

    double angle = 2.0 * M_PI / double(len);
    std::complex<double> unit(cos(angle), (inv ? 1 : -1) * sin(angle)), w(1.0, 0.0);
    for (int i = 0; i < ll; i++, w *= unit) {
        std::complex<double> tmp = w * Ao[i];
        A[i] = Ae[i] + tmp;
        A[i + ll] = Ae[i] - tmp;
    }
    // 当使用 IDFT 时,别忘了在函数外边另外除项数。
    return;
}

有了最基础递归版的 FFT,就可以将 P3803 【模板】多项式乘法(FFT) 解决了。(题解里面有挺多很好的题解的)

多项式乘法的实现

和一开始分析的思路一样,先在给定的两个多项式中取点(DFT),取完点后再根据点确定要求的多项式的系数(IDFT),于是有如下代码:

void solve()
{
    // 最高次数
    int n = 0, m = 0;
    std::cin >> n >> m;
    std::vector<double> a(n + 1), b(m + 1);
    for (auto &i : a) {
        std::cin >> i;
    }
    for (auto &i : b) {
        std::cin >> i;
    }

    int len = 1;
    while (len <= m + n) {
        len <<= 1;
    }

    std::vector<std::complex<double>> A(len), B(len), C(len);
    // 由于 fft 涉及到复数运算,这里就把实数系数转变为了复数系数,同时也完成了添项的操作
    for (int i = 0; i < len; i++) {
        A[i] = std::complex<double>((i <= n ? a[i] : 0.0), 0.0);
        B[i] = std::complex<double>((i <= m ? b[i] : 0.0), 0.0);
    }
    fft(len, A, false);
    fft(len, B, false);
    for (int i = 0; i < len; i++) {
        C[i] = A[i] * B[i];
    }
    fft(len, C, true);
    for (int i = 0; i <= m + n; i++) {
        // +0.5 就相当于是四舍五入了
        std::cout << i64((C[i].real())/ double(len) + 0.5) << ' ';
    }
    std::cout << '\n';
    return;
}

三次变二次优化

这里的次数是指在多项式乘法的过程进行了多少次 DFT,显然上面的版本是进行了三次:对多项式 \(A(x)\)\(B(x)\) 各进行了一次正变换,为了得到 \(C(x)\) 的系数又将正变换的结果相乘进行逆变换。可以通过把 \(B(x)\) 的系数放在 \(A(x)\) 的虚部从而只进行一次正变换。然后对变换结果的平方进行逆变换,最后虚部的一半就是 \(C(x)\) 的系数,即:

    // code ...
    std::vector<std::complex<double>> A(len);
    for (int i = 0; i < len; i++) {
        A[i] = std::complex<double>(i <= la ? a[i] : 0.0, i <= lb ? b[i] : 0.0);
    }
    fft(len, A, false);
    for (int i = 0; i < len; i++) {
        A[i] = A[i] * A[i];
    }
    fft(len, A, true);

    for (int i = 0; i <= la + lb; i++) {
        // 这里的 / len 是逆变换的最后一步,如果在 fft 中已经除过了这里就不用再除了。
        std::cout << i64(A[i].imag() / len / 2.0 + 0.5) << ' ';
    }
    std::cout << '\n';
    // code ...

解释如下(优化的方法是没问题的,但以下解释没有严格证明,只相当于自己说服自己):

首先 \((a + bi) \times Z = aZ + bZi\) (\(i\) 是虚数单位, \(Z\) 是一个复数),也就是说对一系列复数 DFT 就相当于对这一系列复数的实部和虚部分别做 DFT 然后虚部的结果乘个虚数单位再和实部的结果加起来,反之亦然。

然后 \((a + bi) ^ 2 = (a^2 - b^2) + 2abi\),所以平方后做逆变换,虚部就是我们要的结果的两倍。

迭代 FFT

迭代优化是基于一个神奇的规律。

例如当对长度为 8 的序列进行 FFT 时,在递归版中我们要不断依据下标的奇偶将系数分为两部分。于是观察分组:

第一层: 0 1 2 3 4 5 6 7
第二层: 0 2 4 6,1 3 5 7
第三层: 0 4,2 6,1 5,3 7
第四层: 0,4,2,6,1,5,3,7

当我们将最后一层下标的二进制倒过来,我们就会发现这个规律:

000,100,010,110,001,101,011,111 ->
000,001,010,011,100,101,110,111
0  ,1  ,2  ,3  ,4  ,5  ,6  ,7

于是我们要做的就是将序列按照最后一层的下标排列,然后模拟每层的计算就好了。

那么如何获取最后一层的下标呢?

先看如下代码做了什么:

// N 是 2 的幂
void fun1(int N) {
    std::vector idx(N, 0);
    for (int l = 1, r = 1, bit = 1; l < N; l <<= 1, bit <<= 1) {
        for (int i = 0; i < l; i++) {
            idx[r++] = idx[i] | bit;
        }
    }
}

是的,就是 \(O(N)\) 地生成了从 \(0\)\(N - 1\) 的有序序列。这么做的意义是什么呢?意义在于,只要我们稍微改一下上述代码,就可以 \(O(N)\) 地生成最后一层的下标:

// N 是 2 的幂
void fun2(int N) {
    std::vector idx(N, 0);
    for (int l = 1, r = 1, bit = (N >> 1); l < N; l <<= 1, bit >>= 1) {
        for (int i = 0; i < l; i++) {
            idx[r++] = idx[i] | bit;
        }
    }
}

获取了最后一层的下标后,我们进行 FFT 时第一步就是调整原序列,之后的模拟递归看代码:

// 要先处理出 idx
std::vector<int> idx(N << 2);
void fft(int len, std::vector<std::complex<double>> &A, int inv) {
    // 调整原序列,使得其与递归最下层的下标相对应
    for (int i = 1; i < len; i++) {
        if (i < idx[i]) {
            std::swap(A[i], A[idx[i]]);
        }
    }

    for (int mid = 1; mid < len; mid <<= 1) {
        std::complex<double> unit(cos(M_PI / mid), inv * sin(M_PI / mid));
        for (int l = 0; l < len; l += (mid << 1)) {
            std::complex<double> w(1.0, 0.0);
            for (int i = 0; i < mid; i++, w *= unit) {
                std::complex<double> x(A[i | l]), y(w * A[i | l | mid]);
                A[i | l] = x + y;
                A[i | l | mid] = x - y;
            }
        }
    }

    for (int i = 0; inv == -1 && i < len; i++) {
        A[i] /= len;
    }
    return;
}

于是最终版的多项式乘法代码如下:

CODE
// M 多项式最大的长度
std::vector<int> idx(M << 2);
void fft(int len, std::vector<std::complex<double>> &A, int inv) {
    for (int i = 1; i < len; i++) {
        if (i < idx[i]) {
            std::swap(A[i], A[idx[i]]);
        }
    }

    for (int mid = 1; mid < len; mid <<= 1) {
        std::complex<double> unit(cos(M_PI / mid), inv * sin(M_PI / mid));
        for (int l = 0; l < len; l += (mid << 1)) {
            std::complex<double> w(1.0, 0.0);
            for (int i = 0; i < mid; i++, w *= unit) {
                std::complex<double> x(A[i | l]), y(w * A[i | l | mid]);
                A[i | l] = x + y;
                A[i | l | mid] = x - y;
            }
        }
    }

    for (int i = 0; inv == -1 && i < len; i++) {
        A[i] /= len;
    }
    return;
}

// 传两个多项式返回这两个多项式相乘(卷积)的结果
std::vector<i64> converlution(std::vector<i64> &a, std::vector<i64> &b) {
    int la = a.size() - 1, lb = b.size() - 1;
    int len = 1;
    // 长度是 2 的幂且必须大于最高次数
    while (len <= la + lb) {
        len <<= 1;
    }
    
    for (int l = 1, r = 1, bit = len >> 1; l < len; l <<= 1, bit >>= 1) {
        for (int i = 0; i < l; i++) {
            idx[r++] = idx[i] | bit;
        }
    }

    std::vector<std::complex<double>> A(len);
    for (int i = 0; i < len; i++) {
        A[i] = std::complex<double>(i <= la ? a[i] : 0.0, i <= lb ? b[i] : 0.0);
    }
    fft(len, A, 1);
    for (int i = 0; i < len; i++) {
        A[i] = A[i] * A[i];
    }
    fft(len, A, -1);

    std::vector res(la + lb + 1, 0ll);
    for (int i = 0; i <= la + lb; i++) {
        res[i] = i64(A[i].imag() / 2.0 + 0.5);
    }
    return res;
}
posted @ 2025-02-23 22:27  Young_Cloud  阅读(119)  评论(0)    收藏  举报