洛谷P4859 已经没有什么好害怕的了 题解 DP + 二项式反演
题目链接:https://www.luogu.com.cn/problem/P4859
给你两个长度为 \(n\) 的数列 \(a\) 和 \(b\),求数列 \(a\) 存在多少个排列满足:
“\(a_i \gt b_i\) 的下标数量” \(-\) “\(a_i \lt b_i\) 的下标数量” 恰好等于 \(k\)。
即满足
注:虽然题目没有说,但是我一开始加了 断言 测试了一下,发现:
- \(n + k\) 是偶数;
- \(a\) 中元素各不相同;
- \(b\) 中元素也各不相同;
- 不存在 \(a_i = b_j\) 的情况。
解题思路:
首先我们设 \(x\) 表示满足 “\(a_i \gt b_i\)” 的对数。
则有 \(x - (n - x) = k\),得 \(x = \frac{n+k}{2}\)。
首先我们分别给数列 \(a\) 和 数列 \(b\) 从小到大排个序。
然后我们就可以使用双指针在 \(O(n)\) 时间复杂度内求出所有的 \(d_i\),他表示 \(b_1, b_2, \ldots, b_n\) 中存在多少个 \(b_j \lt a_i\)。
定义状态 \(g_{i, j}\) 表示 \(a[1..i]\) 中选了 至少 \(j\) 组(一组指的是选出一个 \(a_i\) 和一个 \(b_j\) 且满足 \(a_i \gt b_j\))的方案数,则:
- 如果不考虑给 \(a_i\) 配对,则 \(g_{i,j}\) 的前一个状态是 \(g_{i-1, j}\);
- 如果考虑给 \(a_i\) 配对,则可以选择的 \(b_j\) 的个数有 \(d_i\) 个,但是其中有 \(j-1\) 个已经和 \(a[1..i-1]\) 中的元素配对了,所以可以选的未使用的 \(b_j\) 个数是 \(\max(0, d_i - (j-1) )\) 个。此时用 \(\binom{\max(0, d_i - (j-1) )}{1} = \max(0, d_i - (j-1) )\) 种选择。
所以,我们得到最终的状态转移方程为:
- \(g_{i, 0} = 1\);
- 当 \(i \gt 0\) 时,\(g_{i, j} = g_{i-1, j} + \max(0, d_i-(j-1)) \cdot g_{i-1,j-1}\)
但是我们最终求的是恰好 \(k\) 组的方案数,而不是至少 \(k\) 组的方案数(\(g_{n, x}\))。
如果我们定义状态 \(f_{i, j}\) 表示 \(a[1..i]\) 恰好 \(j\) 组的方案数,则 \(f_{n, x}\) 就是答案。
对于状态 \(g_{n, i}\),其中有 \(i\) 对配对了,剩余的 \(n-i\) 对可以任选,对于排列方案数 \((n-1)!\)。
而
其中,\(n\) 是固定的。所以可以使用 二项式反演得到
能够得到答案 \(f_{n, x}\)。
时间复杂度 \(O(n^2)\)。
示例程序:
#include <bits/stdc++.h>
using namespace std;
const long long mod = 1e9 + 9;
const int maxn = 2005;
int n, k, x, a[maxn], b[maxn], d[maxn];
long long ans, g[maxn][maxn], c[maxn][maxn], fac[maxn] = {1};
void init(int n) {
for (int i = 1; i <= n; i++)
fac[i] = fac[i-1] * i % mod;
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= i; j++) {
if (j == 0 || j == i)
c[i][j] = 1;
else
c[i][j] = (c[i-1][j-1] + c[i-1][j]) % mod;
}
}
}
int flag(int a) {
return a % 2 ? -1 : 1;
}
int main() {
cin >> n >> k;
init(n);
x = (n + k) / 2;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i <= n; i++) cin >> b[i];
sort(a+1, a+n+1);
sort(b+1, b+n+1);
for (int i = 1, j = 0; i <= n; i++) {
while (j < n && a[i] > b[j+1])
j++;
d[i] = j;
}
for (int i = 0; i <= n; i++) {
g[i][0] = 1;
for (int j = 1; j <= i; j++) {
g[i][j] = (g[i-1][j] + max(0, d[i]-(j-1)) * g[i-1][j-1]) % mod;
}
}
for (int i = x; i <= n; i++)
ans = (ans + flag(i - x) * (c[i][x] % mod * g[n][i] % mod * fac[n-i] % mod) + mod) % mod;
cout << ans << endl;
return 0;
}
浙公网安备 33010602011771号