任意模数多项式相乘模板与FFT模板

返回的值可能会存在负数(即使传入参数a和b中的数全是正整数),若有需要可以给mul返回的结果取正

constexpr int P = 998244353;
int power(int a, int b) {
    int res = 1;
    for (; b; b /= 2, a = 1LL * a * a % P) {
        if (b % 2) {
            res = 1LL * res * a % P;
        }
    }
    return res;
}
vector<int> rev, roots {0, 1};
void dft(vector<int> &a) {
    int n = a.size();
    if (int(rev.size()) != n) {
        int k = __builtin_ctz(n) - 1;
        rev.resize(n);
        for (int i = 0; i < n; i++) {
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
        }
    }
    for (int i = 0; i < n; i++) {
        if (rev[i] < i) {
            swap(a[i], a[rev[i]]);
        }
    }
    if (roots.size() < n) {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1 << k) < n) {
            int e = power(31, 1 << (__builtin_ctz(P - 1) - k - 1));
            for (int i = 1 << (k - 1); i < (1 << k); i++) {
                roots[2 * i] = roots[i];
                roots[2 * i + 1] = 1LL * roots[i] * e % P;
            }
            k++;
        }
    }
    for (int k = 1; k < n; k *= 2) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; j++) {
                int u = a[i + j];
                int v = 1LL * a[i + j + k] * roots[k + j] % P;
                a[i + j] = (u + v) % P;
                a[i + j + k] = (u - v) % P;
            }
        }
    }
}

void idft(vector<int> &a) {
    int n = a.size();
    reverse(a.begin() + 1, a.end());
    dft(a);
    int inv = (1 - P) / n;
    for (int i = 0; i < n; i++) {
        a[i] = 1LL * a[i] * inv % P;
    }
}

vector<int> mul(vector<int> a, vector<int> b) {
    int n = 1, tot = a.size() + b.size() - 1;
    while (n < tot) {
        n *= 2;
    }
    if (tot < 128) {
        vector<int> c(a.size() + b.size() - 1);
        for (int i = 0; i < a.size(); i++) {
            for (int j = 0; j < b.size(); j++) {
                c[i + j] = (c[i + j] + 1LL * a[i] * b[j]) % P;
            }
        }
        return c;
    }
    a.resize(n);
    b.resize(n);
    dft(a);
    dft(b);
    for (int i = 0; i < n; i++) {
        a[i] = 1LL * a[i] * b[i] % P;
    }
    idft(a);
    a.resize(tot);
    //如果有需要a[i]不能小于0,那么就需要这样办
    //for(int i=0;i<n;i++)if(a[i]<0)a[i]+=P;
    return a;
}

FFT模板

using Z = complex<double>;//复数类型用Z表示
 
const double pi = acos(-1);
vector<Z> w;
 
void init(int n) {
    w.resize(n);
    for (int i = 0; i < n; ++i) {
        w[i] = exp(Z(0, 2 * pi / n * i));
    }
}
void FFT(vector<Z>& a, int inv) {
    int n = (int)a.size();
    for (int i = 0, j = 0; i < n; ++i) {
        if (i < j) swap(a[i], a[j]);
        for (int k = n / 2; (j ^= k) < k; k /= 2);
    }
    for (int k = 1; k < n; k *= 2) {
        for (int i = 0; i < n; i += k * 2) {
            for (int j = 0; j < k; ++j) {
                Z wn(w[(int)w.size() / k / 2 * j]);
                if (inv) wn = conj(wn);
                Z t = a[i + j + k] * wn;
                a[i + j + k] = a[i + j] - t;
                a[i + j] += t;
            }
        }
    }
    if (inv) for (auto& x : a) x /= n;
}
vector<int> PolyMul(vector<int> a, vector<int> b) {
    vector<Z> f(a.begin(), a.end()), g(b.begin(), b.end());
    int N = 1;
    for (; N < int(f.size() + g.size()); N *= 2);
    f.resize(N), g.resize(N);
    init(N);
 
    FFT(f, 0), FFT(g, 0);
    vector<Z> h(N);
    for (int i = 0; i < N; ++i) {
        h[i] = f[i] * g[i];
    }
    FFT(h, 1);
    vector<int> c(N);
    for (int i = 0; i < N; ++i) {
        c[i] = round(h[i].real());
    }
    return c;
}
posted @ 2025-05-15 22:46  MENDAXZ  阅读(17)  评论(0)    收藏  举报