FFT 和 NTT 学习笔记
参考了 Alex_Wei 的博客,虽说fjj觉得他写得很乱,但我认为这写得很好。
感谢htc大佬的讲解。
参考了一份题解的代码。
FFT
作用
FFT 干了一件什么事呢?首先要理解多项式乘法和卷积的关系。你有两个多项式 \(A,B\),它们的系数数组是 \(a,b\)。\(A\times B\) 的系数数组就是 \(a*b\)(\(*\) 表示卷积)。这一点应该是比较好理解的。
先对数组 \(a,b\) 作一遍FFT,得到多项式 \(A,B\) 的点值表示法(也就是 \((x_i,y_i)\) 形式,\(n\) 个点确定一个 \(n-1\) 次函数),接下来就可以把 \(FFT[a]\cdot FFT[b]\),点乘得到 \(FFT[a*b]\),最后执行 IFFT,就能还原出 \(a*b\) 的数值,也就是快速进行了多项式乘法。FFT 的过程是 \(O(n\log n)\) 的,所以比 \(O(n^2)\) 的暴力要快。这个方法也适用于高精度乘法之类。
过程
注:下文 \(n\) 统一是 \(2^k\) 的形式。
FFT 具体是怎样实现的呢?首先我们定义单位根 \(w_n\),表示一个复数满足 \((w_n)^n=1\),显然这样的数有 \(n\) 个,但我们取任意一个。(注意,这里的下标不是序号!是幂次!)
我们定义 \(F_a(x)=\sum\limits_{i=0}^{n-1}a_ix^i\)。我们定义 \(FFT[a]_i=F_a(w_n^i)\),也就是所谓的“点值表示法”。这个东西要怎么快速计算呢?
这里就要参考Alex_Wei 的博客中的讲解了。简单来说,按照奇次项和偶次项分成两个函数,再把 \(w_n\) 代换为 \(w_\frac{n}{2}\),根据一个定理(或是感性理解),\(w_\frac{n}{2}=w_n^2\),所以就能把问题规模缩小。再根据奇函数和偶函数的性质:
这是他博客中的式子。记住这里 \(f_o\) 系数的 \(x\) 也是变量。
根据 \(w_n^k=-w_n^{k+\frac{n}{2}}\) ,上述式子就可以代入到 FFT 中,作为分治的理论基础。
分治的过程可以使用蝴蝶变换来优化。详见Alex_Wei 的博客。但是这种优化会使得代码的理解难度暴增。
逆变换:和FFT差不多,只是 \(IFFT[a]_i=\frac{1}{n}f_a(w_n^{-i})\)。证明不会欸。
NTT
原根
比起 FFT,唯一的区别就是用原根代替了单位根。
原根可以被理解为模意义下的单位根,原根 \(g\) 满足 \(g^{\varphi(p)}\equiv 1\pmod p\),而且 \(g^i\) 在成为 \(1\) 之前一直都是不同的,也即其循环节长度为 \(\varphi(p)\)。例如 \(998244353\) 的一个原根是 \(3\),而且 \(\varphi(998244353)=998244352\) 是 \(2^{23}\) 的倍数,我们计算 \(w_{\frac{n}{2^k}}\) 时就直接用 \(3^\frac{998244352}{2^k}\) 就行了。
这个东西的好处是,不用写邪恶的复数,且速度更快,不会有精度问题。坏处是只能针对特定模数,否则就要用 MTT。
代码
借用了某篇题解,部分使用了AI注释。(byd AI 给代码里乱加东西)
#include<cstdio>
// 快速读入字符的宏定义,用于提高输入效率
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
// 交换两个变量的值,使用异或运算
#define swap(x,y) x ^= y, y ^= x, x ^= y
// 定义长整型别名LL
#define LL long long
// 常量定义,MAXN为数组最大长度,P为模数,G为原根,Gi为G的逆元
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
// 缓冲区定义,用于快速读入
char buf[1<<21], *p1 = buf, *p2 = buf;
// 快速读入整数的函数
inline int read() {
char c = getchar(); int x = 0, f = 1;
// 跳过非数字字符,判断负号
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
// 读取数字字符并转换为整数
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
// 输入的多项式的次数N和M,limit为NTT变换的长度,L为log(limit),r为位反转数组
int N, M, limit = 1, L, r[MAXN];
// 存储多项式系数的数组a和b
LL a[MAXN], b[MAXN];
// 快速幂函数,计算a^k % P
inline LL fastpow(LL a, LL k) {
LL base = 1;
while(k) {
if(k & 1) base = (base * a ) % P;
a = (a * a) % P;
k >>= 1;
}
return base % P;
}
// NTT变换函数,type为1时表示正变换,type为-1时表示逆变换
inline void NTT(LL *A, int type) {
// 位反转重排
for(int i = 0; i < limit; i++)
if(i < r[i]) swap(A[i], A[r[i]]);
// 进行蝶形操作
for(int mid = 1; mid < limit; mid <<= 1) { //分治的层级
LL Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1));//计算对应幂次的单位根
for(int j = 0; j < limit; j += (mid << 1)) {
LL w = 1;
for(int k = 0; k < mid; k++, w = (w * Wn) % P) {//这里的w就是上面公式中的自变量x,随着下标变化而变化
int x = A[j + k], y = w * A[j + k + mid] % P;
A[j + k] = (x + y) % P,
A[j + k + mid] = (x - y + P) % P;
// 根据之前的公式
}
}
}
}
int main() {
// 读取多项式的次数N和M
N = read(); M = read();
// 读取多项式a的系数
for(int i = 0; i <= N; i++) a[i] = (read() + P) % P;
// 读取多项式b的系数
for(int i = 0; i <= M; i++) b[i] = (read() + P) % P;
// 计算NTT变换的长度limit,使其为大于等于N+M的最小2的幂次,并计算L=log(limit)
while(limit <= N + M) limit <<= 1, L++;
// 计算位反转数组r
for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
// 对多项式a和b进行NTT正变换
NTT(a, 1); NTT(b, 1);
// 点乘操作,得到卷积后的系数
for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
// 对卷积后的系数进行NTT逆变换
NTT(a, -1);
LL inv = fastpow(limit, P - 2);
for(int i = 0; i <= N + M; i++)
printf("%d ", (a[i] * inv) % P);//乘上逆变换系数
return 0;
}

浙公网安备 33010602011771号