QOJ 10174. Lost Table
Statements
有一个 \(n \times m\) 的矩阵,现在你知道每行的最大值 \(a_i\) 和每列的最大值 \(b_i\),求有多少种满足条件的矩阵。
Solution
交换行和交换列对答案无影响,所以先对 \(a,b\) 排序。
设 \(c_{i,j} = \min(a_i, b_j)\) 即 \((i,j)\) 填数的上界。
考虑排序后 \(c\) 矩阵形成了若干个 L 形的东西,里面的数相同,假设为 \(t\):
则总共有 \(S=am+bn-ab\) 个数为 \(t\)。
注意这里的 \(n,m\) 指的是 L 形的边长而不是原题大矩阵的边长。
然后考虑容斥计算“每个位置都 \(\leq t\) 且每行每列至少有一个为 \(t\)”的方案数。
设 \(g_{i,j}\) 为至少有 \(i\) 行 \(j\) 列没有元素顶到最大值的方案数,\(f_{i,j}\) 为恰好,容斥推一下可以得到对于 \(t\) 的答案:
\[\sum\limits_{i=0}^{a} \sum\limits_{j=0}^{b} (-1)^{i+j} \binom{a}{i} \binom{b}{j} t^{im+jn-ij} (t+1)^{S-im+jn-ij}
\]
可以做到 \(O(nm \log nm)\),提出和 \(i\) 相关项:
\[\sum\limits_{i=0}^{a} (-1)^{i} \binom{a}{i} (t+1)^{S} \sum\limits_{j=0}^{b} (-1)^{j} \binom{b}{j} (\frac{t}{t+1})^{im+jn-ij}
\]
\[\sum\limits_{i=0}^{a} (-1)^{i} \binom{a}{i} (t+1)^{S} (\frac{t}{t+1})^{im} \sum\limits_{j=0}^{b} (-1)^{j} \binom{b}{j} \Big[(\frac{t}{t+1})^{n-i}\Big]^{j}
\]
发现后面的 \(\sum\) 是一个二项式定理:
\[\sum\limits_{i=0}^{a} (-1)^{i} \binom{a}{i} (t+1)^{S} (\frac{t}{t+1})^{im} \Big(1 - (\frac{t}{t+1})^{n-i}\Big)^{b}
\]
于是做到了 \(O((n+m) \log nm)\)。
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5, mod = 1e9 + 7;
int n, m, a[N], b[N], arr[N << 1], tot;
inline int qmi(int a, long long k) {
int res = 1;
while (k) {
if (k & 1) res = res * 1ll * a % mod;
a = a * 1ll * a % mod, k >>= 1;
} return res;
}
int fac[N], inv[N];
inline void init() {
fac[0] = 1; for (int i = 1; i <= 200000; i++) fac[i] = fac[i - 1] * 1ll * i % mod;
inv[200000] = qmi(fac[200000], mod - 2);
for (int i = 199999; i >= 0; i--) inv[i] = inv[i + 1] * 1ll * (i + 1) % mod;
}
inline int C(int n, int m) { return fac[n] * 1ll * inv[m] % mod * 1ll * inv[n - m] % mod; }
int main() {
init();
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), a[i]--, arr[i] = a[i];
for (int i = 1; i <= m; i++) scanf("%d", &b[i]), b[i]--, arr[i + n] = b[i];
sort(a + 1, a + 1 + n), sort(b + 1, b + 1 + m);
sort(arr + 1, arr + 1 + n + m), tot = unique(arr + 1, arr + 1 + n + m) - arr - 1;
int x = 1, y = 1;
long long ans = 1;
for (int e = 1; e <= tot; e++) {
int a = 0, b = 0, t = arr[e];
int nn = n - x + 1, mm = m - y + 1;
while (x <= n && ::a[x] == t) x++, a++;
while (y <= m && ::b[y] == t) y++, b++;
long long S = a * 1ll * mm + b * 1ll * nn - a * 1ll * b, base = t * 1ll * qmi(t + 1, mod - 2) % mod;
long long res = 0;
for (int i = 0; i <= a; i++) {
long long v = qmi(mod + 1 - qmi(base, nn - i), b);
if (i & 1) v = mod - v;
(v *= C(a, i)) %= mod, (v *= qmi(t + 1, S)) %= mod, (v *= qmi(base, i * 1ll * mm)) %= mod;
(res += v) %= mod;
}
(ans *= res) %= mod;
} printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号