The 2024 ICPC Asia Nanjing Regional Contest - I. Bingo
考虑题目所提到的 Bingo 数 \(k\) 满足:
- 对于所有行,满足 \(\max\limits_{x \in row}\{x\} \le k\)
- 对于所有列,满足 \(\max\limits_{y \in col}\{y\} \le k\)
显然必然存在一个数恰好是 Bingo 数且满足该数所在的行和列均有:
- 对于所有行,满足 \(\max\limits_{x \in row_i}\{x\} = k\)
- 对于所有列,满足 \(\max\limits_{y \in col_j}\{y\} = k\)
也就是说我们的 Bingo 数为:
考虑 \(\min - \max\) 反演,记 \(R = \{row_1, \, row_2, \, \cdots, \, row_n\}\),\(C = \{col_1, \, col_2, \, \cdots, \, col_ m\}\),有:
设 \(E(S) = \bigcup\limits_{x \in T, \, T \in S} \{x\}\),显然可写作:
设 \(t = |S|m + |T|n - |S||T|\),则上式 \(\max\) 部分的组合意义为选取 \(t\) 个元素的最大值。由于结果需要求 \((NM)!\) 个排列的 Bingo 数之和,所以我们不关心最大值元素的位置,于是我们可以将给出的 \(\{a_1, \, a_2, \, \cdots, \, a_n\}\) 进行排序,不妨钦定最大值是 \(a_i\),则这 \(t\) 个数有 \(\dbinom{i - 1}{t - 1}\) 种选择方法,因此可以定义 Bingo 数为 \(a_i\) 的权值为:
考虑拆为卷积的形式有:
由于答案和 \(t\) 相关,所以我们需要将 \(i\) 消掉,不妨设 \(F(x) = \sum\limits_{i = 1}^{nm}a_i(i - 1)!\),\(G(x) = \sum\limits_{i = 0}^{nm}\dfrac{1}{i!}\),只需要将 \(G(x)\) 取反得到 \(G'(x) = \sum\limits_{i = 0}^{nm}\dfrac{1}{(nm - i)!}\),则原先预计的 \([x^t](F*G)\) 的系数会被卷到 \([x^{nm+t}](F*G')\) 上,乘上系数后映射到 \([x^t]H(x)\) 上,我们有:
考虑选取任意行任意列的方案数为 \(\dbinom{n}{i}\dbinom{m}{j}\),故答案即为:
时间复杂度 \(O(nm\log{nm})\),可以通过。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define i128 __int128
#define ull unsigned long long
#define clr(f, n) memset(f, 0, sizeof(int) * (n))
#define cpy(f, g, n) memcpy(f, g, sizeof(int) * (n))
#define rev(f, n) reverse(f, f + (n))
const int _G = 3, _i = 86583718, mod = 998244353, INF = 1e9;
const int N = 3e5 + 10;
ll qpow(ll a, ll k = mod - 2) {
ll res = 1;
while (k) {
if (k & 1) res = res * a % mod;
k >>= 1;
a = a * a % mod;
}
return res;
}
const int invG = qpow(_G), invi = qpow(_i), inv2 = qpow(2);
int rev[N << 1], rev_len;
void rev_init(int n) {
if (rev_len == n) return;
for (int i = 0; i < n; i ++ ) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
rev_len = n;
}
void NTT(int *g, int op, int n) {
rev_init(n);
static ull f[N << 1], Gk[N << 1] = {1};
for (int i = 0; i < n; i ++ ) f[i] = g[rev[i]];
for (int k = 1; k < n; k <<= 1) {
int G1 = qpow(~op ? _G : invG, (mod - 1) / (k << 1));
for (int i = 1; i < k; i ++ ) Gk[i] = Gk[i - 1] * G1 % mod;
for (int i = 0; i < n; i += k << 1) {
for (int j = 0; j < k; j ++ ) {
int tmp = Gk[j] * f[i | j | k] % mod;
f[i | j | k] = f[i | j] + mod - tmp;
f[i | j] += tmp;
}
}
if (k == (1 << 10)) for (int i = 0; i < n; i ++ ) f[i] %= mod;
}
if (~op) for (int i = 0; i < n; i ++ ) g[i] = f[i] % mod;
else {
int invn = qpow(n);
for (int i = 0; i < n; i ++ ) g[i] = f[i] % mod * invn % mod;
}
}
void px(int *f, int *g, int n) {
for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * g[i] % mod;
}
int inv[N << 1], inv_len;
void inv_init(int n) {
if (n <= inv_len) return;
if (!inv_len) inv[0] = inv[1] = 1, inv_len = 1;
for (int i = inv_len + 1; i <= n; i ++ ) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
inv_len = n;
}
void Poly_d(int *f, int n) {
for (int i = 1; i < n; i ++ ) f[i - 1] = 1ll * f[i] * i % mod;
f[n - 1] = 0;
}
void Poly_int(int *f, int n) {
for (int i = n; i; i -- ) f[i] = 1ll * f[i - 1] * inv[i] % mod;
f[0] = 0;
}
void covolution(int *f, int *g, int len, int lim) {
static int sav[N << 1];
int n; for (n = 1; n < len << 1; n <<= 1);
clr(sav, n); cpy(sav, g, n);
NTT(sav, 1, n); NTT(f, 1, n);
px(f, sav, n); NTT(f, -1, n);
clr(f + lim, n - lim), clr(sav, n);
}
void Poly_inv(int *f, int m) {
static int b1[N << 1], b2[N << 1], sav[N << 1];
int n; for (n = 1; n < m; n <<= 1);
b1[0] = qpow(f[0]);
for (int len = 2; len <= n; len <<= 1) {
cpy(b2, b1, len >> 1), cpy(sav, f, len);
NTT(b2, 1, len), NTT(sav, 1, len);
px(b2, sav, len); NTT(b2, -1, len);
clr(b2, len >> 1); cpy(sav, b1, len);
NTT(sav, 1, len); NTT(b2, 1, len);
px(b2, sav, len); NTT(b2, -1, len);
for (int i = len >> 1; i < len; i ++ ) b1[i] = (2ll * b1[i] - b2[i] + mod) % mod;
}
cpy(f, b1, m), clr(b1, n), clr(b2, n), clr(sav, n);
}
void Poly_sqrt(int *f, int m) {
static int b1[N << 1], b2[N << 1];
int n; for (n = 1; n < m; n <<= 1);
b1[0] = 1;
for (int len = 2; len <= n; len <<= 1) {
for (int i = 0; i < len >> 1; i ++ ) b2[i] = (b1[i] << 1) % mod;
Poly_inv(b2, len);
NTT(b1, 1, len); px(b1, b1, len); NTT(b1, -1, len);
for (int i = 0; i < len; i ++ ) b1[i] = (f[i] + b1[i]) % mod;
covolution(b1, b2, len, len);
}
cpy(f, b1, m); clr(b1, n << 1); clr(b2, n << 1);
}
void Poly_div(int *f, int *g, int n, int m) {
static int b1[N << 1], b2[N << 1];
int len = n - m + 1;
rev(f, n); cpy(b2, f, len); rev(f, n);
rev(g, m); cpy(b1, g, len); rev(g, m);
Poly_inv(b1, len); covolution(b1, b2, len, len); rev(b1, len);
covolution(g, b1, n, n);
for (int i = 0; i < m - 1; i ++ ) g[i] = (f[i] - g[i] + mod) % mod;
clr(g + m - 1, len); cpy(f, b1, len); clr(f + len, n - len);
}
void Poly_ln(int *f, int n) {
static int sav[N << 1];
inv_init(n); cpy(sav, f, n);
Poly_d(sav, n); Poly_inv(f, n);
covolution(f, sav, n, n); Poly_int(f, n - 1);
clr(sav, n);
}
void Poly_exp(int *f, int m) {
static int b1[N << 1], b2[N << 1];
int n; for (n = 1; n < m; n <<= 1);
b1[0] = 1;
for (int len = 2; len <= n; len <<= 1) {
cpy(b2, b1, len >> 1); Poly_ln(b2, len);
for (int i = 0; i < len; i ++ ) b2[i] = (f[i] - b2[i] + mod) % mod;
b2[0] = (b2[0] + 1) % mod;
covolution(b1, b2, len, len);
}
cpy(f, b1, m); clr(b1, n); clr(b2, n);
}
void Poly_qpow(int *f, int n, ll k) {
static int sav[N << 1];
int len1 = n, len2 = 1;
sav[0] = 1;
while (k) {
if (k & 1) {
covolution(f, sav, len1 + len2 >> 1, len1 + len2 - 1);
len1 = len1 + len2 - 1;
}
covolution(sav, sav, len2, (len2 << 1) - 1);
len2 = (len2 << 1) - 1;
k >>= 1;
}
clr(sav, len2);
}
void Poly_pow(int *f, int n, string k) {
int k1 = 0, k2 = 0, p = 0, c;
while (!f[p]) p ++ ;
for (int i = 0; k[i]; i ++ ) {
k1 = (10ll * k1 + k[i] - '0') % mod;
k2 = (10ll * k2 + k[i] - '0') % (mod - 1);
if (1ll * k1 * p >= n) return clr(f, n), void();
}
n -= p * k1; c = qpow(f[p]);
for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i + p] * c % mod;
clr(f + n, p * k1); Poly_ln(f, n);
for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * k1 % mod;
Poly_exp(f, n); c = qpow(c, mod - 1 - k2);
for (int i = n - 1; i >= 0; i -- ) f[p * k1 + i] = 1ll * f[i] * c % mod;
clr(f, p * k1);
}
void Poly_sin(int *f, int n) {
static int sav[N << 1];
for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * _i % mod;
for (int i = 0; i < n; i ++ ) sav[i] = (mod - f[i]) % mod;
Poly_exp(f, n); Poly_exp(sav, n);
for (int i = 0; i < n; i ++ ) f[i] = 1ll * (f[i] - sav[i] + mod) * invi % mod * inv2 % mod;
clr(sav, n);
}
void Poly_cos(int *f, int n) {
static int sav[N << 1];
for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * _i % mod;
for (int i = 0; i < n; i ++ ) sav[i] = (mod - f[i]) % mod;
Poly_exp(f, n); Poly_exp(sav, n);
for (int i = 0; i < n; i ++ ) f[i] = 1ll * (f[i] + sav[i]) * inv2 % mod;
clr(sav, n);
}
void CDQ(int *f, int *g, int l, int r) {
static int b1[N << 1], b2[N << 1];
if (l == r) {
if (!l) f[l] = 1;
return;
}
int mid = l + r >> 1;
CDQ(f, g, l, mid);
int n; for (n = 1; n <= r - l + 1; n <<= 1);
cpy(b1, f + l, mid - l + 1); clr(b1 + mid - l + 1, n - (mid - l));
cpy(b2, g, r - l + 1); clr(b2 + r - l + 1, n - (r - l));
NTT(b1, 1, n); NTT(b2, 1, n); px(b1, b2, n); NTT(b1, -1, n);
for (int i = mid + 1; i <= r; i ++ ) f[i] = (f[i] + b1[i - l]) % mod;
clr(b1, n); clr(b2, n);
CDQ(f, g, mid + 1, r);
}
void Poly_print(int *f, int n) {
for (int i = 0; i < n; i ++ ) cout << f[i] << " \n"[i == n - 1];
}
int n, m;
int a[N], F[N << 1], G[N << 1], ans[N];
ll frac[N], inv_frac[N];
ll C(int n, int m) {
if (n < m) return 0;
return frac[n] * inv_frac[m] % mod * inv_frac[n - m] % mod;
}
// ai * (i - 1)! * (i - t)! * (nm - t)! * t
// a(i) * b(nm - t) -> c(nm + i - t)
// F[i] -> ans[i - t]
void solve() {
cin >> n >> m;
frac[0] = inv_frac[0] = 1;
for (int i = 1; i <= n * m; i ++ ) frac[i] = frac[i - 1] * i % mod;
inv_frac[n * m] = qpow(frac[n * m]);
for (int i = n * m - 1; i; i -- ) inv_frac[i] = inv_frac[i + 1] * (i + 1) % mod;
for (int i = 1; i <= n * m; i ++ ) cin >> a[i];
sort(a + 1, a + n * m + 1);
for (int i = 0; i <= n * m; i ++ ) {
if (i) F[i] = a[i] * frac[i - 1] % mod;
G[i] = inv_frac[n * m - i];
}
covolution(F, G, n * m + 1, n * m * 2 + 1);
for (int i = 1; i <= n * m; i ++ ) ans[i] = F[n * m + i] * frac[n * m - i] % mod * i % mod;
clr(F, n * m * 2 + 1), clr(G, n * m * 2 + 1);
ll res = 0;
for (int i = 0; i <= n; i ++ ) {
for (int j = 0; j <= m; j ++ ) {
int t = i * m + j * n - i * j;
if ((i + j) & 1) res = (res + ans[t] * C(n, i) % mod * C(m, j) % mod) % mod;
else res = (res - ans[t] * C(n, i) % mod * C(m, j) % mod + mod) % mod;
}
}
cout << res << "\n";
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int T = 1;
cin >> T;
while (T -- ) solve();
return 0;
}

浙公网安备 33010602011771号