模板 - 快速傅里叶变换
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 4e6;
const double PI = acos(-1.0);
struct Complex {
double x, y;
Complex() {}
Complex(double x, double y): x(x), y(y) {}
friend Complex operator+(const Complex &a, const Complex &b) {
return Complex(a.x + b.x, a.y + b.y);
}
friend Complex operator-(const Complex &a, const Complex &b) {
return Complex(a.x - b.x, a.y - b.y);
}
friend Complex operator*(const Complex &a, const Complex &b) {
return Complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
} A[MAXN + 5], B[MAXN + 5];
void FFT(Complex a[], int n, int op) {
for(int i = 1, j = n >> 1; i < n - 1; ++i) {
if(i < j)
swap(a[i], a[j]);
int k = n >> 1;
while(k <= j) {
j -= k;
k >>= 1;
}
j += k;
}
for(int len = 2; len <= n; len <<= 1) {
Complex wn(cos(2.0 * PI / len), sin(2.0 * PI / len)*op);
for(int i = 0; i < n; i += len) {
Complex w(1.0, 0.0);
for(int j = i; j < i + (len >> 1); ++j) {
Complex u = a[j], t = a[j + (len >> 1)] * w ;
a[j] = u + t, a[j + (len >> 1)] = u - t;
w = w * wn;
}
}
}
if(op == -1) {
for(int i = 0; i < n; ++i)
a[i].x = (int)(a[i].x / n + 0.5);
}
}
int pow2(int x) {
int res = 1;
while(res < x)
res <<= 1;
return res;
}
void convolution(Complex A[], Complex B[], int Asize, int Bsize) {
int n = pow2(Asize + Bsize - 1);
for(int i = 0; i < n; ++i) {
A[i].y = 0.0;
B[i].y = 0.0;
}
for(int i = Asize; i < n; ++i)
A[i].x = 0;
for(int i = Bsize; i < n; ++i)
B[i].x = 0;
FFT(A, n, 1);
FFT(B, n, 1);
for(int i = 0; i < n; ++i)
A[i] = A[i] * B[i];
FFT(A, n, -1);
return;
}
int main() {
#ifdef Yinku
freopen("Yinku.in", "r", stdin);
#endif // Yinku
int n, m;
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; ++i) {
scanf("%lf", &A[i].x);
}
for(int i = 0; i <= m; ++i) {
scanf("%lf", &B[i].x);
}
convolution(A, B, n + 1, m + 1);
for(int i = 0; i <= n + m; i++) {
printf("%d%c", (int)A[i].x, " \n"[i == n + m]);
}
return 0;
}
有一个精度更高(在多次调用时速度也更快)的版本,需要预处理单位根,花费多一点空间:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const double PI = acos(-1.0);
const int MAXN = 1e6;
struct Complex {
double x, y;
Complex(): x(0), y(0) {}
Complex(double x, double y): x(x), y(y) {}
friend Complex operator+(const Complex &a, const Complex &b) {
return Complex(a.x + b.x, a.y + b.y);
}
friend Complex operator-(const Complex &a, const Complex &b) {
return Complex(a.x - b.x, a.y - b.y);
}
friend Complex operator*(const Complex &a, const Complex &b) {
return Complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
};
typedef vector<Complex> Poly;
Poly w[2 * MAXN + 5][2];
int rev[4 * MAXN + 5];
inline void FFT(Poly &a, int n, int op) {
for(int i = 0; i < n ; ++i) {
if(i < rev[i])
swap(a[i], a[rev[i]]);
}
for(int len = 2; len <= n; len <<= 1) {
register int m = len >> 1;
for(int i = 0; i < m; ++i) {
Complex &tw = w[m][op == 1][i];
for(int j = i; j < n; j += len) {
Complex u = a[j], t = a[j + m] * tw ;
a[j] = u + t, a[j + m] = u - t;
}
}
}
if(op == -1) {
for(int i = 0; i < n; ++i) {
a[i].x = (ll)(a[i].x / n + 0.5);
a[i].y = 0;
}
}
}
inline int pow2(int x, int &lgn) {
int res = 1;
lgn = 0;
while(res < x) {
if(!w[res][0].size()) {
w[res][0].resize(res);
w[res][1].resize(res);
for(int i = 0; i < res; ++i) {
//0是逆变换需要的
w[res][0][i] = Complex(cos(-PI * i / res), sin(-PI * i / res));
w[res][1][i] = Complex(cos(PI * i / res), sin(PI * i / res));
}
}
res <<= 1;
++lgn;
}
return res;
}
inline int convolution(Poly &A, Poly &B, int Asize, int Bsize) {
int lgn, n = pow2(Asize + Bsize - 1, lgn);
for(int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lgn - 1);
A.resize(n);
B.resize(n);
FFT(A, n, 1);
FFT(B, n, 1);
for(int i = 0; i < n; ++i)
A[i] = A[i] * B[i];
FFT(A, n, -1);
while(n && (A[n - 1].x == 0))
--n;
return n;
}
Poly poly[MAXN + 5];
int main() {
#ifdef Yinku
freopen("Yinku.in", "r", stdin);
#endif // Yinku
int n, m;
scanf("%d%d", &n, &m);
Poly A, B;
for(int i = 0; i <= n; ++i) {
int x;
scanf("%d", &x);
A.push_back(Complex(x, 0));
}
for(int i = 0; i <= m; ++i) {
int x;
scanf("%d", &x);
B.push_back(Complex(x, 0));
}
int C = convolution(A, B, n + 1, m + 1);
for(int i = 0; i <= n + m; ++i) {
printf("%lld%c", (ll)A[i].x, " \n"[i == n + m]);
}
return 0;
}