返回顶部

模板 - 快速数论变换

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MAXN = 4e6, mod = 998244353;

inline int pow_mod(ll x, int n) {
    ll res;
    for(res = 1; n; n >>= 1, x = x * x % mod)
        if(n & 1)
            res = res * x % mod;
    return res;
}

inline int add_mod(int x, int y) {
    x += y;
    return x >= mod ? x - mod : x;
}

inline int sub_mod(int x, int y) {
    x -= y;
    return x < 0 ? x + mod : x;
}

void NTT(int a[], int n, int op) {
    for(int i = 1, j = n >> 1; i < n - 1; ++i) {
        if(i < j)
            swap(a[i], a[j]);
        int k = n >> 1;
        while(k <= j) {
            j -= k;
            k >>= 1;
        }
        j += k;
    }
    for(int len = 2; len <= n; len <<= 1) {
        int g = pow_mod(3, (mod - 1) / len);
        for(int i = 0; i < n; i += len) {
            int w = 1;
            for(int j = i; j < i + (len >> 1); ++j) {
                int u = a[j], t = 1ll * a[j + (len >> 1)] * w % mod;
                a[j] = add_mod(u, t), a[j + (len >> 1)] = sub_mod(u, t);
                w = 1ll * w * g % mod;
            }
        }
    }
    if(op == -1) {
        reverse(a + 1, a + n);
        int inv = pow_mod(n, mod - 2);
        for(int i = 0; i < n; ++i)
            a[i] = 1ll * a[i] * inv % mod;
    }
}

int A[MAXN + 5], B[MAXN + 5];

int pow2(int x) {
    int res = 1;
    while(res < x)
        res <<= 1;
    return res;
}

void convolution(int A[], int B[], int Asize, int Bsize) {
    int n = pow2(Asize + Bsize - 1);
    for(int i = Asize; i < n; ++i)
        A[i] = 0;
    for(int i = Bsize; i < n; ++i)
        B[i] = 0;
    NTT(A, n, 1);
    NTT(B, n, 1);
    for(int i = 0; i < n; ++i)
        A[i] = 1ll * A[i] * B[i] % mod;
    NTT(A, n, -1);
    return;
}

int main() {
#ifdef Yinku
    freopen("Yinku.in", "r", stdin);
#endif // Yinku
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i = 0; i <= n; ++i) {
        scanf("%d", &A[i]);
        A[i] = add_mod(A[i], mod);
    }
    for(int i = 0; i <= m; ++i) {
        scanf("%d", &B[i]);
        B[i] = add_mod(B[i], mod);
    }
    convolution(A, B, n + 1, m + 1);
    for(int i = 0; i <= n + m; i++) {
        printf("%d%c", A[i], " \n"[i == n + m]);
    }
    return 0;
}
posted @ 2019-09-10 22:06  Inko  阅读(163)  评论(0编辑  收藏  举报