BZOJ 3625: [Codeforces Round #250]小朋友和二叉树

\(f[i]\) 为权值为 \(i\) 的二叉树种类数
那么有 \(f[i]=\sum\limits_{c \in \mathbb{C}}c\sum\limits_{j+k=i}f[j]\times f[k]\)
写成生成函数的形式就有 \(F(x)=C(x)*F(x)^2+1\)
\(1\) 是为了 \(f(0)=1\)
解方程得到$$F(x)=\frac{1-\sqrt{1-4C(x)}}{2C(x)}=\frac{2}{1+\sqrt{1-4*C(x)}}$$
多项式开根加求逆元即可。
因为NTT实现的跑起来还没MTT跑得快,所以直接用MTT了。。

#include <bits/stdc++.h>

namespace IO {
	char buf[1 << 21], buf2[1 << 21], a[20], *p1 = buf, *p2 = buf, hh = '\n';
	int p, p3 = -1;
	void read() {}
	void print() {}
	inline int getc() {
		return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
	}
	inline void flush() {
		fwrite(buf2, 1, p3 + 1, stdout), p3 = -1;
	}
	template <typename T, typename... T2>
	inline void read(T &x, T2 &... oth) {
		T f = 1; x = 0;
		char ch = getc();
		while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getc(); }
		while (isdigit(ch)) { x = x * 10 + ch - 48; ch = getc(); }
		x *= f;
		read(oth...);
	}
	template <typename T, typename... T2>
	inline void print(T x, T2... oth) {
		if (p3 > 1 << 20) flush();
		if (x < 0) buf2[++p3] = 45, x = -x;
		do {
			a[++p] = x % 10 + 48;
		} while (x /= 10);
		do {
			buf2[++p3] = a[p];
		} while (--p);
		buf2[++p3] = hh;
		print(oth...);
	}
}

typedef long long ll;
typedef double db;
const int MOD = 998244353, inv2 = (MOD + 1) / 2;
const int N = 1 << 19;

inline void M(int &x) { if (x >= MOD) x -= MOD; if (x < 0) x += MOD; }
inline int mul_mod(int x, int y) {return (ll)x * y % MOD;}
inline int qp(int a, int n = MOD - 2) {
  int ret = 1;
  for (a %= MOD; n; n >>= 1) {
    if (n & 1) ret = mul_mod(ret, a);
    a = mul_mod(a, a);
  }
  return ret;
}

namespace FFT {
  #define toll(val) ((long long)((val)+(val<0?-0.5:0.5))%MOD)
  const db PI = acos(-1.0L);
  const int M0 = 1 << 15;
  struct Complex {
    db re, im;
    Complex(db _re = 0, db _im = 0) : re(_re), im(_im) {}
    Complex operator!() const {return Complex(re, -im);}
    Complex operator+(const Complex &b) const {return Complex(re + b.re, im + b.im);}
    Complex operator-(const Complex &b) const {return Complex(re - b.re, im - b.im);}
    Complex operator*(const Complex &b) const {return Complex(re * b.re - im * b.im, re * b.im + im * b.re);}
    Complex& operator+=(const Complex &b) {re += b.re, im += b.im; return *this;}
    Complex& operator*=(const Complex &b) {db x = re * b.re - im * b.im, y = re * b.im + im * b.re; re = x, im = y; return *this;}
    Complex& operator/=(db d) {re /= d, im /= d; return *this;}
  };
  int base = 1;
  int rev[N] = {0, 1};
  Complex roots[N] = {{0, 0}, {1, 0}};
  inline void ensure_base(int nbase) {
    if (base > nbase) return;
    //rev.resize(1 << nbase), roots.resize(1 << nbase);
    for (int i = 0; i < (1 << nbase); i++)
      rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
    for (; base < nbase; ++base) {
      db angle = 2 * PI / (1 << (base + 1));
      for (int i = 1 << (base - 1); i < (1 << base); i++) {
        db angle_i = angle * (2 * i + 1 - (1 << base));
        roots[i << 1] = roots[i];
        roots[(i << 1) + 1] = Complex(cos(angle_i), sin(angle_i));
      }
    }
  }
  inline void _fft(Complex a[], int n, bool isInv = false) {
    int zeros = __builtin_ctz(n);
    ensure_base(zeros);
    int shift = base - zeros;
    for (int i = 0; i < n; ++i)
      if (i < (rev[i] >> shift)) std::swap(a[i], a[rev[i] >> shift]);
    for (int k = 1; k < n; k <<= 1) {
      for (int i = 0; i < n; i += 2 * k) {
        for (int j = 0; j < k; ++j) {
          Complex z(a[j + i + k] * (isInv ? !roots[j + k] : roots[j + k]));
          a[j + i + k] = a[j + i] - z, a[j + i] += z;
        }
      }
    }
    if (isInv) for (int i = 0; i < n; ++i) a[i] /= n;
  }
  inline int get_n(int x) {
    int n = 0x80000000u >> __builtin_clz(x);
    return n < x ? n << 1 : n;
  }
  Complex pa[N], pb[N], tem[N];
  inline void mul_force(int a[], int len_a, int b[], int len_b, int res[]) {
    static int tempSum[N];
    int *sum = res, tot = len_a + len_b - 1;
    if (res == a || res == b) sum = tempSum;
    memset(sum, 0, sizeof(int)*tot);
    for (int i = 0; i < len_a; ++i) {
      for (int j = 0; j < len_b; ++j) {
        sum[i + j] += mul_mod(a[i], b[j]);
        if (sum[i + j] >= MOD) sum[i + j] -= MOD;
      }
    }
    if (sum == tempSum) memcpy(res, sum, sizeof(int)*tot);
  }
  inline void convo(int a[], int len_a, int b[], int len_b, int res[]) {
    if (ll(len_a) * len_b <= 1 << 18) {
      mul_force(a, len_a, b, len_b, res); return;
    }
    int tot = len_a + len_b - 1, n = get_n(tot);
    for (int i = 0; i < len_a; ++i)
      pa[i] = Complex(a[i] / M0, a[i] % M0);
    memset(pa + len_a, 0, sizeof(Complex) * (n - len_a));
    _fft(pa, n);
    if (a == b && len_a == len_b) {
      memcpy(pb, pa, sizeof(Complex)*n);
    } else {
      for (int i = 0; i < len_b; ++i)
        pb[i] = Complex(b[i] / M0, b[i] % M0);
      memset(pb + len_b, 0, sizeof(Complex) * (n - len_b));
      _fft(pb, n);
    }
    tem[0] = Complex(pa[0].im * pb[0].im, 0);
    for (int i = 1; i < n; ++i)
      (tem[i] = (pa[i] - !pa[n - i]) * (pb[i] - !pb[n - i])) /= -4; // tem = Fra*Frb
    for (int i = 0; i < n; ++i)
      (pa[i] *= pb[i]) += tem[i];
    _fft(pa, n, true), _fft(tem, n, true);// IDFT(Fka*Fkb + i(Fra*Fkb + Fka*Frb))、IDFT(Fra*Frb)
    for (int i = 0; i < tot; ++i)
      res[i] = (toll(pa[i].re) * M0 * M0 + toll(pa[i].im) * M0 + toll(tem[i].re)) % MOD;
  }
	void poly_inv(int a[], int b[], int n) {
	  if (n == 1) {b[0] = qp(a[0], MOD - 2); return;}
	  static int res[N];
	  int hN = (n + 1) >> 1;
	  poly_inv(a, b, hN);
	  convo(b, hN, b, hN, res);
	  convo(a, n, res, hN * 2 - 1, res);
	  for (int i = hN; i < n; ++i) b[i] = res[i] == 0 ? 0 : MOD - res[i];
	}
	void poly_sqr(int *a, int *b, int n) {
	  static int A[N], B[N];
	  b[0] = 1;
	  int len, lim;
	  for (len = 1; len < (n << 1); len <<= 1) {
	    lim = len << 1;
	    for (int i = 0; i < len; i++) A[i] = a[i];
	    poly_inv(b, B, lim >> 1);
	    convo(A, len, B, len, A);
	    for (int i = 0; i < len; i++) b[i] = 1ll * (b[i] + A[i]) * inv2 % MOD;
	    for (int i = len; i < lim; i++) b[i] = 0;
	  }
	  for (int i = 0; i < len; i++) A[i] = B[i] = 0;
	  for (int i = n; i < len; i++) b[i] = 0;
	}
}

int a[N], b[N], n, m, c[N];

int main() {
	IO::read(n, m);
	while (n--) {
		int x;
		IO::read(x);
		a[x]++;
	}
	a[0] = 1;
	int len = m + 1;
	for (int i = 1; i < len; i++)
		M(a[i] = (-4LL * a[i] % MOD + MOD)), assert(a[i] >= 0 && a[i] < MOD);
	FFT::poly_sqr(a, b, len);
	b[0]++;
	M(b[0]);
	FFT::poly_inv(b, c, len);
	for (int i = 1; i <= m; i++) 
		M(c[i] += c[i]), IO::print(c[i]);
	IO::flush();
	return 0;
}
posted @ 2020-02-10 13:27  Mrzdtz220  阅读(72)  评论(0)    收藏  举报