题解:P5644 [PKUWC2018] 猎人杀

简单题。

题意:给出 \(n\) 个数 \(w_i\)。现在每轮删除一个数,假设现在有 \(i_1,i_2\cdots i_k\) 这些下标的数还在,那么对于 \(i_x\) 就有 \(\frac{w_{i_x}}{\sum w_{i_j}}\) 的概率删除。问 \(1\) 号元素最后一个被删除的概率。\(n\le 10^5,\sum w_i\le 10^5\)

做法:

首先因为有个 \(\sum w\) 在分母上,这个事情很烦,但是稍加考虑,我们其实可以认为每个数被删除的概率还是 \(\frac{w_i}{sum}\),这里 \(sum\) 是所有数总和,下文同。每个数可以多次被删,执行无限轮,只是后面是无效删除而已,有效删除的概率是不变的。

然后就是考虑最后一个被删这个条件了,显然很可以容斥,设 \(S\) 是我钦定他们必须在 \(1\) 后被删掉,那么答案应该就是 \(\sum\limits_S P(S)(-1)^{|S|}\),这里 \(P(S)\)\(S\) 全在 \(1\) 后被删掉的概率。

考虑如何计算 \(P(S)\),那么就要求只要不碰到 \(1,S\) 这些元素随便选,记 \(s(S)\) 是集合 \(S\)\(w\) 之和,那么概率为 \(\frac{sum-w_1-s(S)}{sum}\),枚举执行多少轮后 \(1\) 被删除,那么概率就是 \(\sum (\frac{sum-w_1-s(S)}{sum})^i\frac{w_1}{sum}\)

\(\frac{w_1}{sum}\) 这个常量提出再用等比数列求和稍微化简,得到 \(P(S)=\frac{w_1}{w_1+s(S)}\),很漂亮的柿子。

注意到题目中有 \(\sum w_i \le 10^5\),考虑直接枚举 \(s(S)\) 然后计算容斥系数的贡献即可,这个直接用多项式分治乘去做就可以,复杂度 \(O(n\log^2 n)\)

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, mod = 998244353, gb = 3, gi = (mod + 1) / gb;
int rev[maxn];
void init(int n) {
	for (int i = 1; i < n; i++) {
		rev[i] = rev[i >> 1] >> 1;
		if(i & 1)
			rev[i] |= (n >> 1);
	}
}
int qpow(int x, int k, int p) {
	int res = 1;
	while(k) {
		if(k & 1)
			res = res * x % p;
		x = x * x % p, k >>= 1;
	}
	return res;
}
struct Poly {
	vector<int> a;
	int& operator[](int x) {
		return a[x];
	}
	void resize(int N) {
		a.resize(N);
	}
	int size() {
		return a.size();
	}
	Poly() {
		
	}
	Poly(int N) {
		a.resize(N);
	}
	void NTT(int f) {
		int n = size();
		for (int i = 0; i < n; i++)
			if(i < rev[i])
				swap(a[i], a[rev[i]]);
		for (int h = 2; h <= n; h <<= 1) {
			int d = qpow((f == 1 ? gb : gi), (mod - 1) / h, mod);
			for (int i = 0; i < n; i += h) {
				int nw = 1;
				for (int j = i; j < i + h / 2; j++) {
					int a0 = a[j], a1 = a[j + h / 2] * nw % mod;
					a[j] = (a0 + a1) % mod, a[j + h / 2] = (a0 - a1 + mod) % mod;
					nw = nw * d % mod;
				}
			}
		}
		if(f == -1) {
			int inv = qpow(n, mod - 2, mod);
			for (int i = 0; i < n; i++)
				a[i] = a[i] * inv % mod;
		}
	}
	friend Poly operator*(Poly f, Poly g) {
		int len = 1, t = f.size() + g.size() - 1;
		while(len < t)
			len <<= 1;
		init(len), f.resize(len), g.resize(len);
		f.NTT(1), g.NTT(1);
		for (int i = 0; i < len; i++)
			f[i] = f[i] * g[i] % mod;
		f.NTT(-1);
		f.resize(t);
		return f;
	}
};
int n, a[maxn];
Poly solve(int l, int r) {
	if(l == r) {
		Poly f(a[l] + 1); f[0] = 1, f[a[l]] = mod - 1;
		return f;
	}
	int mid = l + r >> 1;
	return solve(l, mid) * solve(mid + 1, r);
}
signed main() {
	cin >> n;
	for (int i = 1; i <= n; i++)
		cin >> a[i];
	Poly res = solve(2, n);
	int ans = 0;
	for (int i = 0; i < res.size(); i++)
		ans = (ans + qpow(a[1] + i, mod - 2, mod) * a[1] % mod * res[i]) % mod;
	cout << ans << endl;
	return 0;
}
posted @ 2025-10-24 09:52  LUlululu1616  阅读(9)  评论(0)    收藏  举报