FFT(快速傅里叶变换)与NTT(快速数论变换)及其蝴蝶迭代

1. FFT

FFT即快速傅里叶变换,是一种高效实现离散傅里叶变换的算法。核心思想是利用拉格朗日插值,将n-1次多项式用n个点及其取值表示,进一步利用到多项式乘法里面来。
为了将\(O(n^2)\)的多项式乘法的复杂度降低到\(O(nlogn)\),这里的n个点取的非常特殊,即为\(w^0_n,w^1_n,...,w^{n-1}_n\),其中\(w^i_n = e^{\frac{-2i\pi}{n}}\)
当n为2的幂次时,我们分成奇数次和偶数次两组,例如\(F(x)=\sum_0^{3}a_ix^i=(a_0 + a_2x^2)+(a_1x+a_3x^3)\)
\(G(x)=a_0+a_2x,H(x)=a_1+a_3x\),则\(F(x)=G(x^2)+xH(x^2)\)
同时\(F(\omega^{k}_n)=G(\omega^{k}_{\frac{n}{2}})+\omega^k_nH(\omega^k_{\frac{n}{2}}),F(\omega^{k+\frac{n}{2}}_n)=G(\omega^{k}_{\frac{n}{2}})-\omega^k_nH(\omega^k_{\frac{n}{2}})\)
这时我们就有了第一种想法,递归处理。
而FFT本身相当于进行了一次矩阵运算,计算

\[\begin{bmatrix} (w_n^0)^0 & (w_n^0)^1 & \cdots & (w_n^0)^{n-1} \\ (w_n^1)^0 & (w_n^1)^1 & \cdots & (w_n^1)^{n-1} \\ \vdots & \vdots & \ddots & \vdots \\ (w_n^{n-1})^0 & (w_n^{n-1})^1 & \cdots & (w_n^{n-1})^{n-1} \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ \vdots \\ a_{n-1} \end{bmatrix} = \begin{bmatrix} A(w_n^0) \\ A(w_n^1) \\ \vdots \\ A(w_n^{n-1}) \end{bmatrix} \]

然后神之一手注意到,

\[\frac{1}{n} \begin{bmatrix} (w_n^{-0})^0 & (w_n^{-1})^0 & \cdots & (w_n^{-(n-1)})^0 \\ (w_n^{-0})^1 & (w_n^{-1})^1 & \cdots & (w_n^{-(n-1)})^1 \\ \vdots & \vdots & \ddots & \vdots \\ (w_n^{-0})^{n-1} & (w_n^{-1})^{n-1} & \cdots & (w_n^{-(n-1)})^{n-1} \end{bmatrix} \begin{bmatrix} A(w_n^0) \\ A(w_n^1) \\ \vdots \\ A(w_n^{n-1}) \end{bmatrix} = \begin{bmatrix} a_0 \\ a_1 \\ \vdots \\ a_{n-1} \end{bmatrix} \]

于是,逆变换本质就是在计算过程中替换你的系数即可。
这里的代码是我自己写的,封装的比较丑,建议参考后面的迭代版本代码,是学习的大神优化后的版本。

点击查看代码
struct C {//complex复数
    double x, y;
    C(double xx = 0, double yy = 0) {
        x = xx;
        y = yy;
    }
    C operator + (C other) {
        return C(x + other.x, y + other.y);
    }
    C operator - (C other) {
        return C(x - other.x, y - other.y);
    }
    C operator * (C other) {
        return C(x * other.x - y * other.y, x * other.y + y * other.x);
    }
};

struct FFT {
    int limit = 1;
    vector<C> a, b, c;
    vector<int> o;//output, 乘积系数
    const double Pi = acos(-1);
    FFT(int n, vector<int>& aa, int m, vector<int>& bb) {
        while (limit <= n + m) {
            limit <<= 1;
        }
        a.assign(limit + 2, {});
        b.assign(limit + 2, {});
        c.assign(limit + 2, {});
        o.assign(limit + 2, 0);
        for (int i = 0; i <= n; ++i) {
            a[i].x = aa[i];
        }
        for (int i = 0; i <= m; ++i) {
            b[i].x = bb[i];
        }
    }
    void fft(int limit, vector<C>& x, int type) {//1为fft,-1为逆运算
        if (limit == 1) {
            return;
        }
        int len = limit >> 1;
        vector<C> a1(len + 2), a2(len + 2);
        for (int i = 0; i <= limit; i += 2) {
            a1[i >> 1] = x[i];
            a2[i >> 1] = x[i + 1];
        }
        fft(len, a1, type);
        fft(len, a2, type);
        C Wn = C(cos(2.0 * Pi / limit), type * sin(2.0 * Pi / limit));
        C w = C(1, 0);//单位1
        for (int i = 0; i < len; ++i, w = w * Wn) {
            x[i] = a1[i] + w * a2[i];
            x[i + len] = a1[i] - w * a2[i];
        }
    }
    void mul() {
        fft(limit, a, 1);
        fft(limit, b, 1);
        for (int i = 0; i <= limit; ++i) {
            c[i] = a[i] * b[i];
        }
        fft(limit, c, -1);
        for (int i = 0; i <= limit; ++i) {
            o[i] = (int)(c[i].x / limit + 0.5);
        }
    }
};

int main() {
    int n, m;
    cin >> n >> m;
    vector<int> a(n + 1), b(m + 1);
    for (int i = 0; i <= n; ++i) {
        cin >> a[i];
    }
    for (int i = 0; i <= m; ++i) {
        cin >> b[i];
    }
    FFT ans(n, a, m, b);
    ans.mul();
    for (int i = 0; i <= n + m; ++i) {
        cout << ans.o[i] << " \n"[i == n + m];
    }
    return;
}

2. NTT

感觉看其他博客的时候讲的比较复杂,不怎么友好,我就建立在FFT基础上来讲。
NTT其实就是取模意义下的FFT,有兴趣的可以了解一下什么叫定义在环上面的运算,这就涉及到数论内容了。
这也就意味着,我们其实是利用一个模意义下的整数环,来计算多项式乘积,对于上面提到的\(w^k_n\),我们用一个整数来表示,这就要求这个模数有原根。

原根

听起来是一个非常高大上的说法,但本质就是模拟上述的复数运算,使得\(w^k_n\)互不相同且\(w^n_n\equiv 1,w^{\frac{n}{2}}\equiv -1\),记模数为\(p\),原根为\(g\),我们实际上就是利用\(g^{\frac{k(p-1)}{n}}\)来模拟\(w^k_n\),这就要求\(m\)的2的幂次,且\(p=c2^n+1\)
而原根实际上就是这个模数满足该性质的最小值,一般是暴力计算,这里需要记下结论。
回到NTT本身,我们就只需要预处理好需要的\(w^k_n\)即可,或者每次快速幂暴力计算。

3. FFT与NTT的蝴蝶迭代

本质就是从递归变成了递归,其余部分没有区别。
我们考虑递归的过程中,每次合并的是哪些下标对应的系数,这就需要一定的注意力了,最后在oiwiki上称之为“位逆序置换”,自己思考的话,建议从二进制拆分入手,你会发现每次末位为1或是为0会影响递归过程的下一次合并对象。
这里放的是NTT版本的,如果需要FFT的话,自行修改即可。

点击查看代码
inline int quick_pow(int x, int p, int mod) {
    int ret = 1;
    while (p) {
        if (p & 1) {
            ret = (1ll * ret * x) % mod;
        }
        x = (1ll * x * x) % mod;
        p >>= 1;
    }
    return ret;
}

//优化过的ntt,使用时记得初始化。998244353,1004535809,469762049,985661441,167772161. g = 3.950009857,g=7
const int maxn = 2e5 + 10;
const int mod = 998244353,g = 3;
int w[maxn << 2], inv[maxn << 2], r[maxn << 2], last = -1;

void init() {
    int lim = maxn << 1;//最长数组的两倍
    inv[1] = 1;
    for(int i = 2; i <= lim; i++)
        inv[i] = mod - (mod / i) * inv[mod % i] % mod;
    for(int i = 1; i < lim; i <<= 1)
    {
        int wn = quick_pow(g, (mod - 1) / (i << 1), mod);
        for(int j = 0, ww = 1; j < i; j++, ww = 1ll * ww * wn % mod)
            w[i + j] = ww;
    }
}

void ntt(vector<int>& g, vector<int>& f, int n, int op)
{
    if(last != n) {
        for(int i = 1; i < n; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
        last = n;
    }
    for (int i = 0; i < n; ++i)
    f[i] = g[i];
    for(int i = 1; i < n; i++) {
        if(i < r[i])
        swap(f[i], f[r[i]]);
    }
    for(int i = 1; i < n; i <<= 1)
    for(int j = 0; j < n; j += (i << 1))
    for(int k = 0; k < i; k++) {
        int x = f[j + k];
        int y = 1ll * f[i + j + k] * w[i + k] % mod;
        f[j + k] = (x + y) % mod;
        f[i + j + k] = (x - y + mod) % mod;
    }
    if(op == -1) {
        reverse(&f[1], &f[n]);
        for(int i = 0;i < n; i++)
        f[i] = 1ll * f[i] * inv[n] % mod;
    }
}
        
inline void mul(int tmp_n,vector<int>& f,vector<int>& g,vector<int>& res) {
    int n = 1;
    while (n < tmp_n) {
        n <<= 1;
    }
    vector<int> d1(n), d2(n);
    ntt(f, d1, n, 1);
    ntt(g, d2, n, 1);
    for(int i = 0 ; i < n; i++)
        d1[i] = 1ll * d1[i] * d2[i] % mod;
    ntt(d1, res, n, -1);
    for(int i = 0; i < n; i++)
        d1[i] = d2[i] = 0;
}
posted @ 2025-08-27 17:08  WE-R  阅读(127)  评论(0)    收藏  举报