题解:AT_abc379_g [ABC379G] Count Grid 3-coloring
前情提要:C 题喜提 \(6\) 发罚时,这题没调完了。
发现 \(\min(h,w) \leq 14\),可以状压。
很经典的套路,发现枚举点 \((i,j)\) 时,只需记下将所有同一行中在 \((i,j)\) 左边的格子和上一行中在 \((i-1,j)\) 右边的格子,也就是下图的黄色格子。
可以发现有效状态为 \(2^n\) 的,但用 \(3^n\) 也能过,而且更好写,下面是一些卡常技巧:
-
开
-O3
。 -
取模优化。
-
把常数
const
掉。 -
可以发现枚举时有一些位是固定为一些数的,只枚举其他位。
具体实现看代码。时间复杂度 \(O(hw3^{\min(h,w)})\)。
Code:
#pragma GCC optimize(3)
#include<iostream>
#include<cstdlib>
#include<ctime>
#include<cassert>
#include<vector>
#include<cmath>
#include<cstring>
#include<set>
#include<climits>
#include<queue>
#include<stack>
#include<bitset>
#include<map>
#include<algorithm>
using namespace std;
const int N = 205, MOD = 998244353, pw[] = {1,3,9,27,81,243,729,2187,6561,19683,59049,177147,531441,1594323,4782969};
int h, w, la, f[2][4782969], mp[N][N];
inline int gt(int mask, int k){
return (mask % pw[k]) / pw[k - 1];
}
inline void upd(int &x, int y){
((x += y) >= MOD)? (x -= MOD) : x;
}
inline void dp(int pos, int x){
int t = x * pw[pos - 1];
if(pos == 1){
int a = (x + 1) % 3, b = (a + 1) % 3;
for(int mask = 0; mask < pw[h]; mask += 3){
upd(f[la ^ 1][mask + t], f[la][mask + a]);
upd(f[la ^ 1][mask + t], f[la][mask + b]);
}
}
else{
int a = (x + 1) % 3 * pw[pos - 2], b = (x + 2) % 3 * pw[pos - 2];
int aa = 3 * a + a, ab = 3 * a + b, ba = 3 * b + a, bb = 3 * b + b;
for(int x = 0; x < pw[h - pos]; x++){
for(int y = 0, z = x * pw[pos]; y < pw[pos - 2]; y++, z++){
upd(f[la ^ 1][z + t + a], f[la][z + aa]);
upd(f[la ^ 1][z + t + a], f[la][z + ba]);
upd(f[la ^ 1][z + t + b], f[la][z + ab]);
upd(f[la ^ 1][z + t + b], f[la][z + bb]);
}
}
}
}
main(){
// freopen("input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
ios::sync_with_stdio(0);
cin.tie(0);
cin >> h >> w;
for(int i = 1; i <= h; i++){
for(int j = 1; j <= w; j++){
char ch;
cin >> ch;
mp[i][j] = ((ch == '?')? 0 : (ch - '0'));
}
}
if(h > w){
swap(h, w);
for(int i = 1; i <= w; i++)
for(int j = i + 1; j <= w; j++)
swap(mp[i][j], mp[j][i]);
}
for(int i = 0; i < pw[h]; i++){
f[la][i] = 1;
for(int j = 1; j <= h; j++){
if((mp[j][1] && mp[j][1] != gt(i, j) + 1) || (j > 1 && gt(i, j) == gt(i, j - 1))){
f[la][i] = 0;
break;
}
}
}
for(int i = 2; i <= w; i++){
for(int j = 1; j <= h; j++){
memset(f[la ^ 1], 0, sizeof(f[la ^ 1]));
if(mp[j][i])
dp(j, mp[j][i] - 1);
else{
dp(j, 0);
dp(j, 1);
dp(j, 2);
}
la ^= 1;
}
}
int sum = 0;
for(int i = 0; i < pw[h]; i++)
upd(sum, f[la][i]);
cout << sum << "\n";
}