BZOJ3625: 小朋友和二叉树

传送门

Sol

\(f_x\) 表示权值为 \(x\) 的二叉树的个数
\(s_x\) 表示是否有 \(x\) 这种权值可以选择
那么

\[f_n=\sum_{i=0}^{n}\sum_{j=0}^{n-i}f_jf_{n-i-j}s_i \]

构造

\[F(x)=\sum_{i=0}f_ix^i \]

\[S(x)=\sum_{i=0}s_ix^i \]

由于 \(s_0=0,f_0=1\)
那么
\(F^2(x)S(x)=F(x)-1\)
所以可以求得

\[F(x)=\frac{1\pm \sqrt{1-4S(x)}}{2S(x)} \]

由于 \(F(0)=1,S(0)=0\) 所以

\[F(x)=\frac{1- \sqrt{1-4S(x)}}{2S(x)} \]

正是因为 \(S(0)=0\) 没有办法求逆
所以化简得到

\[F(x)=\frac{2}{1+ \sqrt{1-4S(x)}} \]

开根(常数项是 \(1\) 所以不用二次剩余)+求逆即可

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn(4e5 + 5);
const int mod(998244353);
const int inv2(499122177);

inline int Pow(ll x, int y) {
    register ll ret = 1;
    for (; y; y >>= 1, x = x * x % mod)
        if (y & 1) ret = ret * x % mod;
    return ret;
}

inline void Inc(int &x, int y) {
    if ((x += y) >= mod) x -= mod;
}

int a[maxn], b[maxn], c[maxn], w[2][maxn], deg, r[maxn], l;

inline void Init(int n) {
	register int i, k, wn, iwn;
	for (deg = 1, l = 0; deg < n; deg <<= 1) ++l;
	for (i = 0; i < deg; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	for (i = 1; i < deg; i <<= 1) {
		w[0][0] = w[1][0] = 1;
		wn = Pow(3, (mod - 1) / (i << 1)), iwn = Pow(wn, mod - 2);
		for (k = 1; k < i; ++k) {
			w[0][deg / i * k] = 1LL * w[0][deg / i * (k - 1)] * wn % mod;
			w[1][deg / i * k] = 1LL * w[1][deg / i * (k - 1)] * iwn % mod;
		}
	}
}

inline void NTT(int *p, int opt) {
	register int i, j, k, t, wn, x, y;
	for (i = 0; i < deg; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
	for (i = 1; i < deg; i <<= 1)
		for(t = i << 1, j = 0; j < deg; j += t)
			for (k = 0; k < i; ++k) {
				wn = w[opt == -1][deg / i * k];
				x = p[j + k], y = 1LL * wn * p[i + j + k] % mod;
				p[j + k] = x + y, p[i + j + k] = x - y;
				if (p[j + k] >= mod) p[j + k] -= mod;
				if (p[i + j + k] < 0) p[i + j + k] += mod;
			}
	if (opt == -1) {
		wn = Pow(deg, mod - 2);
		for (i = 0; i < deg; ++i) p[i] = 1LL * p[i] * wn % mod;
	}
}

int n, m, f[maxn], g[maxn], s[maxn];

void Inv(int *p, int *q, int len) {
	if (len == 1) {
		q[0] = Pow(p[0], mod - 2);
		return;
	}
	Inv(p, q, len >> 1);
	register int i, tmp = len << 1;
	for (i = 0; i < len; ++i) a[i] = p[i], b[i] = q[i];
	Init(tmp), NTT(a, 1), NTT(b, 1);
	for (i = 0; i < tmp; ++i) a[i] = 1LL * a[i] * b[i] % mod * b[i] % mod;
	NTT(a, -1);
	for (i = 0; i < len; ++i) q[i] = (2LL * q[i] + mod - a[i]) % mod;
	for (i = 0; i < tmp; ++i) a[i] = b[i] = 0;
}

void Sqrt(int *p, int *q, int len) {
    if (len == 1) {
        q[0] = sqrt(p[0]);
        return;
    }
    Sqrt(p, q, len >> 1), Inv(q, c, len);   
    register int i, tmp = len << 1;
    for (i = 0; i < len; ++i) a[i] = p[i];
    Init(tmp), NTT(a, 1), NTT(c, 1);
    for (i = 0; i < tmp; ++i) a[i] = 1LL * a[i] * c[i] % mod;
    NTT(a, -1);
    for (i = 0; i < len; ++i) q[i] = 1LL * (q[i] + a[i]) % mod * inv2 % mod;
    for (i = 0; i < tmp; ++i) a[i] = c[i] = 0;
}

int main() {
	register int i, len, v;
	scanf("%d%d", &n, &m);
	for (i = 1; i <= n; ++i) scanf("%d", &v), ++s[v];
	for (len = 1; len <= m; len <<= 1);
	for (i = 1; i <= m; ++i) s[i] = 1LL * s[i] * 4 % mod, s[i] = mod - s[i];
	s[0] = 1, Sqrt(s, f, len), Inc(f[0], 1), Inv(f, g, len);
	for (i = 0; i < len; ++i) g[i] = 2LL * g[i] % mod;
	for (i = 1; i <= m; ++i) printf("%d\n", g[i]);
    return 0;
}
posted @ 2018-11-29 17:25  Cyhlnj  阅读(...)  评论(...编辑  收藏