题解:HDU6339 Eat Cards, Have Fun / ZROI3223 商品检查

题解:HDU6339 Eat Cards, Have Fun / ZROI3223 商品检查

题目描述

\(n\) 张不重复整数的卡片排成一个圆圈,按顺时针顺序从 \(1\)\(n\) 编号 \(a_1, a_2, \cdots, a_n\)

一开始,Kazari 拿着一个空数组 \(A\) 站在卡片 \(1\) 处。

她会一直执行以下两个操作,直到所有卡片都被吃掉。

  • 她将当前卡片上的数字添加到 \(A\) 中,并以概率 \(p=\frac{p'} {q'}\) 吃掉当前卡片。
  • 移动到顺时针顺序中尚未被吃掉的下一张卡片。

很明显,最后得到的 \(A\) 是一个 \(n\) -排列。如果它在所有 \(n\) -排列中是第 \(k\) 个字典序最小的,那么定义它的值为 \(k\)。请帮她计算出 \(A\) 的期望值。

\(n\leq 5000\)

题解

\(q=1-p\)。令 \(t_i\) 表示 \(a_i\) 最终在 \(A_{t_i}\) 的位置(注意这是随机变量)。根据康托展开的知识,将答案写为:

\[\sum_i\sum_j[a_i>a_j]\mathbb{P}[t_i<t_j](n-t_i)! \]

考虑计算 \(\mathbb{P}[t_i<t_j](n-t_i)!\) 这一项的期望。由于有删除操作以及环的问题,我们考虑设 \(f_{L, i, j}\) 表示环长为 \(L\)\(\mathbb{P}[t_i<t_j](n-t_i)!\) 的期望,显然 \(i\neq j\),考虑计算它。

  • \(i=1\) 时:\(f_{L,i,j}=qf_{L,L,j-1}+p(L-1)!\),也就是要么没吃掉,要么吃掉了并结算。
  • \(j=1\) 时:\(f_{L,i,j}=qf_{L,i-1,L}\),也就是不能吃掉。
  • 否则:\(f_{L,i,j}=qf_{L,i-1,j-1}+pf_{L-1,i-1,j-1}\)

不用写高斯消元,按照 \(L\) 从小到大的顺序计算,并发现 \(f_{L,i,j}\) 按照 \((i-j)\bmod L\) 形成 \(L\) 个环,可以逐个环计算。复杂度 \(O(n^3)\)

打表可以发现,\(f_L\) 矩阵相同的数很多,同一行只有两种数字:当 \(i<j\) 时,\(f_{L,i,j}=f_{L,i,i+1}\);当 \(i>j\) 时,\(f_{L,i,j}=f_{L,i,i-1}\)(下标对 \(L\) 取模)。所以只需要枚举 \((i-j)\bmod L\in \{1,-1\}\) 就可以做到 \(O(n^2)\)。做法非常自然。

代码

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <unsigned umod>
struct modint {/*{{{*/
  static constexpr int mod = umod;
  unsigned v;
  modint() = default;
  template <class T, enable_if_t<is_integral<T>::value, int> = 0>
    modint(const T& y) : v((unsigned)(y % mod + (is_signed<T>() && y < 0 ? mod : 0))) {}
  modint& operator+=(const modint& rhs) { v += rhs.v; if (v >= umod) v -= umod; return *this; }
  modint& operator-=(const modint& rhs) { v -= rhs.v; if (v >= umod) v += umod; return *this; }
  modint& operator*=(const modint& rhs) { v = (unsigned)(1ull * v * rhs.v % umod); return *this; }
  modint& operator/=(const modint& rhs) { assert(rhs.v); return *this *= qpow(rhs, mod - 2); }
  friend modint operator+(modint lhs, const modint& rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint& rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint& rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint& rhs) { return lhs /= rhs; }
  template <class T> friend modint qpow(modint a, T b) {
    modint r = 1;
    for (assert(b >= 0); b; b >>= 1, a *= a) if (b & 1) r *= a;
    return r;
  }
  friend int raw(const modint& self) { return self.v; }
  friend ostream& operator<<(ostream& os, const modint& self) { return os << raw(self); }
  explicit operator bool() const { return v != 0; }
  modint operator-() const { return modint(0) - *this; }
  bool operator==(const modint& rhs) const { return v == rhs.v; }
  bool operator!=(const modint& rhs) const { return v != rhs.v; }
};/*}}}*/
using mint = modint<1000000007>;
constexpr int N = 5010;
int n, a[N];
mint p, q, f[2][N][N], fac[N];
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  cin >> n >> p.v, q = 1 - p;
  fac[0] = 1;
  for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i;
  for (int i = 1; i <= n; i++) cin >> a[i];
  f[0][1][2] = p / (1 - q * q);
  f[0][2][1] = q * f[0][1][2];
  for (int L = 3; L <= n; L++) {
    for (int d : {1, L - 1}) {
      // set x = f[L][L][d]
      mint prek = 1, preb = 0;
      for (int i = 1; i <= L; i++) {
        int j = (d + i - 1) % L + 1;
        prek *= q, preb *= q;
        if (i == 1) preb += p * fac[L - 1];
        else if (j != 1) preb += p * f[(L - 1) & 1][i - 1][j - 1];
      }
      // k * x + b == x
      mint pre = f[L & 1][L][d] = preb / (1 - prek);
      for (int i = 1; i <= L; i++) {
        int j = (d + i - 1) % L + 1;
        f[L & 1][i][j] = pre * q;
        if (i == 1) f[L & 1][i][j] += p * fac[L - 1];
        else if (j != 1) f[L & 1][i][j] += p * f[(L - 1) & 1][i - 1][j - 1];
        pre = f[L & 1][i][j];
      }
    }
  }
  mint ans = 0;
  for (int i = 1; i <= n; i++) {
    for (int j = 1; j <= n; j++) debug("%d%c", raw(f[n & 1][i][j]), " \n"[j == n]);
    for (int j = 1; j <= n; j++) if (a[i] > a[j]) {
      if (i < j) ans += f[n & 1][i][i % n + 1];
      else ans += f[n & 1][i][(i + n - 2) % n + 1];
    }
  }
  cout << ans + 1 << endl;
  return 0;
}
posted @ 2025-05-31 10:23  caijianhong  阅读(57)  评论(0)    收藏  举报