题解:P5644 [PKUWC2018] 猎人杀
简单题。
题意:给出 \(n\) 个数 \(w_i\)。现在每轮删除一个数,假设现在有 \(i_1,i_2\cdots i_k\) 这些下标的数还在,那么对于 \(i_x\) 就有 \(\frac{w_{i_x}}{\sum w_{i_j}}\) 的概率删除。问 \(1\) 号元素最后一个被删除的概率。\(n\le 10^5,\sum w_i\le 10^5\)
做法:
首先因为有个 \(\sum w\) 在分母上,这个事情很烦,但是稍加考虑,我们其实可以认为每个数被删除的概率还是 \(\frac{w_i}{sum}\),这里 \(sum\) 是所有数总和,下文同。每个数可以多次被删,执行无限轮,只是后面是无效删除而已,有效删除的概率是不变的。
然后就是考虑最后一个被删这个条件了,显然很可以容斥,设 \(S\) 是我钦定他们必须在 \(1\) 后被删掉,那么答案应该就是 \(\sum\limits_S P(S)(-1)^{|S|}\),这里 \(P(S)\) 是 \(S\) 全在 \(1\) 后被删掉的概率。
考虑如何计算 \(P(S)\),那么就要求只要不碰到 \(1,S\) 这些元素随便选,记 \(s(S)\) 是集合 \(S\) 的 \(w\) 之和,那么概率为 \(\frac{sum-w_1-s(S)}{sum}\),枚举执行多少轮后 \(1\) 被删除,那么概率就是 \(\sum (\frac{sum-w_1-s(S)}{sum})^i\frac{w_1}{sum}\)。
把 \(\frac{w_1}{sum}\) 这个常量提出再用等比数列求和稍微化简,得到 \(P(S)=\frac{w_1}{w_1+s(S)}\),很漂亮的柿子。
注意到题目中有 \(\sum w_i \le 10^5\),考虑直接枚举 \(s(S)\) 然后计算容斥系数的贡献即可,这个直接用多项式分治乘去做就可以,复杂度 \(O(n\log^2 n)\)。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, mod = 998244353, gb = 3, gi = (mod + 1) / gb;
int rev[maxn];
void init(int n) {
for (int i = 1; i < n; i++) {
rev[i] = rev[i >> 1] >> 1;
if(i & 1)
rev[i] |= (n >> 1);
}
}
int qpow(int x, int k, int p) {
int res = 1;
while(k) {
if(k & 1)
res = res * x % p;
x = x * x % p, k >>= 1;
}
return res;
}
struct Poly {
vector<int> a;
int& operator[](int x) {
return a[x];
}
void resize(int N) {
a.resize(N);
}
int size() {
return a.size();
}
Poly() {
}
Poly(int N) {
a.resize(N);
}
void NTT(int f) {
int n = size();
for (int i = 0; i < n; i++)
if(i < rev[i])
swap(a[i], a[rev[i]]);
for (int h = 2; h <= n; h <<= 1) {
int d = qpow((f == 1 ? gb : gi), (mod - 1) / h, mod);
for (int i = 0; i < n; i += h) {
int nw = 1;
for (int j = i; j < i + h / 2; j++) {
int a0 = a[j], a1 = a[j + h / 2] * nw % mod;
a[j] = (a0 + a1) % mod, a[j + h / 2] = (a0 - a1 + mod) % mod;
nw = nw * d % mod;
}
}
}
if(f == -1) {
int inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; i++)
a[i] = a[i] * inv % mod;
}
}
friend Poly operator*(Poly f, Poly g) {
int len = 1, t = f.size() + g.size() - 1;
while(len < t)
len <<= 1;
init(len), f.resize(len), g.resize(len);
f.NTT(1), g.NTT(1);
for (int i = 0; i < len; i++)
f[i] = f[i] * g[i] % mod;
f.NTT(-1);
f.resize(t);
return f;
}
};
int n, a[maxn];
Poly solve(int l, int r) {
if(l == r) {
Poly f(a[l] + 1); f[0] = 1, f[a[l]] = mod - 1;
return f;
}
int mid = l + r >> 1;
return solve(l, mid) * solve(mid + 1, r);
}
signed main() {
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
Poly res = solve(2, n);
int ans = 0;
for (int i = 0; i < res.size(); i++)
ans = (ans + qpow(a[1] + i, mod - 2, mod) * a[1] % mod * res[i]) % mod;
cout << ans << endl;
return 0;
}

浙公网安备 33010602011771号