多项式板子(待补充)

已经实现的操作

  • 乘法NTT, mul
  • 求逆inverse
  • 积分integral,微分deriv
  • 对数ln
  • 指数exp
  • 快速幂power(常数项为\(1\)
  • 开根sqrt(常数项为\(1\)

说明

  • 所有操作都是原地操作,以后有空改(大量copy还是比较麻烦、耗时)
  • 常数还可以

Code

const int N = 262144, G = 3, iG = 332748118, p = 998244353, maxlen = N;
inline int add(int x, int y){return (x+y) % p;}
inline int sub(int x, int y){return (x-y+p) % p;}
inline int mul(int x, int y){return 1LL * x * y % p;}
int last, rev[N], w[N], invw[N], pinv[N];
inline int qp(int x, int y){
	int res = 1;
	for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
	return res;
}

void pre(){
	w[1] = qp(G, (p-1) / maxlen);
	invw[1] = qp(iG, (p-1) / maxlen);
	w[0] = invw[0] = 1;
	for(int i=2; i<maxlen; i++) w[i] = mul(w[i-1], w[1]), invw[i] = mul(invw[i-1], invw[1]);
	pinv[1] = 1;
	for(int i=2; i<maxlen; i++) pinv[i] = mul(p - p / i, pinv[p % i]);
}

void init(int n){
	if(last == n) return ;
	last = n;
	int tp = n, lg = -1;
	while(tp != 1) tp >>= 1, ++lg;
	for(int i=0; i<n; i++)
		rev[i] = (rev[i>>1] >> 1) | ((i & 1) << lg);
}
void NTT(int *f, int n, int *w){
	init(n);
	for(int i=0; i<n; i++) if(rev[i] > i) swap(f[i], f[rev[i]]);
	for(int l=1; l<n; l<<=1){
		int step = maxlen / (l << 1);
		for(int i=0; i<n; i += (l << 1)){
			for(int j=i, Wj=0; j<i+l; j++, Wj += step){
				int x = f[j], y = mul(f[j+l], w[Wj]);
				f[j] = add(x, y); f[j+l] = sub(x, y);
			}
		}
	}
}
void mul(int *f, int *g, int n){ // will make g unavailable
	NTT(f, n, w); NTT(g, n, w);
	for(int i=0; i<n; i++) f[i] = mul(f[i], g[i]);
	NTT(f, n, invw);
	int in = qp(n, p-2);
	for(int i=0; i<n; i++) f[i] = mul(f[i], in);
}

void inverse(int *f, int n){ // n : pow2 length of f
	static int g[N], tp[N];
	g[0] = qp(f[0], p-2);
	for(int i=1; i<n; i<<=1){ // i : last length of g
		copy(f, f+(i<<1), tp);
		NTT(tp, i<<2, w); NTT(g, i<<2, w);
		for(int j=0; j<(i<<2); j++) g[j] = mul(g[j], sub(2, mul(tp[j], g[j])));
		NTT(g, i<<2, invw);
		int invLen = qp((i<<2), p-2);
		for(int j=0; j<(i<<2); j++) g[j] = mul(g[j], invLen);
		fill(g + (i << 1), g + (i << 2), 0);
	}
	copy(g, g+n, f);
	fill(g, g+n, 0);
	fill(tp, tp+(n<<1), 0);
}

void deriv(int *a, int n){
	for(int i=1; i<n; i++) a[i-1] = mul(a[i], i);
	a[n-1] = 0;
}
void integral(int *a, int n){
	for(int i=n-1; i>=0; i--) a[i] = mul(a[i-1], pinv[i]);
	a[0] = 0;
}
void ln(int *a, int n){
	static int tp[N];
	for(int i=0; i<n-1; i++) tp[i] = mul(a[i+1], i+1);
	tp[n-1] = 0;
	inverse(a, n);
	mul(a, tp, n<<1);
	fill(a+n, a+(n<<1), 0);
	integral(a, n);
	fill(tp, tp+(n<<1), 0);
}
void exp(int *a, int n){
	static int f[N], lnf[N], f0[N];
	f[0] = 1;
	for(int i=1; i<n; i<<=1){
		copy(f, f + i, lnf); copy(f, f + i, f0);
		ln(lnf, i << 1);
		for(int j=0; j<(i << 1); j++) f[j] = sub(a[j], lnf[j]);
		++f[0];
		mul(f, f0, i << 2);
		fill(f + (i << 1), f + (i << 2), 0);
		fill(f0 + (i << 1), f0 + (i << 2), 0);
	}
	copy(f, f+n, a);
	fill(f, f + n, 0);
	fill(lnf, lnf + n, 0);
	fill(f0, f0 + n, 0);
}
void power(int *a, int n, int p){
	if(p == 1) return ;
	ln(a, n);
	for(int i=1; i<n; i++) a[i] = mul(a[i], p);
	exp(a, n);
}
void sqrt(int *a, int n){
	static int f[N], f0[N], invf[N], tp[N];
	f[0] = 1;
	for(int i=1; i<n; i<<=1){
		copy(a, a + (i << 1), tp);
		copy(f, f + i, invf);
		inverse(invf, i<<1);
		copy(f, f + i, f0);
		NTT(tp, i << 2, w); NTT(f, i << 2, w); NTT(invf, i << 2, w);
		for(int j=0; j<(i<<2); j++) f[j] = mul(sub(mul(f[j], f[j]), tp[j]), mul(pinv[2], invf[j]));
		NTT(f, i << 2, invw);
		int invLen = qp(i << 2, p - 2);
		for(int j=0; j<(i<<1); j++) f[j] = sub(f0[j], mul(f[j], invLen));
		fill(f + (i << 1), f + (i << 2), 0);
		fill(invf + (i << 1), invf + (i << 2), 0);
	}
	copy(f, f+n, a);
	fill(f, f + n, 0);
	fill(tp, tp + (n << 1), 0);
}

基于 std::vector 的板子

常数比较好,功能暂不完善。

不开 O2 会带来十倍左右的常数。

const int N = 1 << 18, M = 998244353;
using ll = long long;

int fac[N], ifac[N];

class poly : public vector<int> { using vector<int>::vector; };

inline int power(int x, int y) {
  int p = 1;
  for (; y; y >>= 1, x = 1LL * x * x % M) if (y & 1) p = 1LL * p * x % M;
  return p;
}
inline int inv(int x) { return power(x, M - 2); }
void prefac(int n) {
  fac[0] = 1;
  for (int i = 1; i <= n; ++i) fac[i] = 1LL * fac[i - 1] * i % M;
  ifac[n] = power(fac[n], M - 2);
  for (int i = n - 1; i >= 0; --i) ifac[i] = 1LL * ifac[i + 1] * (i + 1) % M;
}
int C(int n, int m) { return 1LL * fac[n] * ifac[m] % M * ifac[n - m] % M; }

int A[N], B[N];
int w[N], iw[N], rev[N], last = 0, _w[N], Inv[N + 1];

void polyinit() {
  auto calc = [](int *w, int *iw, int N) {
    w[0] = 1; iw[0] = 1;
    w[1] = power(3, (M - 1) / N); iw[1] = power(w[1], M - 2);
    for (int i = 2; i < N / 2; ++i) w[i] = 1LL * w[i - 1] * w[1] % M;
    for (int i = 2; i < N / 2; ++i) iw[i] = 1LL * iw[i - 1] * iw[1] % M;
  };
  int *_w = w, *_iw = iw;
  for (int n = N; n > 1; n >>= 1) calc(_w, _iw, n), _w += n / 2, _iw += n / 2;
  Inv[1] = 1;
  for (int i = 2; i <= N; ++i) Inv[i] = M - ll(M / i) * Inv[M % i] % M;
}
void binrev(int n) {
  if (last == n) return;
  last = n;
  int l = __builtin_ctz(n) - 1;
  for (int i = 1; i < n; ++i)
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
}

inline int add(int x, int y) { return x + y >= M ? x + y - M : x + y; }
inline int sub(int x, int y) { return x - y < 0 ? x - y + M : x - y; }
inline void inc(int &x, int y) { x += y; if (x >= M) x -= M; }

void dft(int *a, int n, int *w = ::w) {
  binrev(n);
  for (int i = 1; i < n; ++i)
    if (rev[i] > i) swap(a[i], a[rev[i]]);
  for (int l = 2, k = 1; k < n; l <<= 1, k <<= 1) {
    int *W = w + N - l;
    for (int *i = a, *t = a + n; i != t; i += l)
      for (int *p = i, *q = i + k, *omega = W, *t = W + k; omega != t; ++p, ++q, ++omega) {
        int x = *p, y = (ll)*q * *omega % M;
        *p = add(x, y); *q = sub(x, y);
      }
  }
}
void dot(int *a, int *b, int n) {
  for (int i = 0; i < n; ++i) a[i] = (ll)a[i] * b[i] % M;
}
int length(int n) {
  int l = 1;
  while (l < n) l <<= 1;
  return l;
}

poly operator*(const poly &f, const poly &g) {
  poly h(f.size() + g.size() - 1);
  int l = length(h.size());
  fill(A + f.size(), A + l, 0); fill(B + g.size(), B + l, 0);
  copy(f.begin(), f.end(), A); copy(g.begin(), g.end(), B);
  dft(A, l); dft(B, l);
  for (int i = 0; i < l; ++i) A[i] = 1LL * A[i] * B[i] % M;
  dft(A, l, iw);
  int in = power(l, M - 2);
  for (size_t i = 0; i < h.size(); ++i)
    h[i] = 1LL * A[i] * in % M;
  return h;
}

poly inverse(const poly &f, int n, int SIZE = -1) {
  poly g = {inv(f[0])};
  if (SIZE == -1) g.resize(n);
  else g.resize(SIZE);
  for (int len = 1; len < n; len <<= 1) {
    int cu = min(len * 2, (int)f.size());
    copy_n(f.begin(), cu, A); fill(A + cu, A + len * 4, 0);
    copy_n(g.begin(), len, B); fill(B + len, B + len * 4, 0);
    dft(A, len * 4); dft(B, len * 4);
    for (int j = 0, li = len * 4; j < li; ++j)
      A[j] = (2 * B[j] - (ll)B[j] * B[j] % M * A[j]) % M, A[j] += M & (A[j] >> 31);
    dft(A, len * 4, iw);
    int invN = inv(len * 4);
    for (int j = len, li = min(len * 2, n); j < li; ++j) g[j] = (ll)A[j] * invN % M;
  }
  return g;
}

poly ln(const poly &f, int n, int SIZE = -1) {
  int len = length(n + min(n, (int)f.size() - 1) - 1);
  poly g = inverse(f, n, max(len, SIZE));
  int li = min((int)f.size() - 1, n);
  for (int i = 0; i < li; ++i) A[i] = (ll)f[i + 1] * (i + 1) % M;
  fill(A + li, A + len, 0);
  dft(A, len); dft(g.data(), len); dot(A, g.data(), len);
  dft(A, len, iw);
  g[0] = 0;
  int invN = inv(len);
  for (int i = 1; i < n; ++i) g[i] = (ll)A[i - 1] * Inv[i] % M * invN % M;
  if (SIZE == -1) g.resize(n);
  else g.resize(SIZE);
  return g;
}

poly exp(const poly &f, int n) {
  poly g = {1}; g.resize(n);
  for (int len = 1; len < n; len <<= 1) {
    int up = min(len * 2, n);
    poly l = ln(g, up);
    int li = min((int)f.size(), up);
    for (int j = 0; j < li; ++j) A[j] = sub(f[j], l[j]);
    for (int j = li; j < up; ++j) A[j] = sub(0, l[j]);
    A[0] = A[0] + 1 == M ? 0 : A[0] + 1;
    fill(A + up, A + len * 4, 0);
    copy_n(g.begin(), len, B);
    fill(B + len, B + len * 4, 0);
    dft(A, len * 4); dft(B, len * 4); dot(A, B, len * 4); dft(A, len * 4, iw);
    int invN = inv(len * 4);
    for (int i = len, li = min(len * 2, n); i < li; ++i) g[i] = (ll)A[i] * invN % M;
  }
  return g;
}

poly power(const poly &f, int k, int n) {
  assert(f[0] == 1);
  poly l = ln(f, n);
  for (int &x : l) x = (ll)x * k % M;
  return exp(l, n);
}
posted @ 2020-01-21 17:37  RiverHamster  阅读(352)  评论(0)    收藏  举报
\