Loading

多项式 - 快速傅里叶变换

解决问题

求解两个多项式相乘。
即对于两个 \(n\) 项多项式

\[f(x) = \sum_{i=0}^{n-1} a_i\times x^{i} \\ g(x) = \sum_{i=0}^{n-1} b_i\times x^{i} \]

求出

\[h(x) = \sum_{i=0}^{n-1} c_i\times x^{i} \]

其中

\[c_i=\sum_{j=0}^{i}a_j \times b_{i-j} \]

以上, \(n\) 为一个足够大的值。

前置知识

-复数

请自行学习复数相关知识。

- 单位根

将复平面单位圆 \(n\) 等分,规定 \((1,0)\)\(0\) 次单位根,逆时针依次标号为 \(\left [0, n\right )\)

如图:

可以发现 \(\omega_{n}^{k}\) 即为 \(e^{i\frac{2\pi }{n}}\)
由此可以证明以下=几条性质(当然也可以在单位圆上数形结合):

  • \(\omega_n^{n+k} = \omega_n^k\)

  • \(\omega_{dn}^{dk} = \omega_n^k\)

  • \(\omega_n^{k+\frac{n}{2}} = -\omega_n^k\)

  • \(\omega_n^{-k}\) \(\texttt{与}\) \(\omega_n^k\) \(\texttt{共轭}\)

单位根反演

\[\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{v\times i} = \left[v \bmod n = 0 \right] \]

证:
\(v \bmod n = 0\) 时 , \(v = kn, k\in Z\)

\[\begin{align*} \frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{v\times i} &=\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{ik\times n} \\ &=\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^0 =\frac{1}{n}\sum_{i=0}^{n-1}1 \\ &=\frac{1}{n}\times n = 1 \end{align*} \]

\(v \bmod n \not= 0\) 时 , \(v = kn + r, k,r \in Z, r < n\)

\[\begin{align*} \frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{v\times i} &=\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{(kn+r)\times i} \\ &=\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{ki\times n + i\times r} =\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{i\times r} \end{align*} \]

又根据等比数列求和公式:

\[\sum_{i=0}^{n-1}a_0\times q^i = a_0\frac{1-q^n}{1-q} \]

\[\begin{align*} \frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{i\times r} &=\frac{1}{n}\sum_{i=0}^{n-1}(\omega_n^r)^i \\ &=\frac{1}{n}\times \frac{1-\omega_n^{nr}}{1-\omega_n^r} =\frac{1}{n}\times \frac{1-\omega_n^0}{1-\omega_n^r} \\ &=\frac{1}{n}\times 0 = 0 \end{align*} \]

两种情况综合,得证。

韦达定理

\(n\) 个点确定一个 \(n\) 项多项式。

过程

如果我们得知 \(n\) 个形如

\[\left (x_i, f\left(x_i\right) \right) \]

的点值,便可以得出 \(h\) 的点值表示:

\[\left ( x_i, f\left(x_i \right) \times g\left(x_i \right) \right) \]

然后我们便可以确认 \(h\) 的系数。
将系数表示法转化为点值表示法的过程称为 \(DFT\) ,逆过程称为 \(IDFT\)

DFT

对于上述一个多项式,我们将它拆开来看:

\[f\left(x\right) = a_0 + a_1x + a_2x^2 + a_3x^3 + \cdots + a_{n-1}x^{n-1} \]

将每一项按照奇偶性分离:

\[f\left(x\right) = a_0 + a_2x^2 + \cdots + a_{n-2}x^{n-2} + a_1x + a_3x^3 + \cdots + a_{n-1}x^{n-1} \]

假定我们有另外两个多项式:

\[f_1\left(x\right)=a_0 + a_2x + a_4x^2 + \cdots + a_{n-2}x^{\frac{n-2}{2}} \\ f_2\left(x\right)=a_1 + a_3x + a_5x^2 + \cdots + a_{n-1}x^{\frac{n-2}{2}} \]

可以发现:

\[f\left(x\right)=f_1\left(x^2 \right) + x f_2\left(x^2\right) \]

好的,我们将它推成了很适合分治的形式,但这还不够。
接下来我们将 \(\omega_n^k\) 代入看看能得到什么:
\(k < \frac{n}{2}\),代入 \(\omega_n^k\)\(\omega_n^{k + \frac{n}{2}}\) 得到:

\[f\left(\omega_n^k\right) = f_1\left(\omega_n^{2k}\right)+\omega_n^kf_2\left(\omega_n^{2k}\right) \\ f\left(\omega_n^{k+\frac{n}{2}}\right) = f_1\left(\omega_n^{2k+n}\right)+\omega_n^{k+\frac{n}{2}}f_2\left(\omega_n^{2k+n}\right) \\ =f_1\left(\omega_n^{2k}\right)-\omega_n^kf_2{\omega_n^{2k}} \]

我们惊喜地发现,只需要将 \(\omega^{k}_{\frac{n}{2}}\) 分别代入 \(f_1\)\(f_2\),便可以得出 \(f\) 所有的点值。
每次问题的规模都被缩小一半,复杂度满足

\[T\left(n\right) = 2T\left(\frac{n}{2}\right) + O\left(n\right) \]

\(O\left(n\log n\right)\)
另外对于具体实现,发现我们需要保证 \(n\)\(2\) 的整次幂。

IDFT

对于 \(h\) 的其中一项 \(c_i\),有:

\[\begin{align*} c_i&=\sum_{j=0}^{i}a_j \times b_{i-j} \\ &=\sum_{j=0}\sum_{k=0} a_j\times b_k \times \left[ j + k = i\right] \\ &=\sum_{j=0}\sum_{k=0} a_j\times b_k \times \left[ j + k - i \bmod n = 0\right] \\ &=\sum_{j=0}\sum_{k=0} a_j\times b_k \times \frac{1}{n}\sum_{l=0}\omega_n^{\left(j + k - i \right)\times l} \\ &=\sum_{j=0}\sum_{k=0} a_j\times b_k \times \frac{1}{n}\sum_{l=0}\omega_n^{\left(j + k - i \right)\times l} \end{align*} \]

则:

\[\begin{align*} nc_i&=\sum_{j=0}\sum_{k=0} a_j\times b_k \times \sum_{l=0}\omega_n^{jl}\times \omega_n^{kl} \times \omega_n^{-il} \\ &=\sum_{l=0}\omega_n^{-il} \times \sum_{j=0}a_j \omega_n^{jl} \times \sum_{k=0}b_k\omega_n^{kl} \\ &=\sum_{l=0}\omega_n^{i\times \left(-l\right)} f\left(\omega_n^l\right) g\left(\omega_n^l\right) \\ &=\sum_{l=0} h\left(\omega_n^l\right) \omega_n^{i\times \left(-l\right)} \end{align*} \]

\(DFT\) 得到的 \(h\)\(g\) 的点值相乘作为系数,在分别代入 \(\omega_n^{-k}\),得到的结果在除以 \(n\),便得到了对应的系数。
如何代入 \(\omega_n^{-k}\),详情见代码:


#include <bits/stdc++.h>
#define lep(i, a, b) for (int i = a; i <= b; ++i)
#define rep(i, a, b) for (int i = a; i >= b; --i)

typedef std::complex<double> cp;
typedef long long ll;
const int _ = 4e6 + 7;
const double PI = std::acos(-1);

int n, m;
cp A[_], B[_], b[_];

void FFT(cp a[], int n, int opt) {
    if (n == 1) return; int mid = (n >> 1);
    lep(i, 0, mid - 1) b[i] = a[i << 1], b[i + mid] = a[i << 1 | 1];
    lep(i, 0, n - 1) a[i] = b[i];
    FFT(a, mid, opt), FFT(a + mid, mid, opt);
    cp wk = cp(1.0, 0.0), w1 = cp(std::cos(1.0 * PI / mid), opt * std::sin(1.0 * PI / mid));//共轭,虚部取负
    lep(i, 0, mid - 1) {
        b[i] = a[i] + wk * a[i + mid],
        b[i + mid] = a[i] - wk * a[i + mid];
        wk *= w1;
    }
    lep(i, 0, n - 1) a[i] = b[i];
}

int main() {
    std::ios::sync_with_stdio(false),
    std::cin.tie(nullptr), std::cout.tie(nullptr);
    std::cin >> n >> m;
    lep(i, 0, n) std::cin >> A[i];
    lep(i, 0, m) std::cin >> B[i];
    
    m = n + m, n = 1;
    while (n <= m) n <<= 1;
    
    FFT(A, n, 1), FFT(B, n, 1);
    lep(i, 0, n - 1) A[i] *= B[i];
    
    FFT(A, n, -1);
    lep(i, 0, m) std::cout << (int)(A[i].real() / n + 0.5) << ' ';
    return 0;
}

蝴蝶变换优化

上述是常见的 \(FFT\) 递归写法,但其实,我们可以将其写成常数更小的非递归写法。

对于一个 \(n\) 项多项式,递归过程中的系数变化如下:


s: a_0 a_1   a_2 a_3   a_4 a_5   a_6 a_7
1: 000 001   010 011   100 101   110 111
2: 000 010   100 110 | 001 011   101 111
3: 000 100 | 010 110 | 001 101 | 011 111
t: a_0 a_4   a_2 a_6   a_1 a_5   a_3 a_7

可以发现,递归分奇偶的过程就是将 \(a_i\)\(a_j\) 交换。
其中 \(i\) 的二进制表示为 \(j\) 的二进制表示的翻转(\(n\) 位表示)。

所以我们可以直接处理处每一个系数最后会递归到哪一个位置,然后从下至上递推答案。

定义 \(f\left[i\right]\)\(i\) 二进制表示翻转后的十进制表示,有递推式如下:

\[f\left[ i \right] = \left( f \left[ i >> 1 \right] >> 1 \right) \| \left(\left( i \& 1\right) << \left(l - 1\right)\right) \]

(想一想,为什么)

迭代版代码如下:


void Init() {
    m = n + m, n = 1;
    while (n <= m) n <<= 1, ++l;
    lep(i, 1, n - 1) f[i] = (f[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
void FFT(cp a[], int n, int opt) {
    lep(i, 1, n - 1) if (i < f[i]) std::swap(a[i], a[f[i]]);
    for (int mid = 1; mid < n; mid <<= 1) {
        for (int j = 0; j < n; j += (mid) << 1) {
            cp wk = cp(1, 0), w1 = cp(std::cos(PI / mid), opt * std::sin(PI / mid));
            lep(k, 0, mid - 1) {
                b[j + k] = a[j + k] + wk * a[j + k + mid],
                b[j + k + mid] = a[j + k] - wk * a[j + k + mid];
                wk *= w1;
            }
        }
        lep(i, 0, n - 1) a[i] = b[i];
    }
}
posted @ 2025-01-26 07:31  qkhm  阅读(82)  评论(2)    收藏  举报