模拟赛T4 分析
题目概述
随机 \(2n\) 个数,值域为 \([0,m]\),求前 \(n\) 个数比后 \(n\) 个数大的概率(对质数 \(P\) 取模),其中 \(10^8\leq P\leq 10^9\)。
数据范围:\(1\leq n,m,T\leq 2000\)。
分析
好好玩。
显然可以转化为计数题目。
赛时想了一个 \(\mathcal{O}(Tn^2m)\) 做法。
首先我们不难想到设 \(f_{i,j}\) 表示前 \(i\) 个数和为 \(j\) 的方案。转移是简单的,求答案也是简单的:
两者都可以用前缀和优化。
正解肯定不是这么乱搞的。
首先可以分成三种情况:
- 前面 \(n\) 个小于后面 \(n\) 个。
- 前面 \(n\) 个等于后面 \(n\) 个。
- 前面 \(n\) 个大于后面 \(n\) 个。
我们发现第一个和第三个其实是一样的。
于是我们的答案就是总共方案减去等于的方案再除以 \(2\)。
我们发现等于的方案很难去做,很难刻画。
考虑做一些精妙的转变:
将后 \(n\) 个数取反,相当于前面的数的值域为 \([0,m]\),后面的为 \([-m,0]\),等于就变成了相加为 \(0\)。
再将后 \(n\) 个数每个数加上 \(m\),相当于前面的数的值域为 \([0,m]\),后面的数也是,相加为 \(0\) 就变成了相加为 \(nm\)。
我们发现这个是好做的,至于怎么想到的,确实很难,不过灵光一现还是有可能的。
我们现在假设去掉这个限制(经典容斥片头曲),那么相当于在 \(2n\) 个非负数中得到的和为 \(nm\) 的方案。
现在转化成了这个问题:
有 \(2n\) 个非负整数 \(x_i\),且需要满足 \(x_i\leq m\),那么请问有多少组 \((x_1,\dots,x_{2n})\) 的解满足 \(\sum_{i=1}^{2n} x_i=nm\)。可以去看我对经典例题的题解:https://www.cnblogs.com/high-sky/p/19151402
听说这个问题 Oi-Wiki 上是有的,我们推导一遍加深印象。
考虑没有限制的情况。
那么其实就是插板分组求每组的个数一一对应 \(x_i\),方案是 \(C_{m+2n-1}^{2n-1}\)。
怎么理解呢?
如果是正整数的情况直接 \(C_{m-1}^{2n}\) 就可以了,但是现在有 \(0\) 的情况——即两个板可以插在同一个位置上,于是多 \(2n\) 个位置就行了。
现在考虑有限制的情况。
那么我们直接快乐容斥就行了。
我们可以先算出其中不满足限制的至少是 \(0\) 个的,然后再减去不满足限制至少是 \(1\) 个的……以此类推(很像二项式反演)。
那么我们现在要考虑的是不满足限制至少是 \(k\) 个的答案怎么计算。
首先,不满足的限制为:\(x_i>m\)。
那么也就是说 \(x_i\) 至少为 \(m+1\)。
一个巧妙的地方在于:因为我是至少 \(k\) 个不满足限制,所以说,我先把我选出来的不合法的数全部令为 \(m+1\),即使可能的实际值比这个大,我也可以暂且先分给其他满足条件的位置,这样虽说可能变成不满足限制的了,但是我定义的毕竟是至少嘛,所以说不影响。(这个可以当作一个trick,很牛)
所以说最后我们要求的便是:
代码
时间复杂度 \(\mathcal{O}(Tn)\).
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#define int long long
#define N 5000005
using namespace std;
int mod;
int qpow(int a,int b) {
int res = 1;
while(b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int jc[N],inv[N];
int C(int a,int b) {
if (a < 0 || b < 0 || a < b) return 0;
return jc[a] * inv[b] % mod * inv[a - b] % mod;
}
signed main() {
// freopen("pr.in","r",stdin),freopen("pr.out","w",stdout);
cin >> mod;
jc[0] = jc[1] = inv[0] = inv[1] = 1;
for (int i = 2;i < N;i ++) jc[i] = jc[i - 1] * i % mod,inv[i] = (mod - mod / i) * inv[mod % i] % mod;
for (int i = 2;i < N;i ++) inv[i] = inv[i - 1] * inv[i] % mod;
// cout << C(5,3) <<'f';
int T;
cin >> T;
int n,m;
int inv2 = mod / 2 + 1;
for (;T--;) {
scanf("%lld%lld",&n,&m);
int ans = 0;
for (int i = 0,t = 1;i <= 2 * n;i ++,t = -t)
ans = (ans + (t * C(2 * n,i) % mod * C(n * m - (m + 1) * i + 2 * n - 1,2 * n - 1) % mod + mod) % mod) % mod;
printf("%lld\n",((qpow(m + 1,2 * n) - ans + mod) % mod) * inv2 % mod * qpow(qpow(m + 1,2 * n),mod - 2) % mod);
}
return 0;
}

浙公网安备 33010602011771号