FFT 学习笔记
FFT 学习笔记
要求多项式 \(f(x) \times g(x)\) 的乘积。
可以把 \(f(x)\) 和 \(g(x)\) 看做一个 \(n\) 次方程,并且过点 \((x_0, f(x_0)),(x_1, f(x_1)) \dots (x_n, f(x_n))\)。
现在已知 \(f, g\) 的系数表示法。
我们知道 \(f(x) \times g(x)\) 的点值表示法是 \((x_0, f(x_0) \times g(x_0)), (x_1, f(x_1) \times g(x_1)) \dots (x_n, f(x_n) \times g(x_n))\)。
所以 FFT 的核心思想就是将 \(f, g\) 转换成点值表示法,再把乘积的点值表示法转换成系数表示法。
单位根
考虑在复平面上的一个单位圆,进行 \(n\) 等分,得到 \(n\) 个复数。定义 \(\omega_n\) 表示幅角(\(n\) 等分的单位大小),有
定义 \(\omega_n^k\) 表示从第一个点开始,逆时针的第 \(n\) 个点,那么
单位根具有以下性质:
对于第一个式子,拆开即可证明。对于第二个式子,容易想到 \(\frac{n}{2}\) 是一个半圆,且 \(\omega_n^k\) 和 \(-\omega_n^k\) 关于原点对称,即可证明。
对于第三个式子的证明:
我们知道欧拉公式:
所以
证毕。
快速傅里叶变换
对于多项式
我们设 \(n = 2^s\),一般 \(2^s \ge n\) 但是我们可以让高次项系数化为 \(0\)。所以令 \(n = 2^s\)
所以我们可以将 \(f(x)\) 按照系数的奇偶性分成两个部分。令
我们发现
所以
所以
令 \(i = \omega_n^k\),所以
因为
所以
因为平方的存在所以 \(f_1, f_2\) 括号中间的值不变。
前面说过,\(n = 2^s\)。根据上式,要求 \(f(\omega_n^k)\),就要求 \(f_1(\omega_{\frac{n}{2}}^k), f_1(\omega_{\frac{n}{2}}^k)\),对于子任务 \(f_1(\omega_{\frac{n}{2}}^k)\),我们把 \(f_1\) 看做 \(f\),继续做递归,直到 \(f_1(\omega_1^k) = f_2(\omega_1^k) = 1\),最后回溯 + 更新答案。
然后我们知道了所有的 \(\omega_n^k\) 和 \(f(\omega_n^k)\),就得到了 \(f\) 的点值表示。
快速傅里叶逆变换
啊这个东西看不懂啊……不管了,上代码!
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1 << 22;
const double eps = 1e-6, pi = acos(-1.0);
complex<double> a[N], b[N];// 分别表示两个多项式的系数表达法
int n, m;// 分别表示两个多项式的次数
void FFT(complex<double> *f, int n, int inv){
// 分别表示要处理的 f,f 的长度,inv = 1 表示系数转点,inv = -1 表示点转系数
if(n == 1) return;
int mid = n >> 1;
complex<double> f1[mid + 1], f2[mid + 1];
for(int i = 0;i <= n;i += 2){// 拆分多项式
f1[i >> 1] = f[i];
f2[i >> 1] = f[i + 1];
}
FFT(f1, mid, inv);// 分治处理
FFT(f2, mid, inv);
complex<double> w0(1, 0), wn(cos(2 * pi / n), inv * sin(2 * pi / n));
// 分别表示 w_n^k 和 w_n(单位根)
for(int i = 0;i < mid;i++, w0 *= wn){
f[i] = f1[i] + w0 * f2[i];
f[i + n/2] = f1[i] - w0 * f2[i];
}
}
signed main(){
// ios::sync_with_stdio(false);
// cin.tie(nullptr);
scanf("%d%d", &n, &m);
for(int i = 0;i <= n;i++){
double x; scanf("%lf", &x);
a[i].real(x);// 表示将实部赋值为 x
}
for(int i = 0;i <= m;i++){
double x; scanf("%lf", &x);
b[i].real(x);
}
int len = 1 << max(int(ceil(log2(n + m))), 1);// FFT 的长度
FFT(a, len, 1);// 将 a 转为点值表达
FFT(b, len, 1);// 将 b 转为点值表达
for(int i = 0;i <= len;i++) a[i] = a[i] * b[i];// 点值乘积
FFT(a, len, -1);// 将 a 转为系数表达
for(int i = 0;i <= n + m;i++){
printf("%.0f ", a[i].real() / len + eps);
// cout<<int(a[i].real() / len + eps)<<" ";
}
return 0;
}

浙公网安备 33010602011771号