多项式模板

总算把之前摸鱼多项式欠下的东西还清了些。。。

常数应该不算特别大

点击查看代码
namespace Polys {
    #define Poly std::vector <int>

    #define ll long long
    const int G = 3, MOD = 998244353;

    ll power(ll a, ll b = MOD - 2) {
        ll ret = 1;
        for (; b; b >>= 1) {
            if (b & 1) ret = (ret * a) % MOD;
            a = (a * a) % MOD;
        }
        return ret;
    }

    const int invG = power(G);

    std::vector <int> r;
    void initr(int n) {
        if (r.size() == n) return;
        r.resize(n);
        int cnt = 0;
        for (int i = 1; i < n; i <<= 1) cnt++;
        for (int i = 0; i < n; i++) {
            r[i] = (r[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
        }
    }

    #define clr(a, s) memset(a, 0, sizeof(int) * s)
    #define cpy(a, b, s) memcpy(a, b, sizeof(int) * s);
    void mul(int *a, int *b, int n) {
        for (int i = 0; i < n; i++) {
            a[i] = 1ll * a[i] * b[i] % MOD;
        }
    }

    void NTT(int *p, int n, int type) {
        initr(n);
        static int f[MAXN << 1];
        for (int i = 0; i < n; i++) f[i] = p[i];
        for (int i = 0; i < n; i++) {
            if (i < r[i]) std::swap(f[i], f[r[i]]);
        }

        for (int len = 2; len <= n; len <<= 1) {
            int m = len >> 1;
            int base;
            if (type == 1) base = power(G, (MOD - 1) / len);
            else base = power(invG, (MOD - 1) / len);
            for (int i = 0; i < n; i += len) {
                int w = 1;
                for (int j = 0; j < m; j++, w = 1ll * w * base % MOD) {
                    int cur = 1ll * f[i + j + m] * w % MOD;
                    f[i + j + m] = f[i + j] - cur;
                    if (f[i + j + m] < 0) f[i + j + m] += MOD; 
                    f[i + j] = f[i + j] + cur;
                    if (f[i + j] >= MOD) f[i + j] -= MOD;
                }
            }
        }

        if (type == - 1) {
            int invn = power(n);
            for (int i = 0; i < n; i++) f[i] = 1ll * f[i] * invn % MOD;
        }

        for (int i = 0; i < n; i++) p[i] = f[i];
    }

    Poly operator - (const Poly &x, const Poly &y) {
        Poly ret((int)x.size() + y.size() - 1);
        for (int i = 0; i < ret.size(); i++) {
            if (i < x.size()) ret[i] += x[i];
            if (i < y.size()) ret[i] -= y[i];
            if (ret[i] >= MOD) ret[i] -= MOD;
            if (ret[i] < 0) ret[i] += MOD;
        }
        return ret;
    }

    Poly operator + (const Poly &x, const Poly &y) {
        Poly ret((int)x.size() + y.size() - 1);
        for (int i = 0; i < ret.size(); i++) {
            if (i < x.size()) ret[i] += x[i];
            if (i < y.size()) ret[i] += y[i];
            if (ret[i] >= MOD) ret[i] -= MOD;
            if (ret[i] < 0) ret[i] += MOD;
        }
        return ret;
    }

    Poly operator * (const Poly &x, int c) {
        Poly ret(x.size());
        for (int i = 0; i < ret.size(); i++) {
            ret[i] = 1ll * x[i] * c % MOD;
        }
        return ret;
    }

    Poly operator * (const Poly &x, const Poly &y) {
        static int a[MAXN << 1], b[MAXN << 1];
        cpy(a, &x[0], x.size());cpy(b, &y[0], y.size());
        int lim;
        for (lim = 1; lim < ((int)x.size() + y.size() - 1); lim <<= 1);
        NTT(a, lim, 1);NTT(b, lim, 1);
        mul(a, b, lim);
        NTT(a, lim, -1);
        Poly ret((int)x.size() + y.size() - 1);
        cpy(&ret[0], a, ret.size());
        clr(a, lim);clr(b, lim);
        return ret;
    }

    void Getinv(const Poly &a, Poly &b, int n) {
        if (n == 1) {
            b.push_back(power(a[0]));
            return;
        }
        if (n & 1) {
            Getinv(a, b, --n);
            int sum = 0;
            for (int i = 0 ;i < n; i++) {
                sum += 1ll * b[i] * a[n - i] % MOD;
                if (sum >= MOD) sum -= MOD;
            }
            b.push_back(1ll * sum * power(MOD - a[0]) % MOD);
            return;
        }
        Getinv(a, b, n >> 1);
        Poly tmp(n);
        cpy(&tmp[0], &a[0], n);
        b = b * 2 - tmp * b * b;
        b.resize(n);
    }

    Poly Inv(const Poly &x) {
        Poly ret;
        Getinv(x, ret, x.size());
        return ret;
    }

    Poly Der(const Poly &x) {
        Poly ret(x.size());
        for (int i = 1; i < x.size() ; i++) {
            ret[i - 1] = 1ll * i * x[i] % MOD; 
        }
        return ret;
    }

    std::vector <int> invs;

    void initinv(int n) {
        if (invs.size() <= n) {
            int cur = invs.size();
            invs.resize(n + 1);
            for (int i = cur; i <= n; i++) {
                invs[i] = power(i);
            }
        }
    }

    Poly Inter(const Poly &x) {
        Poly ret(x.size());
        initinv(ret.size());
        for (int i = 1; i < ret.size(); i++) {
            ret[i] = 1ll * x[i - 1] * invs[i] % MOD; 
        }
        ret[0] = 0;
        return ret;
    }

    Poly ln(const Poly &x) {
        Poly ret(x.size());
        Poly inv = Inv(x), der = Der(x);
        der = der * inv;
        der = Inter(der);
        der.resize(x.size());
        return der;
    }

    void Getexp(const Poly &a, Poly &b, int n) {
        if (n == 1) {
            b.push_back(1);
            return;
        }
        if (n & 1) {
            Getexp(a, b, n - 1);
            n -= 2;
            int sum = 0;
            for (int i = 0; i <= n; i++) {
                sum += 1ll * (i + 1) * a[i + 1] % MOD * b[n - i] % MOD;
                if (sum >= MOD) sum -= MOD;
            }
            initinv(n + 1);
            sum =1ll * sum * invs[n + 1] % MOD;
            b.push_back(sum);
            return;
        }
        Getexp(a, b , n >> 1);
        Poly lnb = b;
        lnb.resize(n);lnb = ln(lnb);
        for (int i = 0; i < lnb.size(); i++) {
            lnb[i] = (a[i] - lnb[i]);
            if (lnb[i] < 0) lnb[i] += MOD; 
        }
        lnb[0]++;
        if (lnb[0] >= MOD) lnb[0] -= MOD;
        b = lnb * b;
        b.resize(n);
    }

    Poly exp(const Poly &x) {
        Poly ret;
        Getexp(x, ret, x.size());
        return ret;
    }
}
posted @ 2023-09-12 23:31  Katyusha_Lzh  阅读(27)  评论(3)    收藏  举报