FFT与NTT

讲解:http://www.cnblogs.com/poorpool/p/8760748.html

递归版FFT

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <cmath>
using namespace std;
const int MAXN = 4000005;
const double PI = acos(-1);
int init() {
    int rv = 0, fh = 1;
    char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') fh = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        rv = (rv<<1) + (rv<<3) + c - '0';
        c = getchar();
    }
    return rv * fh;
}
struct Complex{
    double x, y;
    Complex (double xx = 0.0, double yy = 0.0) {
        x = xx; y = yy;
    }
    Complex operator + (const Complex &u) const{
        return Complex(x + u.x, y + u.y);
    }
    Complex operator - (const Complex &u) const{
        return Complex(x - u.x, y - u.y);
    }
    Complex operator * (const Complex &u) const{
        return Complex(x * u.x - y * u.y, x * u.y + y * u.x);
    }
}a[MAXN], b[MAXN], buf[MAXN];
int n, m;
void fft(Complex a[], int lim, int opt) {
    if(lim == 1) return;
    int tmp = lim / 2;
    for(int i = 0; i < tmp; i++) {
        buf[i] = a[i * 2];
        buf[tmp + i] = a[i * 2 + 1];
    }
    for(int i = 0; i < lim; i++) {
        a[i] = buf[i];
    }
    fft(a, tmp, opt);
    fft(a + tmp, tmp, opt);
    Complex wn = Complex(cos(PI * 2.0 / lim), opt * sin(PI * 2.0 / lim)), w = Complex(1.0, 0.0);
    for(int i = 0; i < tmp; i++) {
        buf[i] = a[i] + w * a[i + tmp];
        buf[i + tmp] = a[i] - w * a[i + tmp];
        w = w * wn;
    }
    for(int i = 0; i < lim; i++) {
        a[i] = buf[i];
    }
}
int main() {
    n = init(); m = init();
    for(int i = 0; i <= n; i++) a[i].x = init();
    for(int i = 0; i <= m; i++) b[i].x = init();
    int lim = 1;
    while(lim <= n + m) lim <<= 1;
    fft(a, lim, 1);
    fft(b, lim, 1);
    for(int i = 0; i <= lim; i++) {
        a[i] = a[i] * b[i];
    }
    fft(a, lim, -1);
    for(int i = 0; i <= n + m; i++) {
        printf("%d ", (int)(a[i].x / lim + 0.5));
    }
    printf("\n");
    return 0;
}

迭代版FFT

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
using namespace std;
const int MAXN = 4000005;
const double PI = acos(-1);
int init() {
    int rv = 0, fh = 1;
    char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') fh = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        rv = (rv<<1) + (rv<<3) + c - '0';
        c = getchar();
    }
    return rv * fh;
}
struct Complex{
    double x, y;
    Complex (double xx = 0.0, double yy = 0.0) {
        x = xx; y = yy;
    }
    Complex operator + (const Complex &u) const {
        return Complex(x + u.x, y + u.y);
    }
    Complex operator - (const Complex &u) const{
        return Complex(x - u.x, y - u.y);
    }
    Complex operator * (const Complex &u) const{
        return Complex(x * u.x - y * u.y, x * u.y + y * u.x);
    }
}a[MAXN], b[MAXN], buf[MAXN];
int n, m, rev[MAXN], lim, limcnt;
void fft(Complex a[], int opt) {
    for(int i = 0; i <= lim; i++) {
        if(i < rev[i]) swap(a[i], a[rev[i]]);
    }
    for(int mid = 1; mid < lim; mid <<= 1) {
        Complex wn = Complex(cos(PI / mid), opt * sin(PI / mid));
        for(int R = mid << 1, j = 0; j < lim; j += R) {
            Complex w = Complex(1.0, 0.0);
            for(int k = 0; k < mid; k++) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y;
                a[j + mid + k] = x - y;
                w = w * wn;
            }
        }
    }
}
int main() {
    n = init(); m = init();
    for(int i = 0; i <= n; i++) a[i].x = init();
    for(int i = 0; i <= m; i++) b[i].x = init();
    lim = 1;
    while(lim <= n + m) {lim <<= 1; limcnt++;}
    for(int i = 0; i <= lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (limcnt - 1));
    fft(a, 1);
    fft(b, 1);
    for(int i = 0; i <= lim; i++) a[i] = a[i] * b[i];
    fft(a, -1);
    for(int i = 0; i <= n + m; i++) {
        printf("%d ", (int)(a[i].x / lim + 0.5));
    }
    return 0;
}

NTT

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define ll long long
using namespace std;
const int MAXN = 4000005, MOD = 998244353, gg = 3, gi = 332748118;
ll init() {
    ll rv = 0, fh = 1;
    char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') fh = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        rv = (rv<<1) + (rv<<3) + c - '0';
        c = getchar();
    }
    return fh * rv;
}
ll lim = 1, limcnt, rev[MAXN], n, m, a[MAXN], b[MAXN];
ll ksm(ll a, ll k) {
    ll ans = 1;
    while(k) {
        if(k & 1ll) {
            (ans *= a) %= MOD;
        }
        (a *= a) %= MOD;
        k >>= 1;
    }
    return ans;
}
void ntt(ll a[], int opt) {
    for(int i = 0; i <= lim; i++) {
        if(i < rev[i]) swap(a[i], a[rev[i]]);
    }
    for(int mid = 1; mid < lim; mid <<= 1) {
        ll wn = ksm(opt == 1 ? gg : gi, (MOD - 1) / (mid << 1));
        for(int R = mid << 1, j = 0; j < lim; j += R) {
            ll w = 1;
            for(int k = 0; k < mid; k++) {
                ll x = a[j + k], y = w * a[j + mid + k] % MOD;
                a[j + k] = (x + y) % MOD;
                a[j + mid + k] = (x - y + MOD) % MOD;
                (w *= wn) %= MOD;
            }
        }
    }
    if(opt == -1) {
        ll inv = ksm(lim, MOD - 2);
        for(int i = 0; i <= lim; i++) {
            (a[i] *= inv) %= MOD;
        }
    }
}
int main() {
    n = init(); m = init();
    for(int i = 0; i <= n; i++) {
        a[i] = init();
    }
    for(int i = 0; i <= m; i++) b[i] = init();
    while(lim <= (n + m)) lim <<= 1, limcnt++;
    for(int i = 0; i <= lim; i++) 
        rev[i] = (rev[i>>1]>>1) | ((i&1)<<(limcnt-1));
    ntt(a, 1);
    ntt(b, 1);
    for(int i = 0; i <= lim; i++) (a[i] = a[i] * b[i]) %= MOD;
    ntt(a, -1);
    for(int i = 0; i <= n + m; i++) {
        printf("%lld ", a[i]);
    }
    printf("\n");
    return 0;
}
posted @ 2018-05-23 16:50  Mr_Wolfram  阅读(...)  评论(...编辑  收藏