【日常训练】取数问题
Description
给出一个长度为 \(n\) 的序列 \(a\),接下来会有 \(m\) 次询问。
每次询问会给出一个区间 \([l, r]\) 和一个数 \(x\),你的任务如下。
-
给出一种取数的方式:
- 从区间 \([1, r - l + 1]\) 等概率地选取一个数 \(K\)。
- 从区间 \([l, r]\) 内等概率地选取 \(K\) 个数。
-
你需要回答:选取的所有数都大于等于 \(x\) 的概率。
答案对 \(998244353\) 取模。
数据范围:\(1 \leq n, m \leq 10^6\),\(1 \leq a_i, x \leq 10^5\)。
时空限制:\(1000 \ \text{ms} / 512 \ \text{MiB}\)。
Solution
对于任意一组询问 \((l, r, x)\),我们记区间长度为 \(L\),区间中大于等于 \(x\) 的数有 \(H\) 个。
那么答案为:
将组合数拆开,化简,再配成组合数:
重点在于后半段,把它写开就是:
观察到第一项组合数的下指标为 \(0\),那么可以把上指标调成 \(L - H + 1\),答案就可以变化为:
然后,由大家都知道的 \(\tbinom{n}{m} = \tbinom{n - 1}{m} + \tbinom{n - 1}{m - 1}\),答案就可以变化为:
然后继续:
继续继续继续继续继续 ......,最后得到:
然后原式就变化为:
先预处理出 \(1 \sim n\) 在模 \(998244353\) 意义下的逆元。如果知道了 \(L\) 和 \(H\) 就可以 \(\mathcal{O}(1)\) 算出答案。
于是,对于每一组询问 \((l, r, x)\),就是要求出区间 \([l, r]\) 内有多少个数大于等于 \(x\)。
在线的话就直接主席树;离线的话就直接从大到小加数,用 BIT 维护。
时间复杂度 \(\mathcal{O}((n + m) \log n)\),非常的优秀。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
inline int read() {
int x = 0, f = 1; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -f; s = getchar(); }
while (s >= '0' && s <= '9') { x = x * 10 + s - '0'; s = getchar(); }
return x * f;
}
int power(int a, int b, int p) {
int ans = 1;
for (; b; b >>= 1) {
if (b & 1) ans = 1ll * ans * a % p;
a = 1ll * a * a % p;
}
return ans;
}
const int N = 1000100;
const int SIZE = 100100;
const int mod = 998244353;
int n, m;
int a[N];
int inv[N];
struct range {
int l, r, id;
range() {}
range(int A, int B, int C) : l(A), r(B), id(C) {}
};
vector<int> H[SIZE];
vector<range> attend[SIZE];
int c[N];
void add(int x, int val) {
for (; x <= n; x += x & -x) c[x] += val;
}
int ask(int x) {
int ans = 0;
for (; x; x -= x & -x) ans += c[x];
return ans;
}
int cur[N];
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i ++)
inv[i] = power(i, mod - 2, mod);
for (int i = 1; i <= n; i ++)
a[i] = read();
for (int i = 1; i <= n; i ++)
H[a[i]].push_back(i);
for (int i = 1; i <= m; i ++) {
int l = read(), r = read(), x = read();
attend[x].push_back((range) { l, r, i });
}
for (int x = 100000; x >= 1; x --) {
for (int i = 0; i < (int)H[x].size(); i ++) {
int pos = H[x][i];
add(pos, 1);
}
for (int i = 0; i < (int)attend[x].size(); i ++) {
range G = attend[x][i];
int l = G.l, r = G.r;
int cnt = ask(r) - ask(l - 1),
len = r - l + 1;
int val = cnt;
val = 1ll * val * inv[len] % mod;
val = 1ll * val * inv[len - cnt + 1] % mod;
cur[G.id] = val;
}
}
for (int i = 1; i <= m; i ++)
printf("%d\n", cur[i]);
return 0;
}
此题严重暴露了我数学水平低下的弱点。