[刷题] 期望概率

学园祭的乐队

Description

一个吉他有 \(n\) 根弦,小\(A\)在第\(i\)天会检查第\(i\)根弦,如果发现这个弦是坏的,接下来就什么也不做了。特别地,如果小\(A\)检查完了所有的\(n\)根弦,接下来也什么都不做了

\(A\)在每一天晚上检查完弦后,会用一根新的弦等概率替换掉原来吉他上没有被替换过的弦(新的弦一定是好的,被替换的弦不一定是坏的)。

你需要计算小\(A\)检查吉他的天数期望值,答案对\(998244353\)取模。

数据范围 \(1\le n\le 1000000\)

Solution

转换一下题意:选一个全排列\(p\),每到一个点,就将 \(p_i\) 的位置置为0,遇到的第一个\(1\)的位置\(i\)就是答案。

考虑直接枚举最后不合法的弦的出现位置\(i\)

显然,对于\(i\)前面的任意一个为\(1\)的位置\(j\),都要满足\(j在全排列的位置\ <\ j\)

第一个\(j\)的概率是 \(\frac{j-1}{n}\) ,第二个\(j'\)的概率是 \(\frac{j'-1}{n-1}\) ,因为第一个\(j\)已经占了一个位置,所以分母位置是\(j-1\),以此类推……

同时,还要满足\(i在全排列的位置\ \ge \ i\),令\(i\)前面位置\(1\)的个数为\(cnt\),显然概率为 \(\frac{n-i+1}{n-cnt}\)

直接扫一遍计算即可。

时间复杂度 \(O(n)\)

Code

// Author: wlzhouzhuan
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ull unsigned long long
#define rint register int
#define rep(i, l, r) for (rint i = l; i <= r; i++)
#define per(i, l, r) for (rint i = l; i >= r; i--)
#define mset(s, _) memset(s, _, sizeof(s))
#define pb push_back
#define pii pair <int, int>
#define mp(a, b) make_pair(a, b)

inline int read() {
  int x = 0, neg = 1; char op = getchar();
  while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); }
  while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); }
  return neg * x;
}
inline void print(int x) {
  if (x < 0) { putchar('-'); x = -x; }
  if (x >= 10) print(x / 10);
  putchar(x % 10 + '0');
}

const int N = 1000005;
const int mod = 998244853;
int a[N], n, ans;
int qpow(int a, int b) {
  int ret = 1;
  while (b > 0) {
    if (b & 1) ret = 1ll * ret * a % mod;
    a = 1ll * a * a % mod;
    b >>= 1;
  }
  return ret;
}
int inv[N];
void pre(int n) {
  inv[0] = 1, inv[1] = 1;
  for (int i = 2; i <= n; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
}
int main() {
  n = read(), pre(n);
  for (int i = 1; i <= n; i++) scanf("%1d", &a[i]);
  int p = 1, cnt = 0;
  for (int i = 1; i <= n; i++) {
    if (a[i] == 1) {
      ans = (ans + 1ll * p * (n - i + 1) % mod * inv[n - cnt] % mod * i % mod) % mod;
      p = 1ll * p * (i - 1 - cnt) % mod * inv[n - cnt] % mod;
      cnt++;
    }
  }
  ans = (ans + 1ll * p * n % mod) % mod;
  ans = (ans % mod + mod) % mod;
  printf("%d\n", ans);
  return 0; 
} 
posted @ 2020-04-09 18:04  wlzhouzhuan  阅读(189)  评论(0编辑  收藏  举报