题解:qoj8329 Excuse

题意:给出一个数 \(n\),现在我通过下面这个算法生成一个长为 \(n\) 的序列 \(a\)

  • 先进行 \(n\) 次随机扔一个硬币,然后如果你最后连续 \(k\) 次投出的是正面朝上,那么就将一个 \(k\) 加入序列末尾。

问序列 \(a\)\(\operatorname{mex}\) 的期望值。

做法:

首先转化为 \(\operatorname{mex}\ge j\) 的概率和。

然后枚举 \(k\in[1,n]\),那么就要求 \([0,k-1]\) 全都要选至少一个,剩余的都可以。我们不妨认为,对于 \(i\in[0,k-1]\) 出现的概率为 \(2^{-(i+1)}\),而 \(i\ge k\) 的概率是 \(2^{-(k+1)}\)。那么我们枚举每个元素出现次数,注意我可以安排他们的位置,可以得到概率是:

\[P_k = n!\sum_{a_0+a_1+\cdots +a_k=n,\forall t\in [0,k-1],a_t\not=0} \prod_{i=0}^k (2^{-(i+1)})^{a_i}\frac{1}{a_i!} \]

把所有东西按 \(i\) 离开,发现就是个 exp 状物,那么可以改写为:

\[P_k=n!\prod_{i=0}^{k-1}(e^{\frac{x}{2^{i+1}}} - 1)\times e^{\frac{x}{2^{k+1}}} \]

然后注意到中间这个 \(\prod\) 感觉很有性质,因为他们长得都很像。我们这里设 \(xf(x)=e^x-1\),至于为什么需要前面有个 \(x\) 后面会说,再改写柿子:

\[P_k=n!\frac{x^k}{2^{\frac{k(k+1)}{2}}}\prod_{i=0}^{k-1}f(\frac{x}{2^{i+1}})\times e^{\frac{x}{2^{k+1}}} \]

然后我们考虑,如果我们能算出来 \(g(x) = \prod\limits_{i=0}^{\infty}f(\frac{x}{2^{i+1}})\) 这个无限乘积,那么我就可以用 \(g(x)\times g^{-1}(\frac{x}{2^{k}})\) 直接算出来中间这个 \(\prod\)

考虑怎么计算 \(g\),求积很麻烦,直接取 \(\log\) 换成求和,因为我们特意上面凑了一个 \(x\) 使得 \(f\) 的零次项是 \(1\) 所以可以直接取 \(\log\),得到:

\[\log g = \sum_{i=0} \log f(\frac{x}{2^{i+1}}) \]

然后直接展开右边的每一项,那么第 \(n\) 项会多带一个 \(\sum \frac{1}{2^{ni}}\) 的系数,也就是 \(\frac{1}{2^n-1}\),直接给这个 \(f\) 算出来然后很容易算出 \(g\)

之后为了更方便看出变化的部分,会用颜色标记。

然后我们带回到整体的柿子里去,可以得到:

\[ans = n![x^n]\sum_{k=1}\frac{x^k}{2^{\frac{k(k+1)}{2}}} g(x)\textcolor{red}{g^{-1}(\frac{x}{2^{k}})e^{\frac{x}{2^k}}} \]

\(h(x) = g^{-1}(x)e^{x}\),有:

\[ans = n![x^n]g(x)\sum_{k=1}\textcolor{blue}{\frac{x^k}{2^{\frac{k(k+1)}{2}}}} \textcolor{red}{h(\frac{x}{\textcolor{green}{2^k}})} \]

把后面这个东西的卷积展开,有:

\[a_n=\sum_{k=1}\textcolor{blue}{2^{-\frac{k(k+1)}{2}}}\times \textcolor{green}{2^{-k(n-k)}} \times \textcolor{red}{h_{n-k}} \]

这个东西中间用 Chirp Z 变换一下,\(k(n-k)=\binom{n}{2}-\binom k 2 -\binom {n-k}2\) 然后分离一下 \(n,k,n-k\),可以得到:

\[a_n2^{\binom n 2}=\sum_{k=1}\textcolor{blue}{2^{-k}}\textcolor{red}{h_{n-k}2^{\binom {n-k}2}} \]

发现这个东西是可以递推计算的。

然后直接带回去解就可以了,复杂度 \(O(n\log n)\),瓶颈在于求 exp。

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, mod = 998244353, gb = 3, gi = (mod + 1) / gb;
int qpow(int x, int k, int p) {
	int res = 1;
	while(k) {
		if(k & 1)
			res = res * x % p;
		x = x * x % p, k >>= 1;
	}
	return res;
}
int rev[maxn], inv[maxn];
void prepare(int n) {
	inv[1] = inv[0] = 1;
	for (int i = 2; i <= n; i++)
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
void init(int len) {
	for (int i = 0; i < len; i++) {
		rev[i] = rev[i >> 1] >> 1;
		if(i & 1)
			rev[i] |= (len >> 1);
	}
}
struct Poly {
	vector<int> a;
	void resize(int N) {
		a.resize(N);
	}
	void clear() {
		a.clear();
	}
	int size() {
		return a.size();
	}
	int& operator[](int x) {
		return a[x];
	}
	Poly() {}
	Poly(int N) {a.resize(N);}
	void NTT(int f) {
		for (int i = 0; i < size(); i++)
			if(i < rev[i])
				swap(a[i], a[rev[i]]);
		for (int h = 2; h <= size(); h <<= 1) {
			int d = qpow((f == 1 ? gb : gi), (mod - 1) / h, mod);
			for (int i = 0; i < size(); i += h) {
				int nw = 1;
				for (int j = i; j < i + h / 2; j++) {
					int a0 = a[j], a1 = a[j + h / 2] * nw % mod;
					a[j] = (a0 + a1) % mod, a[j + h / 2] = (a0 - a1 + mod) % mod;
					nw = nw * d % mod;
				}
			}
		}
		if(f == -1) {
			int inv = qpow(size(), mod - 2, mod);
			for (int i = 0; i < size(); i++)	
				a[i] = a[i] * inv % mod;
		}
	}
	friend Poly operator*(Poly f, Poly g) {
		int len = 1, t = f.size() + g.size() - 1;
		while(len < t)
			len <<= 1;
		init(len), f.resize(len), g.resize(len);
		f.NTT(1), g.NTT(1);
		for (int i = 0; i < len; i++)
			f[i] = f[i] * g[i] % mod;
		f.NTT(-1);
		f.resize(t);
		return f; 
	}
	friend Poly operator+(Poly f, Poly g) {
		int d = max(f.size(), g.size());
		f.resize(d), g.resize(d);
		for (int i = 0; i < d; i++)
			f[i] = (f[i] + g[i]) % mod;
		return f;
	}
	friend Poly operator-(Poly f, Poly g) {
		int d = max(f.size(), g.size());
		f.resize(d), g.resize(d);
		for (int i = 0; i < d; i++)
			f[i] = (f[i] - g[i] + mod) % mod;
		return f;
	}
	void print() {
		for (int i = 0; i < size(); i++)
			cout << a[i] << " ";
		cout << endl;
	}
	friend Poly operator+(Poly f, int v) {
		f[0] = (f[0] + v) % mod;
		return f;
	}
	friend Poly operator-(Poly f, int v) {
		f[0] = (f[0] - v + mod) % mod;
		return f;
	}
} ;
Poly get_deriv(Poly f) {
	Poly g(f.size() - 1);
	for (int i = 0; i < g.size(); i++)
		g[i] = f[i + 1] * (i + 1) % mod;
	return g;
}
Poly get_integ(Poly f) {
	Poly g(f.size() + 1);
	for (int i = 1; i < g.size(); i++)
		g[i] = f[i - 1] * inv[i] % mod;
	return g;
}
Poly get_inv(Poly f, int lim) {
	if(lim == 1) {
		f.resize(1);
		f[0] = qpow(f[0], mod - 2, mod);
		return f;
	}
	Poly g = get_inv(f, lim + 1 >> 1);
	int len = 1;
	while(len < lim * 2)
		len <<= 1;
	init(len);
	f.resize(lim), f.resize(len), g.resize(len);
	f.NTT(1), g.NTT(1);
	for (int i = 0; i < len; i++)
		f[i] = (2 * g[i] - f[i] * g[i] % mod * g[i] % mod + mod) % mod;
	f.NTT(-1);
	f.resize(lim);
	return f;
}
Poly get_ln(Poly f, int lim) {
	return get_integ(get_deriv(f) * get_inv(f, lim));
}
Poly get_exp(Poly f, int lim) {
	if(lim == 1) {
		Poly ans(1); ans[0] = 1;
		return ans;
	}
	f.resize(lim);
	Poly g = get_exp(f, lim + 1 >> 1), h = get_ln(g, lim);
	h = (f - h + 1);
	g = g * h;
	g.resize(lim);
	return g;
} 
Poly shift(Poly f) {
	for (int i = 0; i < f.size() - 1; i++)
		f[i] = f[i + 1];
	f.resize(f.size() - 1);
	return f;
}
int n;
Poly g, h;
void get_gh() {
	Poly f; f.resize(n + 2);
	f[1] = 1;
	f = shift(get_exp(f, n + 2) - 1);
	f = get_ln(f, n + 1);
	g.resize(n + 1);
	for (int i = 1; i <= n; i++)
		g[i] = f[i] * qpow(qpow(2, i, mod) - 1, mod - 2, mod) % mod;
	g = get_exp(g, n + 1);
	f.clear(), f.resize(n + 1);
	f[1] = 1, f = get_exp(f, n + 1);
	h = get_inv(g, n + 1) * f; h.resize(n + 1);
}
signed main() {
	cin >> n;
	prepare(n + 2);
	get_gh();
	Poly res; res.resize(n + 1);
	for (int i = 0; i <= n; i++)
		h[i] = h[i] * qpow(2, i * (i - 1) / 2, mod) % mod;
	int inv2 = (mod + 1) / 2;
	for (int i = 1; i <= n; i++) 
		res[i] = (res[i - 1] * inv2 % mod + inv2 * h[i - 1] % mod) % mod;
	for (int i = 1; i <= n; i++)
		res[i] = res[i] * qpow(qpow(2, i * (i - 1) / 2, mod), mod - 2, mod) % mod;
	res = res * g;
	for (int i = 1; i <= n; i++)
		res[n] = res[n] * i % mod;
	cout << res[n] << endl;
	return 0; 
}
posted @ 2025-10-17 18:25  LUlululu1616  阅读(12)  评论(0)    收藏  举报