一道模拟赛中的计数好题

\(\text{By hhoppitree.}\)

\(\textbf{Round 3 D. }\) 黑白图

题目大意

有一张 \(n\times m\) 的网格图,求将每个格子染成黑白的方案数,使得每一个 \(N\times M\) 的子矩形内黑色格子个数相同,对 \(998244353\) 取模。
数据范围:\(N\le M\le9,\;n,m\le10^5,\;N\mid n,M\mid m\)\(\texttt{4s/512MB}\)

思路分析

记网格图为 \(a\),其中第 \(i\) 行第 \(j\) 列的元素 \(a_{i,j}=1\) 当且仅当原网格图的第 \(i\) 行第 \(j\) 列的位置为黑格。

发现每个 \(N\times M\) 的子矩形内黑色格子个数相同等价于:

  • 对于 \(1\le i\le n-N\),都有 \(\sum\limits_{j=1}^{M}a_{i,j}=\sum\limits_{j=1}^{M}a_{i+N,j}\)
  • 对于 \(1\le j\le m-M\),都有 \(\sum\limits_{i=1}^{N}a_{i,j}=\sum\limits_{i=1}^{N}a_{i,j+M}\)
  • 对于 \(1\le i\le n-N,1\le j\le m-M\),都有 \(a_{i,j}+a_{i+N,j+M}=a_{i,j+M}+a_{i+N,j}\)

如果前 \(M\) 列的所有元素都确定了,我们可以很容易计算出此时对应的 \(a\) 的个数,具体而言,记 \(c_{0,j},c_{1,j}\) 分别为 \(1\le i\le N\) 中分别满足所有 \(a_{x,j}\left(x\equiv i\pmod N\right)\) 均相等且均为 \(i\) 的个数,则此时对应的 \(a\) 的方案数为

\[\left(\prod\limits_{i=1}^{M}\sum\limits_{j=0}^{\min(c_{0,j},c_{1,j})}\dbinom{c_{0,i}}{j}\dbinom{c_{1,i}}{j}\right)^{\frac{m}{M}-1} \]

考虑动态规划,记 \(f_{i,S}\) 为目前考虑了前 \(M\) 列的所有满足 \(x\equiv t\pmod N,t\le i\)\(a_{x,y}(1\le y\le M)\),且 \((c_{0,j},c_{1,j})\) 构成的序列为 \(S\) 的方案数,转移时考虑枚举当前 \(S\) 的哪些位置会加 \(1\)(显然一个元素对中最多只会有一个元素会被增加 \(1\)),使用容斥计算即可。

具体而言,若设 \(g_i\) 为有一个 \(\dfrac{N}{n}\)\(i\) 列的矩阵,其中每一列的元素不全相同,且每一行的元素和全部相同的方案数,\(g\) 可以使用容斥简单算出。

假设 \(S\) 中有元素被增加 \(1\) 的元素对为 \(T\),容易发现我们不关心具体加的是 \(c_{0,j}\) 还是 \(c_{1,j}\),那么我们有 \(\text{dp}\) 转移式 \(f_{i,S}\times g_{m-|T|}\to f_{i+1,\delta(S,T)}\),其中 \(\delta(S,T)\) 表示所有可能的 \(S\)\(T\) 集合内的对的任意一个元素被增加 \(1\) 过后的结果(有 \(2^{|T|}\) 个,因为可能是 \(c_{0,j}\) 或者是 \(c_{1,j}\))。

这样就可以得到一个暴力算法,不过,在进行了以下两点观察后,我们可以得到一个复杂度更优秀,且代码实现更简洁的算法:

  • \(S\) 的元素顺序不影响答案,也就是说我们可以把 \(S\) 看作是一个可重集;
  • 我们不需要分别记录 \(c_{0,j}\)\(c_{1,j}\),而只需要记录 \(S\)\(c_{0,j}+c_{1,j}\) 构成的集合 所对应的 \(f_{i,S}\),最后若 \(c_{0,j}+c_{1,j}\) 构成的集合为 \(S\),则其实际方案数为 \(f_{n,S}\prod\limits_{j=1}^{m}\dfrac{\binom{c_{0,j}+c_{i,j}}{c_{0,j}}}{2^{c_{0,j}+c_{1,j}}}\)

在代码实现中,可以在转移的时候不乘上 \(2\)(即枚举 \(c_{0,j}\)\(c_{1,j}\) 的两种情况),最终答案就会除掉 \(\prod\limits_{j=1}^{m}2^{c_{0,j}+c_{1,j}}\)

此时得到转移方程:\(f_{i,S}\times g_{m-|T|}\to f_{i,\delta(S,T)}\),为了方便实现,我们在程序实现中对于每个 \(x\),记录了 \(S\)\(x\) 的个数,在转移的时候可以直接使用 \(\text{dfs}\) 进行转移(因为每个 \(x\) 最多只会被“拔高”一级),并且使用了集合哈希来判断相等来优化复杂度。

代码呈现

#include <bits/stdc++.h>

using namespace std;

const int P = 998244353;

int ksm(int x, int y) {
    int res = 1;
    while (y) {
        if (y & 1) res = 1ll * res * x % P;
        x = 1ll * x * x % P;
        y >>= 1;
    }
    return res;
}

int f[10], g[10], C[10][10];
unordered_map<unsigned int, pair< vector<int>, int> > dp[10];
mt19937 rnd;
unsigned int val[10];
vector<int> loc;

void dfs(int id, int x, unsigned int hsh, int y, int w) {
    if (!~x) {
        if (!dp[id].count(hsh)) dp[id][hsh] = {loc, 1ll * y * f[w] % P};
        else dp[id][hsh].second = (dp[id][hsh].second + 1ll * y * f[w]) % P;
        return;
    }
    for (int i = 0; i <= loc[x]; ++i) {
        if (i) loc[x] -= i, loc[x + 1] += i;
        dfs(id, x - 1, hsh + (val[x + 1] - val[x]) * i, 1ll * y * C[loc[x] + i][i] % P, w - i);
        if (i) loc[x] += i, loc[x + 1] -= i;
    }
}

signed main() {
    freopen("map.in", "r", stdin);
    freopen("map.out", "w", stdout);
    int h, w, n, m; scanf("%d%d%d%d", &h, &w, &n, &m);
    for (int i = 1; i <= h; ++i) val[i] = rnd();
    vector<int> tz;
    for (int i = 0; i <= h; ++i) tz.push_back(!i ? w : 0);
    dp[0][0] = {tz, 1}, f[0] = 1;
    for (int i = C[0][0] = 1; i <= w; ++i) {
        for (int j = C[i][0] = 1; j <= i; ++j) C[i][j] = C[i - 1][j - 1] + C[i - 1][j];
        for (int j = 0; j <= i; ++j) f[i] = (f[i] + ksm(C[i][j], n / h)) % P;
        for (int j = 0; j < i; ++j) f[i] = (f[i] - (__int128)f[j] * (C[i][j] << (i - j)) % P + P) % P;
    }
    for (int i = 0; i <= h; ++i) {
        for (int j = 0, k = i; j <= i; ++j, --k) {
            int sum = 0;
            for (int l = 0; l <= j && l <= k; ++l) sum += C[j][l] * C[k][l];
            g[i] = (g[i] + 1ll * C[i][j] * ksm(sum, m / w - 1)) % P;
        }
    }
    for (int i = 0; i < h; ++i) {
        for (auto [ths, ty] : dp[i]) {
            loc = ty.first;
            dfs(i + 1, h, ths, ty.second, w);
        }
    }
    int res = 0;
    for (auto [x, y] : dp[h]) {
        int mul = y.second;
        for (int i = 0; i <= h; ++i) mul = 1ll * mul * ksm(g[i], y.first[i]) % P;
        res = (res + mul) % P;
    }
    printf("%d\n", res);
    return 0;
}
posted @ 2024-10-11 15:55  hhoppitree  阅读(287)  评论(0)    收藏  举报