LOJ2541 「PKUWC2018」猎人杀

LOJ2541 「PKUWC2018」猎人杀

题目大意

题目链接

\(n\) 个猎人,每个猎人有一个权值 \(w_i\)。每个猎人死去后,会开枪打死一个还活着的猎人。假设当前还活着的猎人为 \(\{i_1, \dots, i_m\}\),那么有 \(\frac{w_{i_{k}}}{\sum_{j = 1}^{m}w_{i_{j}}}\) 的概率向 \(i_k\) 开枪。第一枪由你打响,目标的选择方法和猎人一样。由于开枪导致的连锁反应,所有猎人最终都会死亡。请求出 \(1\) 号猎人最后一个死的的概率。答案对 \(998244353\) 取模。

数据范围 \(w_i > 0\)\(1\leq \sum_{i = 1}^{n}w_i\leq 10^5\)

前置知识

  1. \(\forall x \in[0, 1):\sum_{i = 0}^{\infty} x^i = \frac{1}{1 - x}\)
  2. 给定 \(\{a_i\}, \{b_i\}, \{c_i\}\) 序列,求关于 \(x\) 的多项式:\(\prod_{i = 1}^{n}(c_i + a_ix^{b_i})\),其中 \(n\leq B = \sum b_i\leq 10^5\)。可以用分治 FFT 做。时间复杂度 \(\mathcal{O}(B\log^2 B)\)

本题题解

初步转化

因为每个猎人死后,概率的分母也会改变,这给我们的计算带来了极大的不便。于是考虑不改变分母:也就是不管死了多少个猎人,我们开枪时仍然在 \(n\) 个猎人里进行选择,如果选到了已经死去的猎人,就假装无事发生,再选一次,直到某次选中了活着的猎人为止

结论:这样转化后,每个还活着的猎人被选中的概率不会变化

以下是简单的证明。设当前还活着的人的集合为 \(A = \{a_1, \dots, a_{|A|}\}\),已经死去的人的集合为 \(D = \{d_1, \dots, d_{|D|}\}\)。设 \(W = \sum_{i = 1}^{n}w_i\)(即所有猎人的权值之和),\(T = \sum_{i \in D} w_i\)(即所有已死的猎人的权值之和)。设原问题中,杀死当前还活着的人 \(a_k\) 的概率为 \(P_1\),转化后杀死他的概率为 \(P_2\)。显然:

\[P_1 = \frac{w_{a_k}}{W - T} \]

\(P_2\),可以枚举在选到活着的猎人之前,选了几次已死的猎人。则:

\[\begin{align} P_2 =& \sum_{i = 0}^{\infty}(\frac{T}{W})^{i}\cdot \frac{w_{a_k}}{W}\\ =& \frac{w_{a_k}}{W}\cdot \frac{1}{1 - \frac{T}{W}}\\ =& \frac{w_{a_k}}{W}\cdot \frac{W}{W - T}\\ =& P_1 \end{align} \]


容斥

接下来做这个初步转化后的问题。考虑容斥,设集合 \(S\) 里的这些人死的时间晚于 \(1\)\(S\subseteq\{2, \dots, n\}\)),其他人死的时间不限(可以早于 \(1\) 也可以晚于 \(1\))。我们要计算这种情况发生的概率,然后乘以 \((-1)^{|S|}\),累加进答案。

\(\mathrm{sum}(S) = \sum_{i\in S} w_i\)(即 \(S\) 集合里点的权值和)。则:

\[\mathrm{ans} = \sum_{S}(-1)^{|S|}\sum_{i = 0}^{\infty}(\frac{W - \mathrm{sum}(S) - w_1}{W})^{i} \cdot \frac{w_1}{W} \]

因为 \(\frac{W - \mathrm{sum}(S) - w_1}{W}\in[0, 1)\),故考虑使用公式:\(\forall x \in[0, 1):\sum_{i = 0}^{\infty} x^i = \frac{1}{1 - x}\)。则:

\[\mathrm{ans} = \sum_{S}(-1)^{|S|}\cdot \frac{w_1}{W}\cdot \frac{W}{\mathrm{sum}(S) + w_1} \]

\(f_i\) 表示 \(\mathrm{sum}(S) = i\) 的集合 \(S\)\((-1)^{|S|}\) 之和。则:

\[\mathrm{ans} = \sum_{i = 0}^{W - w_1} f_i\cdot \frac{w_1}{i + w_1} \]

问题转化为求 \(f\) 数组。


生成函数

考虑把一组 \(w\) 贡献到 \(f\) 里的过程:下标相加,系数相乘。容易联想到生成函数。

具体来说,构造 \(F(x) = \sum_{i = 0}^{\infty} f_i x^{i}\)。则:

\[F(x) = \prod_{i = 2}^{n} (1 - x^{w_i}) \]

因为 \(\sum w_i\leq 10^5\),上式可以用分治 NTT 求出。时间复杂度 \(\mathcal{O}(n\log ^2 n)\)(此处认为 \(W, n\) 同阶)。

参考代码

实际提交时建议使用输入、输出优化,详见本博客公告。

// problem: LOJ2541
#include <bits/stdc++.h>
using namespace std;

#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T &x, T y) {
    x = (y > x ? y : x);
}
template<typename T> inline void ckmin(T &x, T y) {
    x = (y < x ? y : x);
}

const int MAXN = 1e5;
const int MOD = 998244353;

inline int mod1(int x) {
    return x < MOD ? x : x - MOD;
}
inline int mod2(int x) {
    return x < 0 ? x + MOD : x;
}
inline void add(int &x, int y) {
    x = mod1(x + y);
}
inline void sub(int &x, int y) {
    x = mod2(x - y);
}
inline int pow_mod(int x, int i) {
    int y = 1;

    while (i) {
        if (i & 1)
            y = (ll)y * x % MOD;

        x = (ll)x * x % MOD;
        i >>= 1;
    }

    return y;
}

namespace PolyNTT {
int rev[MAXN * 4 + 5];
int f[MAXN * 4 + 5], g[MAXN * 4 + 5];
void NTT(int *a, int n, int flag) {
    for (int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(a[i], a[rev[i]]);

    for (int i = 1; i < n; i <<= 1) {
        int T = pow_mod(3, (MOD - 1) / (i << 1));

        if (flag == -1)
            T = pow_mod(T, MOD - 2);

        for (int j = 0; j < n; j += (i << 1)) {
            for (int k = 0, t = 1; k < i; ++k, t = (ll)t * T % MOD) {
                int Nx = a[j + k], Ny = (ll)a[i + j + k] * t % MOD;
                a[j + k] = mod1(Nx + Ny);
                a[i + j + k] = mod2(Nx - Ny);
            }
        }
    }

    if (flag == -1) {
        int invn = pow_mod(n, MOD - 2);

        for (int i = 0; i < n; ++i)
            a[i] = (ll)a[i] * invn % MOD;
    }
}
void mul(int n, int m) {
    int lim = 1, ct = 0;

    while (lim <= n + m)
        lim <<= 1, ct++;

    for (int i = n; i <= lim; ++i)
        f[i] = 0;

    for (int i = m; i <= lim; ++i)
        g[i] = 0; //clear

    for (int i = 0; i < lim; ++i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (ct - 1));

    NTT(f, lim, 1);
    NTT(g, lim, 1);

    for (int i = 0; i < lim; ++i)
        f[i] = (ll)f[i] * g[i] % MOD;

    NTT(f, lim, -1);
}
}//namespace PolyNTT

typedef vector<int> Poly;

Poly operator*(const Poly &a, const Poly &b) {
    if (!SZ(a) && !SZ(b))
        return Poly();

    Poly res;
    res.resize(SZ(a) + SZ(b) - 1);

    if (SZ(a) <= 50 && SZ(b) <= 50) {
        for (int i = 0; i < SZ(a); ++i)
            for (int j = 0; j < SZ(b); ++j)
                add(res[i + j], (ll)a[i]*b[j] % MOD);

        return res;
    }

    for (int i = 0; i < SZ(a); ++i)
        PolyNTT::f[i] = a[i];

    for (int i = 0; i < SZ(b); ++i)
        PolyNTT::g[i] = b[i];

    PolyNTT::mul(SZ(a), SZ(b));

    for (int i = 0; i < SZ(res); ++i)
        res[i] = PolyNTT::f[i];

    return res;
}
Poly &operator*=(Poly &lhs, const Poly &rhs) {
    lhs = lhs * rhs;
    return lhs;
}
Poly operator+(const Poly &a, const Poly &b) {
    Poly res;
    res.resize(max(SZ(a), SZ(b)));

    for (int i = 0; i < SZ(res); ++i) {
        res[i] = mod1((i >= SZ(a) ? 0 : a[i]) + (i >= SZ(b) ? 0 : b[i]));
    }

    return res;
}
Poly &operator+=(Poly &lhs, const Poly &rhs) {
    lhs = lhs + rhs;
    return lhs;
}

int n, w[MAXN + 5], W;

Poly solve(int l, int r) {
    if (l == r) {
        Poly res;
        res.resize(w[l] + 1);
        res[0] = 1;
        res[w[l]] = MOD - 1;
        return res;
    }

    int mid = (l + r) >> 1;
    return solve(l, mid) * solve(mid + 1, r);
}

int main() {
    cin >> n;

    if (n == 1) {
        cout << 1 << endl;
        return 0;
    }

    for (int i = 1; i <= n; ++i) {
        cin >> w[i];
        W += w[i];
    }

    Poly f = solve(2, n);
    assert(SZ(f) == W - w[1] + 1);
    int ans = 0;

    for (int i = 0; i <= W - w[1]; ++i) {
        add(ans, (ll)f[i] * pow_mod(i + w[1], MOD - 2) % MOD);
    }

    ans = (ll)ans * w[1] % MOD;
    cout << ans << endl;
    return 0;
}
posted @ 2021-03-06 21:48  duyiblue  阅读(187)  评论(0编辑  收藏  举报