洛谷P4859 已经没有什么好害怕的了 题解 DP + 二项式反演

题目链接:https://www.luogu.com.cn/problem/P4859

给你两个长度为 \(n\) 的数列 \(a\)\(b\),求数列 \(a\) 存在多少个排列满足:

\(a_i \gt b_i\) 的下标数量” \(-\)\(a_i \lt b_i\) 的下标数量” 恰好等于 \(k\)

即满足

\[\sum_{i=1}^n [ a_i \gt b_i ] - \sum_{i=1}^n [a_i \lt b_i] = k \]


注:虽然题目没有说,但是我一开始加了 断言 测试了一下,发现:

  1. \(n + k\) 是偶数;
  2. \(a\) 中元素各不相同;
  3. \(b\) 中元素也各不相同;
  4. 不存在 \(a_i = b_j\) 的情况。

解题思路:

首先我们设 \(x\) 表示满足 “\(a_i \gt b_i\)” 的对数。

则有 \(x - (n - x) = k\),得 \(x = \frac{n+k}{2}\)

首先我们分别给数列 \(a\) 和 数列 \(b\) 从小到大排个序。

然后我们就可以使用双指针在 \(O(n)\) 时间复杂度内求出所有的 \(d_i\),他表示 \(b_1, b_2, \ldots, b_n\) 中存在多少个 \(b_j \lt a_i\)

定义状态 \(g_{i, j}\) 表示 \(a[1..i]\) 中选了 至少 \(j\) 组(一组指的是选出一个 \(a_i\) 和一个 \(b_j\) 且满足 \(a_i \gt b_j\))的方案数,则:

  • 如果不考虑给 \(a_i\) 配对,则 \(g_{i,j}\) 的前一个状态是 \(g_{i-1, j}\)
  • 如果考虑给 \(a_i\) 配对,则可以选择的 \(b_j\) 的个数有 \(d_i\) 个,但是其中有 \(j-1\) 个已经和 \(a[1..i-1]\) 中的元素配对了,所以可以选的未使用的 \(b_j\) 个数是 \(\max(0, d_i - (j-1) )\) 个。此时用 \(\binom{\max(0, d_i - (j-1) )}{1} = \max(0, d_i - (j-1) )\) 种选择。

所以,我们得到最终的状态转移方程为:

  • \(g_{i, 0} = 1\)
  • \(i \gt 0\) 时,\(g_{i, j} = g_{i-1, j} + \max(0, d_i-(j-1)) \cdot g_{i-1,j-1}\)

但是我们最终求的是恰好 \(k\) 组的方案数,而不是至少 \(k\) 组的方案数(\(g_{n, x}\))。

如果我们定义状态 \(f_{i, j}\) 表示 \(a[1..i]\) 恰好 \(j\) 组的方案数,则 \(f_{n, x}\) 就是答案。

对于状态 \(g_{n, i}\),其中有 \(i\) 对配对了,剩余的 \(n-i\) 对可以任选,对于排列方案数 \((n-1)!\)

\[g_{n, x} \cdot (n-i)! = \sum_{i = x} ^ n f_{n, i} \]

其中,\(n\) 是固定的。所以可以使用 二项式反演得到

\[f_{n, x} = \sum_{i = x}^n (-1)^{i-x} \binom{i}{x} g_{n, i} \cdot (n-i)! \]

能够得到答案 \(f_{n, x}\)

时间复杂度 \(O(n^2)\)

示例程序:

#include <bits/stdc++.h>
using namespace std;
const long long mod = 1e9 + 9;
const int maxn = 2005;

int n, k, x, a[maxn], b[maxn], d[maxn];
long long ans, g[maxn][maxn], c[maxn][maxn], fac[maxn] = {1};

void init(int n) {
    for (int i = 1; i <= n; i++)
        fac[i] = fac[i-1] * i % mod;
    for (int i = 0; i <= n; i++) {
        for (int j = 0; j <= i; j++) {
            if (j == 0 || j == i)
                c[i][j] = 1;
            else
                c[i][j] = (c[i-1][j-1] + c[i-1][j]) % mod;
        }
    }
}

int flag(int a) {
    return a % 2 ? -1 : 1;
}

int main() {
    cin >> n >> k;
    init(n);
    x = (n + k) / 2;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= n; i++) cin >> b[i];
    sort(a+1, a+n+1);
    sort(b+1, b+n+1);
    for (int i = 1, j = 0; i <= n; i++) {
        while (j < n && a[i] > b[j+1])
            j++;
        d[i] = j;
    }
    for (int i = 0; i <= n; i++) {
        g[i][0] = 1;
        for (int j = 1; j <= i; j++) {
            g[i][j] = (g[i-1][j] + max(0, d[i]-(j-1)) * g[i-1][j-1]) % mod;
        }
    }
    for (int i = x; i <= n; i++)
        ans = (ans + flag(i - x) * (c[i][x] % mod * g[n][i] % mod * fac[n-i] % mod) + mod) % mod;
    cout << ans << endl;
    return 0;
}
posted @ 2026-04-06 17:07  quanjun  阅读(1)  评论(0)    收藏  举报