「BZOJ 3645」小朋友与二叉树

「BZOJ 3645」小朋友与二叉树

解题思路

\(G(x)\) 为关于可选大小集合的生成函数,即

\[G(x)=\sum[i\in c ] x^i \]

\(F(x)\)\(n\) 项的系数为为权值为 \(n\) 的二叉树的方案数,显然有

\[F(x)=F(x)^2G(x)+1\\ F^2(x)G(x)-F(x)+1=0 \\ F(x)=\dfrac{1\pm\sqrt{1-4G(x)}}{2G(x)} \]

\(x\to 0\) 时,\(F(x)\) 的值为 \(1\) ,当取加号的时候发现

\[\lim_{x\to0} F(x)=\dfrac{1}{G(x)} \\ =\infty \]

所以

\[F(x)=\dfrac{1-\sqrt{1-4G(x)}}{2G(x)} \]

由于 \(2G(x)\) 的常数项为 \(0\) 不存在逆元,所以要稍作一些变化

\[F(x)=\dfrac{4G(x)}{2G(x)(1+\sqrt{1-4G(x)})} \\ =\dfrac{2}{1+\sqrt{1-4G(x)}} \]

\(\sqrt{1-4G(x)}\) 的常数项为 \(1\) ,一遍开根一遍求逆就好了,复杂度 \(\mathcal O(n\log n)\) ,下面代码拖了多项式板子所以有用不到的部分。

code

/*program by mangoyang*/ 
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
    int ch = 0, f = 0; x = 0;
    for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
    for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
    if(f) x = -x;
}
const int N = (1 << 22) + 5, P = 998244353, G = 3;
namespace poly{
	int rev[N], W[N], invW[N], len, lg;	
	inline int Pow(int a, int b){
		int ans = 1;
		for(; b; b >>= 1, a = 1ll * a * a % P)
			if(b & 1) ans = 1ll * ans * a % P;
		return ans;
	}
	inline void init(){
		for(int k = 2; k < N; k <<= 1)
			W[k] = Pow(G, (P - 1) / k), invW[k] = Pow(W[k], P - 2);
	}
	inline void timesinit(int lenth){
		for(len = 1, lg = 0; len <= lenth; len <<= 1, lg++);
		for(int i = 0; i < len; i++)
			rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (lg - 1));
	}
	inline void DFT(int *a, int sgn){
		for(int i = 0; i < len; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
		for(int k = 2; k <= len; k <<= 1){
			int w = ~sgn ? W[k] : invW[k];
			for(int i = 0; i < len; i += k){
				int now = 1;
				for(int j = i; j < i + (k >> 1); j++){
					int x = a[j], y = 1ll * a[j+(k>>1)] * now % P;
					a[j] = (x + y) % P, a[j+(k>>1)] = (x - y + P) % P;
					now = 1ll * now * w % P;
				}
			}
		}
		if(sgn == -1){
			int Inv = Pow(len, P - 2);
			for(int i = 0; i < len; i++) a[i] = 1ll * a[i] * Inv % P;
		}
	}
	inline void getinv(int *a, int *b, int n){
		static int tmp[N];
		if(n == 1) return (void) (b[0] = Pow(a[0], P - 2));
		getinv(a, b, (n + 1) / 2);
		timesinit(n * 2 - 1);
		for(int i = 0; i < len; i++) tmp[i] = i < n ? a[i] : 0;
		DFT(tmp, 1), DFT(b, 1);
		for(int i = 0; i < len; i++) 
			b[i] = 1ll * (2 - 1ll * tmp[i] * b[i] % P + P) % P * b[i] % P;
		DFT(b, -1);
		for(int i = n; i < len; i++) b[i] = 0;
		for(int i = 0; i < len; i++) tmp[i] = 0;
	}
	inline void getsqrt(int *a, int *b, int n){
		static int tmp1[N], tmp2[N];
		if(n == 1) return (void) (b[0] = 1);
		getsqrt(a, b, (n + 1) / 2);
		for(int i = 0; i < n; i++) tmp1[i] = a[i];
		getinv(b, tmp2, n), timesinit(n * 2 - 1);
		DFT(tmp1, 1), DFT(tmp2, 1);
		for(int i = 0; i < len; i++) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % P;
		DFT(tmp1, -1);
		for(int i = 0; i < len; i++) 
			b[i] = 1ll * (b[i] + tmp1[i]) % P * Pow(2, P - 2) % P; 
		for(int i = n; i < len; i++) b[i] = 0;
		for(int i = 0; i < len; i++) tmp1[i] = tmp2[i] = 0;
	}
	inline void getln(int *a, int *b, int n){
		static int tmp[N];
		getinv(a, b, n), timesinit(n * 2 - 1);
		for(int i = 1; i < n; i++) tmp[i-1] = 1ll * a[i] * i % P;
		DFT(tmp, 1), DFT(b, 1);
		for(int i = 0; i < len; i++) b[i] = 1ll * tmp[i] * b[i] % P;
		DFT(b, -1);
		for(int i = len - 1; i; i--) b[i] = 1ll * b[i-1] * Pow(i, P - 2) % P;
		b[0] = 0;
		for(int i = n; i < len; i++) b[i] = 0;
		for(int i = 0; i < len; i++) tmp[i] = 0;
	}
	inline void getexp(int *a, int *b, int n){
		static int tmp[N]; 
		if(n == 1) return (void) (b[0] = 1);
		getexp(a, b, (n + 1) / 2);
		getln(b, tmp, n), timesinit(n * 2 - 1);
		for(int i = 0; i < n; i++) tmp[i] = (!i - tmp[i] + a[i] + P) % P;
		DFT(tmp, 1), DFT(b, 1);
		for(int i = 0; i < len; i++) b[i] = 1ll * b[i] * tmp[i] % P;
		DFT(b, -1);
		for(int i = n; i < len; i++) b[i] = 0;
		for(int i = 0; i < len; i++) tmp[i] = 0;
	}
	inline void power(int *a, int *b, int n, int m, ll k){
		static int tmp[N];
		for(int i = 0; i < m; i++) b[i] = 0;
		int fir = -1;
		for(int i = 0; i < n; i++) if(a[i]){ fir = i; break; }
		if(fir && k >= m) return;
		if(fir == -1 || 1ll * fir * k >= m) return;
		for(int i = fir; i < n; i++) b[i-fir] = a[i];
		for(int i = 0; i < n - fir; i++) 
			b[i] = 1ll * b[i] * Pow(a[fir], P - 2) % P;
		getln(b, tmp, m);
		for(int i = 0; i < m; i++) 
			b[i] = 1ll * tmp[i] * (k % P) % P, tmp[i] = 0;
		getexp(b, tmp, m);
		for(int i = m; i >= fir * k; i--) 
			b[i] = 1ll * tmp[i-fir*k] * Pow(a[fir], k % (P - 1)) % P;
		for(int i = 0; i < fir * k; i++) b[i] = 0;
		for(int i = 0; i < m; i++) tmp[i] = 0;
	}
}
using poly::Pow;
using poly::DFT;
using poly::timesinit;
int a[N], b[N], c[N], n, m;
int main(){
	poly::init(), read(n), read(m), m++;
	for(int i = 1, x; i <= n; i++) 
		read(x), a[x] = P - 4;
	a[0]++, poly::getsqrt(a, b, m); 
	b[0] = (b[0] + 1) % P;
	poly::getinv(b, c, m);
	for(int i = 1; i < m; i++) printf("%lld\n", 2ll * c[i] % P);
	return 0;
}
posted @ 2019-04-11 19:35  Joyemang33  阅读(211)  评论(1编辑  收藏  举报