Poly
目录
qpow:快速幂。
namespace QuadRes:求二次剩余。
namespace poly_real:实数相关,包含乘法。
namespace MTT:任意模数相关,包含乘法和求逆。
namespace poly:正常板子。包含乘、求逆、除、开根、ln、exp、快速幂、sin、cos、asin、atan、RFP转换(见注释)、FFP乘法、多点求值、快速插值、普通多项式转FFP、FFP转普通多项式。
均由数组实现。
711行,22KB。
代码
// made by creation_hy // RFP = Rising Factorial Polynomial // FFP = Falling Factorial Polynomial #include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1.5e5 + 5; const int S = 150; const int mod = 998244353; const int G = 3; int n, m, rev[N], sq; inline ll qpow(ll x, ll y) { ll res = 1; while (y) { if (y & 1) res = res * x % mod; x = x * x % mod; y >>= 1; } return res; } const ll Gi = qpow(G, mod - 2); const ll inv2 = qpow(2, mod - 2); const ll img = qpow(G, (mod - 1) / 4); ll fac[N], inv[N]; namespace QuadRes { ll isq; // i_Square struct Complex { ll real, imag; inline Complex(ll _real = 0, ll _imag = 0) { real = _real, imag = _imag; } inline bool operator==(const Complex &t) { return real == t.real && imag == t.imag; } inline Complex operator*(const Complex &t) { return Complex((real * t.real + isq * imag % mod * t.imag) % mod, (imag * t.real + real * t.imag) % mod); } inline Complex operator^(int x) { Complex res(1, 0), base = *this; while (x) { if (x & 1) res = res * base; base = base * base; x >>= 1; } return res; } }; inline bool check(ll x) { return (Complex(x, 0) ^ (mod - 1 >> 1)) == Complex(1, 0); } inline ll solve(ll x) { ll val = rand() % mod; while (!val || check((val * val + mod - x) % mod)) val = rand() % mod; isq = (val * val + mod - x) % mod; ll ans = (Complex(val, 1) ^ (mod + 1 >> 1)).real; return min(ans, mod - ans); } } namespace poly_real { const double pi = acos(-1); inline void getrev(int lg) { for (int i = 1; i < 1 << lg; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << lg - 1; } struct Complex { double x, y; inline Complex(double _x = 0, double _y = 0) { x = _x, y = _y; } inline Complex operator+(const Complex &t) const { return Complex(x + t.x, y + t.y); } inline Complex operator-(const Complex &t) const { return Complex(x - t.x, y - t.y); } inline Complex operator*(const Complex &t) const { return Complex(x * t.x - y * t.y, x * t.y + y * t.x); } }; inline void FFT(Complex *a, int lg, int type) { for (int i = 0; i < 1 << lg; i++) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int i = 1; i < 1 << lg; i <<= 1) { Complex tmp(cos(pi / i), type * sin(pi / i)); for (int j = 0; j < 1 << lg; j += i << 1) { Complex cur(1, 0); for (int k = 0; k < i; k++, cur = cur * tmp) { Complex cx = a[j + k], cy = a[i + j + k]; a[j + k] = cx + cur * cy; a[i + j + k] = cx - cur * cy; } } } if (type == -1) for (int i = 0; i < 1 << lg; i++) a[i].x /= 1 << lg; } inline void mul(int n, int m, Complex *f, Complex *g, Complex *r) { static Complex a[N], b[N]; int lg = 0; while (1 << lg <= n + m) lg++; memcpy(a, f, sizeof(Complex) * n); memcpy(b, g, sizeof(Complex) * m); for (int i = n; i < 1 << lg; i++) a[i] = Complex(); for (int i = m; i < 1 << lg; i++) b[i] = Complex(); getrev(lg); FFT(a, lg, 1), FFT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i]; FFT(a, lg, -1); memcpy(r, a, sizeof(Complex) * (n + m - 1)); } } namespace MTT { const ll mod[3] = {998244353, 1004535809, 469762049}; inline ll qpow(ll x, ll y, ll mod) { ll res = 1; while (y) { if (y & 1) res = res * x % mod; x = x * x % mod; y >>= 1; } return res; } ll inv1 = qpow(mod[0], mod[1] - 2, mod[1]), inv2 = qpow(1ll * mod[0] * mod[1] % mod[2], mod[2] - 2, mod[2]); struct Tuple { ll x, y, z; inline Tuple() {} inline Tuple(ll _x) { x = y = z = _x; } inline Tuple(ll _x, ll _y, ll _z) { x = _x, y = _y, z = _z; } inline Tuple operator+(const Tuple &t) { return Tuple((x + t.x) % mod[0], (y + t.y + mod[1]) % mod[1], (z + t.z + mod[2]) % mod[2]); } inline Tuple operator-(const Tuple &t) { return Tuple((x - t.x + mod[0]) % mod[0], (y - t.y + mod[1]) % mod[1], (z - t.z + mod[2]) % mod[2]); } inline Tuple operator*(const Tuple &t) { return Tuple(1ll * x * t.x % mod[0], 1ll * y * t.y % mod[1], 1ll * z * t.z % mod[2]); } inline ll val(const int &p) { ll k = (y - x + mod[1]) % mod[1] * inv1 % mod[1] * mod[0] + x; return ((z - k % mod[2] + mod[2]) % mod[2] * inv2 % mod[2] * (mod[0] * mod[1] % p) % p + k) % p; } }; inline void getrev(int lg) { for (int i = 1; i < 1 << lg; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << lg - 1; } inline void NTT(Tuple *f, int lg, int type) { for (int i = 0; i < 1 << lg; i++) if (i < rev[i]) swap(f[i], f[rev[i]]); for (int i = 1; i < 1 << lg; i <<= 1) { Tuple val(qpow(type == 1 ? qpow(G, mod[0] - 2, mod[0]) : G, (mod[0] - 1) / (i << 1), mod[0]), qpow(type == 1 ? qpow(G, mod[1] - 2, mod[1]) : G, (mod[1] - 1) / (i << 1), mod[1]), qpow(type == 1 ? qpow(G, mod[2] - 2, mod[2]) : G, (mod[2] - 1) / (i << 1), mod[2])); for (int j = 0; j < 1 << lg; j += i << 1) { Tuple cur = 1; for (int k = 0; k < i; k++, cur = cur * val) { Tuple x = f[j + k], y = cur * f[i + j + k]; f[j + k] = x + y; f[i + j + k] = x - y; } } } if (type == -1) { Tuple invs(qpow(1 << lg, mod[0] - 2, mod[0]), qpow(1 << lg, mod[1] - 2, mod[1]), qpow(1 << lg, mod[2] - 2, mod[2])); for (int i = 0; i < 1 << lg; i++) f[i] = f[i] * invs; } } inline void mul(int n, int m, int p, Tuple *f, Tuple *g, ll *r) { static Tuple a[N], b[N]; int lg = 0; while (1 << lg <= n + m) lg++; memcpy(a, f, sizeof(Tuple) * n); memcpy(b, g, sizeof(Tuple) * m); getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i]; NTT(a, lg, -1); for (int i = 0; i < n + m - 1; i++) r[i] = a[i].val(p); } inline void Inv(int x, int p, int lg, Tuple *f, Tuple *g) { static Tuple a[N], b[N]; if (x == 1) { g[0] = qpow(f[0].val(p), p - 2, p); return; } Inv(x + 1 >> 1, p, lg - 1, f, g); memcpy(a, f, sizeof(Tuple) * x); memcpy(b, g, sizeof(Tuple) * x); for (int i = x; i < 1 << lg; i++) a[i] = b[i] = 0; getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i]; NTT(a, lg, -1); for (int i = 0; i < 1 << lg; i++) a[i] = p - a[i].val(p); NTT(a, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i]; NTT(a, lg, -1); for (int i = x + 1 >> 1; i < x; i++) g[i] = a[i].val(p); } } namespace poly { ll qp[N]; inline void getrev(int lg) { for (int i = 1; i < 1 << lg; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << lg - 1; ll base = qpow(G, (mod - 1) / (1 << lg)); qp[0] = 1; for (int i = 1; i <= 1 << lg; i++) qp[i] = qp[i - 1] * base % mod; } inline void NTT(ll *f, int lg, int type) { for (int i = 0; i < 1 << lg; i++) if (i < rev[i]) swap(f[i], f[rev[i]]); for (int i = 1; i < 1 << lg; i <<= 1) { ll t = (1 << lg) / i >> 1; for (int j = 0; j < 1 << lg; j += i << 1) for (int k = 0; k < i; k++) { ll cur = type == 1 ? qp[t * k] : qp[(1 << lg) - t * k]; ll x = f[j + k], y = cur * f[i + j + k] % mod; f[j + k] = (x + y) % mod; f[i + j + k] = (x - y + mod) % mod; } } if (type == -1) { ll inv = qpow(1 << lg, mod - 2); for (int i = 0; i < 1 << lg; i++) f[i] = f[i] * inv % mod; } } inline void mul(int n, int m, ll *f, ll *g, ll *r) { static ll a[N], b[N]; int lg = 0; while (1 << lg <= n + m) lg++; memcpy(a, f, sizeof(ll) * n); memcpy(b, g, sizeof(ll) * m); memset(a + n, 0, sizeof(ll) * ((1 << lg) - n)); memset(b + m, 0, sizeof(ll) * ((1 << lg) - m)); getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i]; NTT(a, lg, -1); memcpy(r, a, sizeof(ll) * (n + m - 1)); } inline void Inv(int x, int lg, ll *f, ll *g) { static ll a[N]; if (x == 1) { g[0] = qpow(f[0], mod - 2); return; } Inv(x + 1 >> 1, lg - 1, f, g); memcpy(a, f, sizeof(ll) * x); memset(a + x, 0, sizeof(ll) * ((1 << lg) - x)); getrev(lg); NTT(a, lg, 1), NTT(g, lg, 1); for (int i = 0; i < 1 << lg; i++) g[i] = (2 - a[i] * g[i] % mod + mod) % mod * g[i] % mod; NTT(g, lg, -1); memset(g + x, 0, sizeof(ll) * ((1 << lg) - x)); } inline void Sqrt(int x, int lg, ll *f, ll *g) { static ll inv[N], tmp[N]; if (x == 1) { g[0] = QuadRes::solve(f[0]); return; } Sqrt(x + 1 >> 1, lg - 1, f, g); memset(inv, 0, sizeof(ll) * (1 << lg)); Inv(x, lg, g, inv); getrev(lg); memcpy(tmp, f, sizeof(ll) * x); memset(tmp + x, 0, sizeof(ll) * ((1 << lg) - x)); NTT(tmp, lg, 1), NTT(g, lg, 1), NTT(inv, lg, 1); for (int i = 0; i < 1 << lg; i++) g[i] = (tmp[i] + g[i] * g[i] % mod) % mod * inv2 % mod * inv[i] % mod; NTT(g, lg, -1); memset(g + x, 0, sizeof(ll) * ((1 << lg) - x)); } inline void Div(int n, int m, ll *f, ll *g, ll *q, ll *r) { static ll a[N], b[N]; int lg = 0; while (1 << lg <= n << 1) lg++; memcpy(a, g, sizeof(ll) * m); reverse(a, a + m); Inv(n, lg, a, b); memcpy(a, f, sizeof(ll) * n); reverse(a, a + n); getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i] % mod; NTT(a, lg, -1); reverse(a, a + n - m + 1); memset(a + n - m + 1, 0, sizeof(ll) * ((1 << lg) - (n - m))); memcpy(q, a, sizeof(ll) * (n - m + 1)); memset(a + n - m + 1, 0, sizeof(ll) * ((1 << lg) - (n - m))); memcpy(b, g, sizeof(ll) * (1 << lg)); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i] % mod; NTT(a, lg, -1); for (int i = 0; i < 1 << lg; i++) r[i] = (f[i] - a[i] + mod) % mod; } inline void dev(int n, ll *f, ll *g) { for (int i = 1; i < n; i++) g[i - 1] = i * f[i] % mod; g[n - 1] = 0; } inline void idev(int n, ll *f, ll *g) { for (int i = 1; i < n; i++) g[i] = f[i - 1] * qpow(i, mod - 2) % mod; g[0] = 0; } inline void ln(int n, ll *f, ll *g) { static ll a[N], b[N]; int lg = 0; while (1 << lg <= n << 1) lg++; memset(a, 0, sizeof(ll) * (1 << lg)); memset(b, 0, sizeof(ll) * (1 << lg)); dev(n, f, a); Inv(n, lg, f, b); getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i] % mod; NTT(a, lg, -1); idev(n, a, g); } inline void exp(int x, int lg, ll *f, ll *g) { static ll a[N]; if (x == 1) { g[0] = 1; return; } exp(x + 1 >> 1, lg - 1, f, g); ln(x, g, a); for (int i = 0; i < x; i++) a[i] = (f[i] - a[i] + mod) % mod; a[0]++; getrev(lg); NTT(a, lg, 1), NTT(g, lg, 1); for (int i = 0; i < 1 << lg; i++) g[i] = g[i] * a[i] % mod; NTT(g, lg, -1); memset(g + x, 0, sizeof(ll) * ((1 << lg) - x)); } inline void pow(int n, int k1, int k2, ll *f, ll *g) { static ll a[N], b[N]; ll d = 0; while (d < n && !f[d]) d++; if (d * k1 > n) { memset(g, 0, sizeof(ll) * n); return; } int x = qpow(f[d], mod - 2), y = qpow(f[d], k2); for (int i = 0; i < n; i++) a[i] = f[i + d] * x % mod; ln(n, a, b); for (int i = 0; i < n; i++) a[i] = b[i] * k1 % mod; int lg = 0; while (1 << lg <= n << 1) lg++; memset(b, 0, sizeof(ll) * n); exp(n, lg, a, b); d *= k1; memset(g, 0, sizeof(ll) * d); for (int i = d; i < n; i++) g[i] = b[i - d] * y % mod; } inline void sin(int n, ll *f, ll *g) { static ll a[N], b[N], c[N]; for (int i = 0; i < n; i++) a[i] = f[i] * img % mod; int lg = 0; while (1 << lg <= n << 1) lg++; exp(n, lg, a, b); Inv(n, lg, b, c); ll inv = qpow(img, mod - 2) * inv2 % mod; for (int i = 0; i < n; i++) g[i] = (b[i] - c[i] + mod) % mod * inv % mod; } inline void cos(int n, ll *f, ll *g) { static ll a[N], b[N], c[N]; for (int i = 0; i < n; i++) a[i] = f[i] * img % mod; int lg = 0; while (1 << lg <= n << 1) lg++; exp(n, lg, a, b); Inv(n, lg, b, c); for (int i = 0; i < n; i++) g[i] = (b[i] + c[i]) % mod * inv2 % mod; } inline void asin(int n, ll *f, ll *g) { static ll a[N], b[N]; int lg = 0; while (1 << lg < n << 1) lg++; memcpy(a, f, sizeof(ll) * n); memset(a + n, 0, sizeof(ll) * ((1 << lg) - n)); getrev(lg); NTT(a, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * a[i] % mod; NTT(a, lg, -1); for (int i = 0; i < n; i++) a[i] = (-a[i] + mod) % mod; a[0]++; memset(b, 0, sizeof(ll) * (1 << lg)); Sqrt(n, lg, a, b); memset(a, 0, sizeof(ll) * (1 << lg)); Inv(n, lg, b, a); memset(b, 0, sizeof(ll) * (1 << lg)); dev(n, f, b); getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i] % mod; NTT(a, lg, -1); idev(n, a, g); } inline void atan(int n, ll *f, ll *g) { static ll a[N], b[N]; int lg = 0; while (1 << lg < n << 1) lg++; memcpy(a, f, sizeof(ll) * n); memset(a + n, 0, sizeof(ll) * ((1 << lg) - n)); getrev(lg); NTT(a, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * a[i] % mod; NTT(a, lg, -1); a[0]++; memset(b, 0, sizeof(ll) * (1 << lg)); Inv(n, lg, a, b); memset(a, 0, sizeof(ll) * (1 << lg)); dev(n, f, a); getrev(lg); NTT(a, lg, 1), NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i] % mod; NTT(a, lg, -1); idev(n, a, g); } inline void init(int lim) { fac[0] = 1; for (int i = 1; i <= lim; i++) fac[i] = fac[i - 1] * i % mod; inv[lim] = qpow(fac[lim], mod - 2); for (int i = lim; i; i--) inv[i - 1] = inv[i] * i % mod; } inline void desDFT(ll *f, int n, int type) { static ll a[N]; for (int i = 0; i < n; i++) a[i] = type == 1 ? inv[i] : (i & 1 ? mod - inv[i] : inv[i]); int lg = 0; while (1 << lg <= n << 1) lg++; memset(a + n, 0, sizeof(ll) * ((1 << lg) - n)); getrev(lg); NTT(f, lg, 1), NTT(a, lg, 1); for (int i = 0; i < 1 << lg; i++) f[i] = f[i] * a[i] % mod; NTT(f, lg, -1); } inline void RFPup(int x, int c, ll *f, ll *g) // f(x) to f(x+c) { static ll a[N], b[N]; int lg = 0; while (1 << lg <= x << 1) lg++; for (ll i = 0, val = 1; i < x; i++, val = val * c % mod) { a[x - i - 1] = f[i] * fac[i] % mod; b[i] = val * inv[i] % mod; } memset(a + x, 0, sizeof(ll) * ((1 << lg) - x)); memset(b + x, 0, sizeof(ll) * ((1 << lg) - x)); poly::getrev(lg); poly::NTT(a, lg, 1), poly::NTT(b, lg, 1); for (int i = 0; i < 1 << lg; i++) a[i] = a[i] * b[i] % mod; poly::NTT(a, lg, -1); for (int i = 0; i < x; i++) g[i] = a[x - i - 1] * inv[i] % mod; } ll *ev[N], sz[N]; inline void rmul(int n, int m, ll *f, ll *g, ll *r) { static ll a[N], b[N]; memcpy(a, g, sizeof(ll) * m); reverse(a, a + m); mul(n, m, f, a, b); for (int i = 0; i < n; i++) r[i] = b[i + m - 1]; } inline void evalbuild(int p, int l, int r, ll *f) { if (l == r) { ev[p] = new ll[2]; sz[p] = 2; ev[p][0] = 1, ev[p][1] = (mod - f[l]) % mod; return; } int mid = l + r >> 1, ls = p << 1, rs = p << 1 | 1; evalbuild(ls, l, mid, f); evalbuild(rs, mid + 1, r, f); sz[p] = sz[ls] + sz[rs] - 1; ev[p] = new ll[sz[p]]; mul(sz[ls], sz[rs], ev[ls], ev[rs], ev[p]); } inline void eval(int p, int l, int r, ll *f, ll *g) { if (l == r) { g[l] = f[0]; return; } int len = r - l + 1; int mid = l + r >> 1, ls = p << 1, rs = p << 1 | 1; ll a[len + sz[rs]], b[len + sz[ls]]; rmul(len, sz[rs], f, ev[rs], a); rmul(len, sz[ls], f, ev[ls], b); eval(ls, l, mid, a, g); eval(rs, mid + 1, r, b, g); } ll *itp[N]; inline void itpsolve(int p, int l, int r, ll *f) { static ll a[N], b[N], c[N], d[N]; if (l == r) { itp[p] = new ll[1]; itp[p][0] = f[l]; return; } int mid = l + r >> 1, len = r - l + 1, ls = p << 1, rs = p << 1 | 1; itpsolve(ls, l, mid, f); itpsolve(rs, mid + 1, r, f); memcpy(a, itp[ls], sizeof(ll) * (mid - l + 1)); memcpy(b, itp[rs], sizeof(ll) * (r - mid)); mul(mid - l + 1, sz[rs], a, ev[rs], c); mul(r - mid, sz[ls], b, ev[ls], d); itp[p] = new ll[len]; for (int i = 0; i < len; i++) itp[p][i] = (c[i] + d[i]) % mod; } inline void interpolation(int n, ll *f, ll *g, ll *r) { static ll a[N], b[N], c[N]; evalbuild(1, 0, n - 1, f); memcpy(a, ev[1], sizeof(ll) * sz[1]); reverse(a, a + sz[1]); dev(sz[1], a, b); int lg = 0; while (1 << lg <= sz[1] << 1) lg++; memset(a, 0, sizeof(ll) * (1 << lg)); Inv(sz[1], lg, ev[1], a); rmul(sz[1], sz[1], b, a, c); memset(a, 0, sizeof(ll) * (1 << lg)); eval(1, 0, n - 1, c, a); for (int i = 1; i <= n; i++) a[i] = g[i] * qpow(a[i], mod - 2) % mod; itpsolve(1, 0, n - 1, a); memcpy(r, itp[1], sizeof(ll) * n); reverse(r, r + n); } inline void PolyToFFP(int n, ll *f, ll *g) { static ll a[N], b[N], c[N], d[N]; init(n << 2); memcpy(a, f, sizeof(ll) * n); for (int i = 0; i < n; i++) b[i] = i; evalbuild(1, 0, n - 1, b); int lg = 0, siz = sz[1]; while (1 << lg <= siz << 1) lg++; Inv(siz, lg, ev[1], c); rmul(n, siz, a, c, d); eval(1, 0, n - 1, d, g); for (int i = 0; i < n; i++) g[i] = g[i] * inv[i] % mod; desDFT(g, n, -1); } inline void FFPToPoly(int n, ll *f, ll *g) { static ll a[N], b[N], x[N], y[N]; init(n << 2); for (int i = 0; i <= n; i++) a[i] = inv[i], x[i] = i; mul(n, n + 1, f, a, b); for (int i = 0; i <= n; i++) y[i] = (n - i) & 1 ? mod - b[i] * inv[n - i] % mod : b[i] * inv[n - i] % mod; evalbuild(1, 0, n, x); itpsolve(1, 0, n, y); memcpy(g, itp[1], sizeof(ll) * (n + 1)); reverse(g, g + n + 1); } } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }

浙公网安备 33010602011771号