FFT 和 NTT 学习笔记

参考了 Alex_Wei 的博客,虽说fjj觉得他写得很乱,但我认为这写得很好。

感谢htc大佬的讲解。

参考了一份题解的代码。


FFT

作用

FFT 干了一件什么事呢?首先要理解多项式乘法和卷积的关系。你有两个多项式 \(A,B\),它们的系数数组是 \(a,b\)\(A\times B\) 的系数数组就是 \(a*b\)\(*\) 表示卷积)。这一点应该是比较好理解的。

先对数组 \(a,b\) 作一遍FFT,得到多项式 \(A,B\) 的点值表示法(也就是 \((x_i,y_i)\) 形式,\(n\) 个点确定一个 \(n-1\) 次函数),接下来就可以把 \(FFT[a]\cdot FFT[b]\),点乘得到 \(FFT[a*b]\),最后执行 IFFT,就能还原出 \(a*b\) 的数值,也就是快速进行了多项式乘法。FFT 的过程是 \(O(n\log n)\) 的,所以比 \(O(n^2)\) 的暴力要快。这个方法也适用于高精度乘法之类。

过程

注:下文 \(n\) 统一是 \(2^k\) 的形式。

FFT 具体是怎样实现的呢?首先我们定义单位根 \(w_n\),表示一个复数满足 \((w_n)^n=1\),显然这样的数有 \(n\) 个,但我们取任意一个。(注意,这里的下标不是序号!是幂次!

我们定义 \(F_a(x)=\sum\limits_{i=0}^{n-1}a_ix^i\)。我们定义 \(FFT[a]_i=F_a(w_n^i)\),也就是所谓的“点值表示法”。这个东西要怎么快速计算呢?

这里就要参考Alex_Wei 的博客中的讲解了。简单来说,按照奇次项和偶次项分成两个函数,再把 \(w_n\) 代换为 \(w_\frac{n}{2}\),根据一个定理(或是感性理解),\(w_\frac{n}{2}=w_n^2\),所以就能把问题规模缩小。再根据奇函数和偶函数的性质:

\[\begin{equation} \begin{cases} \ f(x)=f'_e(x^2)+xf'_o(x^2)\\ f(-x)=f'_e(x^2)-xf'_o(x^2)\\ \end{cases} \end{equation} \]

这是他博客中的式子。记住这里 \(f_o\) 系数的 \(x\) 也是变量。

根据 \(w_n^k=-w_n^{k+\frac{n}{2}}\) ,上述式子就可以代入到 FFT 中,作为分治的理论基础。

分治的过程可以使用蝴蝶变换来优化。详见Alex_Wei 的博客。但是这种优化会使得代码的理解难度暴增。

逆变换:和FFT差不多,只是 \(IFFT[a]_i=\frac{1}{n}f_a(w_n^{-i})\)。证明不会欸。

NTT

原根

比起 FFT,唯一的区别就是用原根代替了单位根。

原根可以被理解为模意义下的单位根,原根 \(g\) 满足 \(g^{\varphi(p)}\equiv 1\pmod p\),而且 \(g^i\) 在成为 \(1\) 之前一直都是不同的,也即其循环节长度为 \(\varphi(p)\)。例如 \(998244353\) 的一个原根是 \(3\),而且 \(\varphi(998244353)=998244352\)\(2^{23}\) 的倍数,我们计算 \(w_{\frac{n}{2^k}}\) 时就直接用 \(3^\frac{998244352}{2^k}\) 就行了。

这个东西的好处是,不用写邪恶的复数,且速度更快,不会有精度问题。坏处是只能针对特定模数,否则就要用 MTT。

代码

借用了某篇题解,部分使用了AI注释。(byd AI 给代码里乱加东西)

#include<cstdio>
// 快速读入字符的宏定义,用于提高输入效率
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
// 交换两个变量的值,使用异或运算
#define swap(x,y) x ^= y, y ^= x, x ^= y
// 定义长整型别名LL
#define LL long long 
// 常量定义,MAXN为数组最大长度,P为模数,G为原根,Gi为G的逆元
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118; 
// 缓冲区定义,用于快速读入
char buf[1<<21], *p1 = buf, *p2 = buf;

// 快速读入整数的函数
inline int read() { 
    char c = getchar(); int x = 0, f = 1;
    // 跳过非数字字符,判断负号
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    // 读取数字字符并转换为整数
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

// 输入的多项式的次数N和M,limit为NTT变换的长度,L为log(limit),r为位反转数组
int N, M, limit = 1, L, r[MAXN];
// 存储多项式系数的数组a和b
LL a[MAXN], b[MAXN];

// 快速幂函数,计算a^k % P
inline LL fastpow(LL a, LL k) {
    LL base = 1;
    while(k) {
        if(k & 1) base = (base * a ) % P;
        a = (a * a) % P;
        k >>= 1;
    }
    return base % P;
}

// NTT变换函数,type为1时表示正变换,type为-1时表示逆变换
inline void NTT(LL *A, int type) {
    // 位反转重排
    for(int i = 0; i < limit; i++) 
        if(i < r[i]) swap(A[i], A[r[i]]);
    // 进行蝶形操作
    for(int mid = 1; mid < limit; mid <<= 1) {	//分治的层级
        LL Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1));//计算对应幂次的单位根
        for(int j = 0; j < limit; j += (mid << 1)) {
            LL w = 1;
            for(int k = 0; k < mid; k++, w = (w * Wn) % P) {//这里的w就是上面公式中的自变量x,随着下标变化而变化
                int x = A[j + k], y = w * A[j + k + mid] % P;
                A[j + k] = (x + y) % P,
                A[j + k + mid] = (x - y + P) % P;
				// 根据之前的公式
            }
        }
    }
}

int main() {
    // 读取多项式的次数N和M
    N = read(); M = read();
    // 读取多项式a的系数
    for(int i = 0; i <= N; i++) a[i] = (read() + P) % P;
    // 读取多项式b的系数
    for(int i = 0; i <= M; i++) b[i] = (read() + P) % P;
    // 计算NTT变换的长度limit,使其为大于等于N+M的最小2的幂次,并计算L=log(limit)
    while(limit <= N + M) limit <<= 1, L++;
    // 计算位反转数组r
    for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));	
    // 对多项式a和b进行NTT正变换
    NTT(a, 1); NTT(b, 1);	
    // 点乘操作,得到卷积后的系数
    for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
    // 对卷积后的系数进行NTT逆变换
    NTT(a, -1);	
    LL inv = fastpow(limit, P - 2);
    for(int i = 0; i <= N + M; i++)
        printf("%d ", (a[i] * inv) % P);//乘上逆变换系数
    return 0;
}
posted @ 2025-09-12 11:10  Luke_li  阅读(6)  评论(0)    收藏  举报