[ABC273G] Row Column Sums 2 题解

[ABC273G] Row Column Sums 2 Solution

更好的阅读体验戳此进入

题面

给定 $ n $,给定序列 $ R_n, C_n $,求有多少个 $ n \times n $ 的矩阵满足第 $ i $ 行的数之和为 $ R_i $,第 $ i $ 列的数之和为 $ C_i $,保证 $ 0 \le R_i, C_i, \le 2 $,答案对 $ 998244353 $ 取模。

Solution

提供一种较为好理解且转移较为好实现的做法,核心思路来自 @sssmzy。

设计状态 $ dp(i, j, k) $ 表示考虑前 $ i $ 行剩余 $ j $ 列需要再添 $ 1 \(,\) k $ 列需要再添 $ 2 $ 的方案数,考虑对于每一行可以去消除一些 $ j $ 或 $ k $,考虑转移:

对于 $ R_i = 0 $,显然无法产生贡献。

\[R_i = 0 : dp(i, j, k) \leftarrow dp(i - 1, j, k) \]

对于 $ R_i = 1 $,考虑要么将其补全一个 $ j $,要么将其去补到 $ k $ 使得减少一个 $ k $ 并多一个 $ j $。

\[R_i = 1 : dp(i, j, k) \leftarrow dp(i - 1, j + 1, k) \times (j + 1) + dp(i - 1, j - 1, k + 1) \times (k + 1) \]

对于 $ R_i = 2 $,考虑可以补全两个 $ j $,或补全一个 $ k $,或将两个 $ k $ 转化为两个 $ j $,或补全一个 $ j $ 并将一个 $ k $ 转换为 $ j $。

\[R_i = 2 : dp(i, j, k) \leftarrow dp(i - 1, j + 2, k) \times {j + 2 \choose 2} + dp(i - 1, j, k + 1) \times (k + 1) + dp(i - 1, j - 2, k + 2) \times {k + 2 \choose 2} + dp(i - 1, j, k + 1) \times (k + 1) \times j \]

此为 $ \texttt{3D0D} \(,考虑在每一层转移中,\) j + k $ 的值应为固定的,于是当我们已知 $ j $ 时 $ k $ 就是固定的,具体来说维护一个 $ \sum C_i $,然后每次转移时删去当前次贡献的 $ R_i $ 即可。于是以此可以无需枚举最后一维(同理也可以选择删去 $ i $ 或 $ j $),优化为 $ \texttt{2D0D} $,可以通过。

Tips:注意需要特判 $ \sum R_i \neq \sum C_i $ 时无解,答案为 $ 0 $。

Code

#define _USE_MATH_DEFINES
#include <bits/stdc++.h>

#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW void* Edge::operator new(size_t){static Edge* P = ed; return P++;}
#define ROPNEW_NODE void* Node::operator new(size_t){static Node* P = nd; return P++;}

using namespace std;

mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}

typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;

#define MOD (998244353ll)

template < typename T = int >
inline T read(void);

int N;
ll sumC(0);
ll dp[5100][5100];
int R[5100], C[5100];
int cnt1(0);

int main(){
    // freopen("in.txt", "r", stdin);
    N = read();
    for(int i = 1; i <= N; ++i)R[i] = read();
    for(int i = 1; i <= N; ++i)sumC += (C[i] = read());
    for(int i = 1; i <= N; ++i)if(C[i] == 1)++cnt1;
    dp[0][cnt1] = 1;
    for(int i = 1; i <= N; ++i){
        sumC -= R[i];
        for(int j = 0; j <= N; ++j){
            if((sumC - j) & 1)continue;
            int k = (sumC - j) >> 1;
            switch(R[i]){
                case 0:{
                    dp[i][j] = dp[i - 1][j];
                    break;
                }
                case 1:{
                    (dp[i][j] += dp[i - 1][j + 1] * (j + 1) % MOD) %= MOD;
                    if(j - 1 >= 0)(dp[i][j] += dp[i - 1][j - 1] * (k + 1) % MOD) %= MOD;
                    break;
                }
                case 2:{
                    (dp[i][j] += dp[i - 1][j + 2] * ((j + 2) * (j + 1) / 2 % MOD) % MOD) %= MOD;
                    (dp[i][j] += dp[i - 1][j] * (k + 1) % MOD) %= MOD;
                    if(j - 2 >= 0)(dp[i][j] += dp[i - 1][j - 2] * ((k + 2) * (k + 1) / 2 % MOD) % MOD) %= MOD;
                    (dp[i][j] += dp[i - 1][j] * (k + 1) % MOD * j % MOD) %= MOD;
                    break;
                }
            }
        }
    }
    printf("%lld\n", sumC ? 0 : dp[N][0]);
    fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
    return 0;
}

template < typename T >
inline T read(void){
    T ret(0);
    int flag(1);
    char c = getchar();
    while(c != '-' && !isdigit(c))c = getchar();
    if(c == '-')flag = -1, c = getchar();
    while(isdigit(c)){
        ret *= 10;
        ret += int(c - '0');
        c = getchar();
    }
    ret *= flag;
    return ret;
}

UPD

update-2023_02_20 初稿

posted @ 2023-03-05 12:45  Tsawke  阅读(57)  评论(0)    收藏  举报