快速数论变换(NTT)
一、引入
FFT 跑得很快,但它涉及三角函数运算,常数大且容易有精度误差。
考虑用一个东西等效替代单位根,使得它也能完成 FFT 做的事情,而不会用到浮点数运算。
那么观察一下单位根的性质:
- \(\omega_n^0 = \omega_n^n = 1\)
- \(\omega_n^i = \omega_n^{i+n}\)
- \(\omega_n^i \omega_n^j = \omega_n^{i+j}\)
- \(\omega_{dn}^{di} = \omega_{n}^i\)
- \(\omega_n^i = -\omega_n^{i+\frac{n}{2}}\)
并不难发现,我们很难想到用原根代替单位根进行运算。
而这就是快速数论变换(NTT)的本质思想。
二、前置知识
- 阶、原根相关(该贴第五、六大点)
- 快速傅里叶变换(FFT)
三、与 FFT 的区别
首先需要选择一个合适的质数。
因为 FFT 的本质思想是不断往下分治成子问题,这要求多项式的次数 \(n\) 必须是 \(2\) 的次幂。
回到本题,设找到的质数为 \(p\),它的一个原根是 \(g\),那么选择的“单位根”就是 \(\omega_n = g^{\frac{p-1}{n}}\)(因为 \(\varphi(p) = p-1\))。
这说明 \(p-1\) 是 \(2\) 的次幂的倍数。
摘自 OI-Wiki:常用 NTT(快速数论变换)模数及其原根:
- \(p = 167772161 = 5 \times 2^{25} + 1 ~ (g = 3)\)。
- \(p = 469762049 = 7 \times 2^{26} + 1 ~ (g = 3)\)。
- \(p = 754974721 = 3^2 \times 5 \times 2^{24} + 1 ~ (g = 11)\)。
- \(\color{red}{p = 998244353 = 7 \times 17 \times 2^{23} + 1 ~ (g = 3)}\)。
- \(p = 1004535809 = 479 \times 2^{21} + 1 ~ (g = 3)\)。
定义了 \(\omega\) 之后,检验它是否满足 FFT 中单位根需要的条件。
- \(\omega_n^0 = \omega_n^n = 1\)
- 前者显然,后者 \(=\omega_n^{\frac{p-1}{n} \times n} = \omega_n^{p-1} = 1\)(原根定义)。
- 顺便证明了 \(\omega_n^i = \omega_n^{i+n}\) 也是成立的。
- \(\omega_n^i \omega_n^j = \omega_n^{i+j}\)
- 这是废话。
- \(\omega_{dn}^{di} = \omega_{n}^i\)
- 也是废话,\(n,i\) 同乘 \(d\) 指数上分别是分子和分母,消掉了。
- \(\omega_n^i = -\omega_n^{i+\frac{n}{2}}\)
- 由于 \(\omega_n^n = 1\),所以 \(| \omega_n^{\frac{n}{2}} | = 1\)(平方)。
- 又由于 \(\{g^0, g^1, \dots, g^{p-2}\}\) 互不相同(原根性质),所以 \(\omega_n^{\frac{n}{2}} = -1\)。
- 因此得证。
四、NTT 的局限性
首先由于使用了模意义下的运算,它只能处理系数为整数的多项式。
同时模数 \(p=k \times 2^{t} + 1\),所以它对模数的选择条件较为苛刻。
五、代码实现
这里就只写蝴蝶变换了。
#include <bits/stdc++.h>
using namespace std;
const int N = 4e5 + 5, mod = 998244353;
int n, m, lim, invlim, inv3, a[N], b[N];
inline int qmi(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = res * 1ll * a % mod;
a = a * 1ll * a % mod, k >>= 1;
} return res;
}
inline int inv(int x) { return qmi(x, mod - 2); }
int id[N];
inline int rev(int x, int lim) {
int y = 0; while (lim) y = (y << 1) | (x & 1), x >>= 1, lim >>= 1;
return y;
}
inline void init(int n) { invlim = inv(n), inv3 = inv(3); for (int i = 0; i < n; i++) id[i] = rev(i, n - 1); }
void NTT(int lim, int *f, int opt) {
for (int i = 0; i < lim; i++)
if (i < id[i]) swap(f[i], f[id[i]]);
for (int len = 2; len <= lim; len <<= 1) {
int mid = len >> 1, omega = qmi((opt == 1) ? 3 : inv3, (mod - 1) / len);
for (int i = 0; i < lim; i += len) {
int now = 1;
for (int j = i; j < i + mid; j++, now = now * 1ll * omega % mod ) {
int f0 = f[j], f1 = f[j + mid] * 1ll * now % mod;
f[j] = (f0 + f1) % mod;
f[j + mid] = (f0 - f1 + mod) % mod;
}
}
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; i++) scanf("%d", &a[i]);
for (int i = 0; i <= m; i++) scanf("%d", &b[i]);
lim = 1; while (lim <= n + m) lim <<= 1;
init(lim);
NTT(lim, a, 1), NTT(lim, b, 1);
for (int i = 0; i < lim; i++) a[i] = a[i] * 1ll * b[i] % mod;
NTT(lim, a, -1);
for (int i = 0; i <= n + m; i++) printf("%d ", a[i] * 1ll * invlim % mod);
return 0;
}

浙公网安备 33010602011771号