The 2024 ICPC Asia Nanjing Regional Contest - I. Bingo

题目连接

考虑题目所提到的 Bingo 数 \(k\) 满足:

  • 对于所有行,满足 \(\max\limits_{x \in row}\{x\} \le k\)
  • 对于所有列,满足 \(\max\limits_{y \in col}\{y\} \le k\)

显然必然存在一个数恰好是 Bingo 数且满足该数所在的行和列均有:

  • 对于所有行,满足 \(\max\limits_{x \in row_i}\{x\} = k\)
  • 对于所有列,满足 \(\max\limits_{y \in col_j}\{y\} = k\)

也就是说我们的 Bingo 数为:

\[\min_{\substack{1 \le i \le n \\ 1 \le j \le m}} \left\{\max \left(\max_{x \in row_i}\{x\}, \, \max_{y \in col_j}\{y\}\right)\right\} \]

考虑 \(\min - \max\) 反演,记 \(R = \{row_1, \, row_2, \, \cdots, \, row_n\}\)\(C = \{col_1, \, col_2, \, \cdots, \, col_ m\}\),有:

\[\min_{\substack{1 \le i \le n \\ 1 \le j \le m}} \left\{\max \left(\max_{x \in row_i}\{x\}, \, \max_{y \in col_j}\{y\}\right)\right\} = \sum_{S \in R}\sum_{T \in C}(-1)^{|S| + |T| - 1}\max_{\substack{row \in S \\ col \in T}}\left\{\max \left(\max_{x \in row}\{x\}, \, \max_{y \in col}\{y\}\right)\right\} \]

\(E(S) = \bigcup\limits_{x \in T, \, T \in S} \{x\}\),显然可写作:

\[\sum_{S \in R}\sum_{T \in C}(-1)^{|S| + |T| - 1}\max_{x \in E(S) \cup E(T)}\{x\} \]

\(t = |S|m + |T|n - |S||T|\),则上式 \(\max\) 部分的组合意义为选取 \(t\) 个元素的最大值。由于结果需要求 \((NM)!\) 个排列的 Bingo 数之和,所以我们不关心最大值元素的位置,于是我们可以将给出的 \(\{a_1, \, a_2, \, \cdots, \, a_n\}\) 进行排序,不妨钦定最大值是 \(a_i\),则这 \(t\) 个数有 \(\dbinom{i - 1}{t - 1}\) 种选择方法,因此可以定义 Bingo 数为 \(a_i\) 的权值为:

\[a_i\binom{i - 1}{t - 1}t!(nm - t)! \]

考虑拆为卷积的形式有:

\[a_i(i - 1)! \times \frac{1}{(i - t)!} \times t(nm - t)! \]

由于答案和 \(t\) 相关,所以我们需要将 \(i\) 消掉,不妨设 \(F(x) = \sum\limits_{i = 1}^{nm}a_i(i - 1)!\)\(G(x) = \sum\limits_{i = 0}^{nm}\dfrac{1}{i!}\),只需要将 \(G(x)\) 取反得到 \(G'(x) = \sum\limits_{i = 0}^{nm}\dfrac{1}{(nm - i)!}\),则原先预计的 \([x^t](F*G)\) 的系数会被卷到 \([x^{nm+t}](F*G')\) 上,乘上系数后映射到 \([x^t]H(x)\) 上,我们有:

\[\sum_{S \in R}\sum_{T \in C}(-1)^{|S| + |T| - 1}[x^t]h(x) \]

考虑选取任意行任意列的方案数为 \(\dbinom{n}{i}\dbinom{m}{j}\),故答案即为:

\[\sum_{i = 0}^{n}\sum_{j = 0}^{m}(-1)^{i + j - 1}\binom{n}{i}\binom{m}{j}[x^{im+jn-ij}]h(x) \]

时间复杂度 \(O(nm\log{nm})\),可以通过。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define i128 __int128
#define ull unsigned long long
#define clr(f, n) memset(f, 0, sizeof(int) * (n))
#define cpy(f, g, n) memcpy(f, g, sizeof(int) * (n))
#define rev(f, n) reverse(f, f + (n))
const int _G = 3, _i = 86583718, mod = 998244353, INF = 1e9;
const int N = 3e5 + 10;

ll qpow(ll a, ll k = mod - 2) {
    ll res = 1;
    while (k) {
        if (k & 1) res = res * a % mod;
        k >>= 1;
        a = a * a % mod;
    }
    return res;
}

const int invG = qpow(_G), invi = qpow(_i), inv2 = qpow(2);
int rev[N << 1], rev_len;

void rev_init(int n) {
    if (rev_len == n) return;
    for (int i = 0; i < n; i ++ ) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
    rev_len = n;
}

void NTT(int *g, int op, int n) {
    rev_init(n);
    static ull f[N << 1], Gk[N << 1] = {1};
    for (int i = 0; i < n; i ++ ) f[i] = g[rev[i]];
    for (int k = 1; k < n; k <<= 1) {
        int G1 = qpow(~op ? _G : invG, (mod - 1) / (k << 1));
        for (int i = 1; i < k; i ++ ) Gk[i] = Gk[i - 1] * G1 % mod;
        for (int i = 0; i < n; i += k << 1) {
            for (int j = 0; j < k; j ++ ) {
                int tmp = Gk[j] * f[i | j | k] % mod;
                f[i | j | k] = f[i | j] + mod - tmp;
                f[i | j] += tmp;
            }
        }
        if (k == (1 << 10)) for (int i = 0; i < n; i ++ ) f[i] %= mod;
    }
    if (~op) for (int i = 0; i < n; i ++ ) g[i] = f[i] % mod;
    else {
        int invn = qpow(n);
        for (int i = 0; i < n; i ++ ) g[i] = f[i] % mod * invn % mod;
    }
}

void px(int *f, int *g, int n) {
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * g[i] % mod;
}

int inv[N << 1], inv_len;

void inv_init(int n) {
    if (n <= inv_len) return;
    if (!inv_len) inv[0] = inv[1] = 1, inv_len = 1;
    for (int i = inv_len + 1; i <= n; i ++ ) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
    inv_len = n;
}

void Poly_d(int *f, int n) {
    for (int i = 1; i < n; i ++ ) f[i - 1] = 1ll * f[i] * i % mod;
    f[n - 1] = 0;
}

void Poly_int(int *f, int n) {
    for (int i = n; i; i -- ) f[i] = 1ll * f[i - 1] * inv[i] % mod;
    f[0] = 0;
}

void covolution(int *f, int *g, int len, int lim) {
    static int sav[N << 1];
    int n; for (n = 1; n < len << 1; n <<= 1);
    clr(sav, n); cpy(sav, g, n);
    NTT(sav, 1, n); NTT(f, 1, n);
    px(f, sav, n); NTT(f, -1, n);
    clr(f + lim, n - lim), clr(sav, n);
}

void Poly_inv(int *f, int m) {
    static int b1[N << 1], b2[N << 1], sav[N << 1];
    int n; for (n = 1; n < m; n <<= 1);
    b1[0] = qpow(f[0]);
    for (int len = 2; len <= n; len <<= 1) {
        cpy(b2, b1, len >> 1), cpy(sav, f, len);
        NTT(b2, 1, len), NTT(sav, 1, len);
        px(b2, sav, len); NTT(b2, -1, len);
        clr(b2, len >> 1); cpy(sav, b1, len);
        NTT(sav, 1, len); NTT(b2, 1, len);
        px(b2, sav, len); NTT(b2, -1, len);
        for (int i = len >> 1; i < len; i ++ ) b1[i] = (2ll * b1[i] - b2[i] + mod) % mod;
    }
    cpy(f, b1, m), clr(b1, n), clr(b2, n), clr(sav, n);
}

void Poly_sqrt(int *f, int m) {
    static int b1[N << 1], b2[N << 1];
    int n; for (n = 1; n < m; n <<= 1);
    b1[0] = 1;
    for (int len = 2; len <= n; len <<= 1) {
        for (int i = 0; i < len >> 1; i ++ ) b2[i] = (b1[i] << 1) % mod;
        Poly_inv(b2, len);
        NTT(b1, 1, len); px(b1, b1, len); NTT(b1, -1, len);
        for (int i = 0; i < len; i ++ ) b1[i] = (f[i] + b1[i]) % mod;
        covolution(b1, b2, len, len);
    }
    cpy(f, b1, m); clr(b1, n << 1); clr(b2, n << 1);
}

void Poly_div(int *f, int *g, int n, int m) {
    static int b1[N << 1], b2[N << 1];
    int len = n - m + 1;
    rev(f, n); cpy(b2, f, len); rev(f, n);
    rev(g, m); cpy(b1, g, len); rev(g, m);
    Poly_inv(b1, len); covolution(b1, b2, len, len); rev(b1, len);
    covolution(g, b1, n, n);
    for (int i = 0; i < m - 1; i ++ ) g[i] = (f[i] - g[i] + mod) % mod;
    clr(g + m - 1, len); cpy(f, b1, len); clr(f + len, n - len);
}

void Poly_ln(int *f, int n) {
    static int sav[N << 1];
    inv_init(n); cpy(sav, f, n);
    Poly_d(sav, n); Poly_inv(f, n);
    covolution(f, sav, n, n); Poly_int(f, n - 1);
    clr(sav, n);
}

void Poly_exp(int *f, int m) {
    static int b1[N << 1], b2[N << 1];
    int n; for (n = 1; n < m; n <<= 1);
    b1[0] = 1;
    for (int len = 2; len <= n; len <<= 1) {
        cpy(b2, b1, len >> 1); Poly_ln(b2, len);
        for (int i = 0; i < len; i ++ ) b2[i] = (f[i] - b2[i] + mod) % mod;
        b2[0] = (b2[0] + 1) % mod;
        covolution(b1, b2, len, len);
    }
    cpy(f, b1, m); clr(b1, n); clr(b2, n);
}

void Poly_qpow(int *f, int n, ll k) {
    static int sav[N << 1];
    int len1 = n, len2 = 1;
    sav[0] = 1;
    while (k) {
        if (k & 1) {
            covolution(f, sav, len1 + len2 >> 1, len1 + len2 - 1);
            len1 = len1 + len2 - 1;
        }
        covolution(sav, sav, len2, (len2 << 1) - 1);
        len2 = (len2 << 1) - 1;
        k >>= 1;
    }
    clr(sav, len2);
}

void Poly_pow(int *f, int n, string k) {
    int k1 = 0, k2 = 0, p = 0, c;
    while (!f[p]) p ++ ;
    for (int i = 0; k[i]; i ++ ) {
        k1 = (10ll * k1 + k[i] - '0') % mod;
        k2 = (10ll * k2 + k[i] - '0') % (mod - 1);
        if (1ll * k1 * p >= n) return clr(f, n), void();
    }
    n -= p * k1; c = qpow(f[p]);
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i + p] * c % mod;
    clr(f + n, p * k1); Poly_ln(f, n); 
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * k1 % mod;
    Poly_exp(f, n); c = qpow(c, mod - 1 - k2);
    for (int i = n - 1; i >= 0; i -- ) f[p * k1 + i] = 1ll * f[i] * c % mod;
    clr(f, p * k1);
}

void Poly_sin(int *f, int n) {
    static int sav[N << 1];
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * _i % mod;
    for (int i = 0; i < n; i ++ ) sav[i] = (mod - f[i]) % mod;
    Poly_exp(f, n); Poly_exp(sav, n);
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * (f[i] - sav[i] + mod) * invi % mod * inv2 % mod;
    clr(sav, n);
}

void Poly_cos(int *f, int n) {
    static int sav[N << 1];
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * f[i] * _i % mod;
    for (int i = 0; i < n; i ++ ) sav[i] = (mod - f[i]) % mod;
    Poly_exp(f, n); Poly_exp(sav, n);
    for (int i = 0; i < n; i ++ ) f[i] = 1ll * (f[i] + sav[i]) * inv2 % mod;
    clr(sav, n);
}

void CDQ(int *f, int *g, int l, int r) {
    static int b1[N << 1], b2[N << 1];
    if (l == r) {
        if (!l) f[l] = 1;
        return;
    }
    int mid = l + r >> 1;
    CDQ(f, g, l, mid);
    int n; for (n = 1; n <= r - l + 1; n <<= 1);
    cpy(b1, f + l, mid - l + 1); clr(b1 + mid - l + 1, n - (mid - l));
    cpy(b2, g, r - l + 1); clr(b2 + r - l + 1, n - (r - l));
    NTT(b1, 1, n); NTT(b2, 1, n); px(b1, b2, n); NTT(b1, -1, n);
    for (int i = mid + 1; i <= r; i ++ ) f[i] = (f[i] + b1[i - l]) % mod;
    clr(b1, n); clr(b2, n);
    CDQ(f, g, mid + 1, r);
}

void Poly_print(int *f, int n) {
    for (int i = 0; i < n; i ++ ) cout << f[i] << " \n"[i == n - 1];
}

int n, m;
int a[N], F[N << 1], G[N << 1], ans[N];
ll frac[N], inv_frac[N];

ll C(int n, int m) {
    if (n < m) return 0;
    return frac[n] * inv_frac[m] % mod * inv_frac[n - m] % mod;
}

// ai * (i - 1)! * (i - t)! * (nm - t)! * t
// a(i) * b(nm - t) -> c(nm + i - t)
// F[i] -> ans[i - t]

void solve() {
    cin >> n >> m;
    frac[0] = inv_frac[0] = 1;
    for (int i = 1; i <= n * m; i ++ ) frac[i] = frac[i - 1] * i % mod;
    inv_frac[n * m] = qpow(frac[n * m]);
    for (int i = n * m - 1; i; i -- ) inv_frac[i] = inv_frac[i + 1] * (i + 1) % mod;
    for (int i = 1; i <= n * m; i ++ ) cin >> a[i];
    sort(a + 1, a + n * m + 1);
    for (int i = 0; i <= n * m; i ++ ) {
        if (i) F[i] = a[i] * frac[i - 1] % mod;
        G[i] = inv_frac[n * m - i];
    }
    covolution(F, G, n * m + 1, n * m * 2 + 1);
    for (int i = 1; i <= n * m; i ++ ) ans[i] = F[n * m + i] * frac[n * m - i] % mod * i % mod;
    clr(F, n * m * 2 + 1), clr(G, n * m * 2 + 1);
    ll res = 0;
    for (int i = 0; i <= n; i ++ ) {
        for (int j = 0; j <= m; j ++ ) {
            int t = i * m + j * n - i * j;
            if ((i + j) & 1) res = (res + ans[t] * C(n, i) % mod * C(m, j) % mod) % mod;
            else res = (res - ans[t] * C(n, i) % mod * C(m, j) % mod + mod) % mod;
        }
    }
    cout << res << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    int T = 1;
    cin >> T;
    while (T -- ) solve();
    return 0;
}
posted @ 2025-10-21 16:12  YipChip  阅读(2)  评论(0)    收藏  举报