FFT学习小结

关键词

多项式乘法,系数表示,点值表示,单位根

FFT基本思路

  1. 系数表示->点值多项式
  2. 点值下直接相乘,时间复杂度O(n)
  3. 点值多项式->系数表示

系数表示->点值多项式

  • 分治思想,奇偶分开,单位根
  • 假定\(f(x)=\sum_{i=0}^{n-1}a_ix^i\),其中n为2的幂次
  • 对于一个有n个系数的多项式,点值表示需要n个不同的点,
  • 那么考虑使用单位根\(x^n=1\)的n个解(\(\omega_{n}^0,\omega_{n}^1...\omega_{n}^{n-1}\)),来作为这n个点
  • 那么我们只需求出\(f(\omega_{n}^0),f(\omega_{n}^1)...f(\omega_{n}^{n-1})\)就得到了这个多项式的点值表示
  • 具体做法就是利用一点单位根的性质,我们将奇偶项分开

\(A(x)=a_0+a_2x^2+a_4x^4+...\)
\(B(x)=a1+a_3x^3+a_5x^5+...\)
\(f(x)=A(x^2)+xB(x^2)\)
\(f(\omega_{n}^{k})=A(\omega_{n}^{2k})+\omega_{n}^kB(\omega_{n}^{2k})\)

\(f(\omega_{n}^{k})=A(\omega_{\frac{n}{2}}^{k})+\omega_{n}^kB(\omega_{\frac{n}{2}}^{k})\)
\(f(\omega_{n}^{k+n/2})=A(\omega_{\frac{n}{2}}^{k})-\omega_{n}^kB(\omega_{\frac{n}{2}}^{k})\)
那么我们直接递归下去就行

op=1

void fft(cp* a, int n, int op) {
    if (n == 1) return;
    cp a1[n / 2], a2[n / 2];
    for (int i = 0;i * 2 < n;++i) {
        a1[i] = a[2 * i];
        a2[i] = a[2 * i + 1];
    }

    fft(a1, n / 2, op);
    fft(a2, n / 2, op);

    cp wn = (cp){ cos(2 * pi / n), op*sin(2 * pi / n) }; 
    cp w = (cp){ 1,0 };
    for (int i = 0;i < n / 2;++i) {
        a[i] = a1[i] + w * a2[i];
        a[i + n / 2] = a1[i] - w * a2[i];
        w = w * wn;
    }
}

乘法

假如我们将两个多项式都使用点值表示,并且是相同的n个点,那么我们直接对应相乘,就得到了乘积多项式的点值表示

点值->系数表示

  • 考虑使用拉格朗日插值将点值表示还原到系数表示

  • \(f(x)=\sum_{i=0}^{n-1}f(\omega_{n}^{i})L_i(x)\)

  • \(L_i(x)=\prod_{k\neq i}\frac{x-\omega_{n}^k}{\omega_{n}^i-\omega_{n}^k}\)

  • \(L_i(x)\)可以直接硬求,下面贴一个LLM的做法
    在这里插入图片描述
    在这里插入图片描述

  • 稍微简单一点的做法,利用单位根的正交性质

  • \(\sum_{k=0}^{n-1} \omega_{n}^{k(i-j)}=n\delta_{i,j}\)\(\delta_{i,j}\)为克罗内克符号,\(\delta_{i,j}\)为1当且仅当\(i=j\)

  • 那么我们要构造的\(L_i(x)\)本质上就是要让\(L_i(\omega_{n}^{j})=\delta_{i,j}\)

  • \(L_i(x)=\sum_{k=0}^{n-1} c_{i,k}x^k\)

  • \(\omega_{n}^{j}\)代入\(L_i(\omega_{n}^{j})=\sum_{k=0}^{n-1} c_{i,k}\omega_{n}^{j}\),那么我们对比一下它的正交性质的式子,只需令\(c_{i,k}=\frac{\omega_{n}^{-ik}}{n}\)就能搞定

  • 因此有\(f(x)=\frac{1}{n}\sum_{i=0}^{n-1} f(\omega_{n}^{i})\sum_{k=0}^{n-1}\omega_{n}^{-ki}x^k\)

  • \(f(x)=\frac{1}{n}\sum_{k=0}^{n-1}x^k \sum_{i=0}^{n-1}f(\omega_{n}^i)\omega_{n}^{-ki}\)

  • \(\frac{1}{n}\sum_{i=0}^{n-1}f(\omega_{n}^i)\omega_{n}^{-ki}\)其实就是\(a_k\)

  • 那么\(a_0,a_1...a_{n-1}\) 我们可以看作是求\(g(x)=\frac{1}{n}\sum_{i=0}^{n-1}f(\omega_{n}^i)x^i\)这个多项式在\(\omega_{n}^{-0},\omega_{n}^{-1}...\omega_{n}^{-(n-1)}\)的值

  • 而我们第一部分求的是\(f(x)=\sum_{i=0}^{n-1}a_ix^i\)\(\omega_{n}^{0},\omega_{n}^{1}...\omega_{n}^{(n-1)}\)的值,因此代码是可以复用的

蝶形优化

  • 蝶形优化其实就是自底向上计算,那么首先需要求得每个数最后在哪里?
  • 经过观察可以发现就是将它的二进制位进行一个翻转,比如n=8时(001->100,110->011)
  • 那么将每个数放到最后一层的正确位置后,自底向上计算即可
#include<bits/stdc++.h>
#define lc (o<<1)
#define rc ((o<<1)|1) 
using namespace std;
typedef long long ll;
typedef double db;
constexpr int N = 1 << 22;
constexpr ll inf = 1ll << 60;
const db pi = acos(-1);
struct cp {
    db x = 0, y = 0;
    cp(db x = 0, db y = 0) : x(x), y(y) {}
};
cp operator + (const cp& a, const cp& b) {
    return (cp) { a.x + b.x, a.y + b.y };
}
cp operator - (const cp& a, const cp& b) {
    return (cp) { a.x - b.x, a.y - b.y };
}
cp operator * (const cp& a, const cp& b) {
    return (cp) { a.x* b.x - a.y * b.y, a.x* b.y + a.y * b.x };
}
int n, m, r[N];
cp a[N], b[N], c[N];
void fft(cp* a, int n, int op) {
    for (int i = 0;i < n;++i) if (i < r[i]) swap(a[i], a[r[i]]);

    for (int i = 1;i < n;i *= 2) {
        cp wn = (cp){ cos(pi / i), sin(pi / i) * op };
        for (int j = 0;j < n;j += i << 1) {
            cp w = (cp){ 1,0 }, x, y;
            for (int k = 0;k < i;++k) {
                x = a[j + k];
                y = a[j + k + i];
                a[j + k] = x + w * y;
                a[j + k + i] = x - w * y;
                w = w * wn;
            }
        }
    }
}
void R(int& x) {
    int t = 0; char ch;
    for (ch = getchar();!('0' <= ch && ch <= '9');ch = getchar());
    for (;('0' <= ch && ch <= '9');ch = getchar()) t = t * 10 + ch - '0';
    x = t;
}
int main() {
#ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("data.out", "w", stdout);
#endif

    cin >> n >> m;
    n++;
    m++;

    int t;
    for (int i = 0;i < n;++i) R(t), a[i].x = t;
    for (int i = 0;i < m;++i) R(t), b[i].x = t;

    int lim = 1;
    while (lim < n + m) lim <<= 1;

    for (int i = 0;i < lim;++i) {
        r[i] = r[i >> 1] >> 1;
        if (i & 1) r[i] += lim / 2;
    }

    fft(a, lim, 1);
    fft(b, lim, 1);

    for (int i = 0;i < lim;++i) c[i] = a[i] * b[i];
    fft(c, lim, -1);

    for (int i = 0;i <= n + m - 2;++i) printf("%d ", (int)(c[i].x / lim + 0.5));

    return 0;
}
posted @ 2025-10-18 23:45  gan_coder  阅读(6)  评论(0)    收藏  举报