P9237 [蓝桥杯 2023 省 A] 像素放置

暴力

一开始还想先分析哪个点会引出最少的分支数量,然后优化顺序。

但是转念一想,既然都写暴力了,就别考虑这么复杂的问题了()

因为题目给的是九宫格类型的拓展,所以最开始写的暴搜是从左上角开始,每次扩展九宫格的搜索方式,搜索完毕之后再进行check

但是这样只有15分

稍作思考一下,便能发现是这个完全没有必要的九宫格扩展方式导致的,它只会白白地增加大量的分支数目。

然后把九宫格方式改成向上下左右四个方向扩展,果然,变成了四十分。

但就连上下左右四个方向扩展也是没有必要的——只要一个个,一行行地填过去就好了。

果然,修改之后分数来到了50分。

#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <string>
#include <cmath>
#define R(x) x = read()
#define For(i, j, n) for (int i = j; i <= n; ++i)
using namespace std;

inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9')
    {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

const int N = 15;

char a[N][N];
int n, m;
int g[N][N];
int lim[N][N], now[N][N];
bool vis[N][N];

void init()
{
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
            if (a[i][j] != '_')
                lim[i][j] = a[i][j] - '0';
            else
                lim[i][j] = 100;
}

bool check()
{
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
        {
            if (lim[i][j] == 100)
                continue;
            if (now[i][j] != lim[i][j])
                return false;
        }
    return true;
}

int dx[4] = {-1, 1, 0, 0};
int dy[4] = {0, 0, -1, 1};

bool dfs(int x, int y, int cnt, int pat)
{
    int pre = g[x][y];
    g[x][y] = pat;
    if (pat)
    {
        for (int i = x - 1; i <= x + 1; i++)
            for (int j = y - 1; j <= y + 1; j++)
            {
                if (i < 0 || i >= n || j < 0 || j >= m)
                    continue;
                now[i][j]++;
            }
    }
    vis[x][y] = 1;
    if (cnt == 1)
    {
        if (check())
            return true;
        else
        {
            if (pat)
            {
                for (int i = x - 1; i <= x + 1; i++)
                    for (int j = y - 1; j <= y + 1; j++)
                    {
                        if (i < 0 || i >= n || j < 0 || j >= m)
                            continue;
                        now[i][j]--;
                    }
            }
            g[x][y] = pre;
            vis[x][y] = 0;
            return false;
        }
    }
    int ny = y + 1, nx = x;
    if (ny == m)
    {
        ny = 0;
        nx = x + 1;
    }
    if (dfs(nx, ny, cnt - 1, 0))
        return true;
    if (dfs(nx, ny, cnt - 1, 1))
        return true;
    if (pat)
    {
        for (int i = x - 1; i <= x + 1; i++)
            for (int j = y - 1; j <= y + 1; j++)
            {
                if (i < 0 || i >= n || j < 0 || j >= m)
                    continue;
                now[i][j]--;
            }
    }
    g[x][y] = pre;
    vis[x][y] = 0;
    return false;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; i++)
        scanf("%s", a[i]);
    init();
    for (int i = 0; i < 2; i++)
        if (dfs(0, 0, n * m, i))
            break;
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < m; j++)
            printf("%d", g[i][j]);
        puts("");
    }
    return 0;
}

一点点优化

做完才检查太慢了,可以尝试每次填入1之后,都检查是否合法

#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <string>
#include <cmath>
#include <string.h>
#define R(x) x = read()
#define For(i, j, n) for (int i = j; i <= n; ++i)
using namespace std;

inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9')
    {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

const int N = 15;

char a[N][N];
int n, m;
int g[N][N];
int lim[N][N], now[N][N], nowB[N][N];
bool vis[N][N];

void init()
{
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
            if (a[i][j] != '_')
                lim[i][j] = a[i][j] - '0';
            else
                lim[i][j] = 100;
}

bool check()
{
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
        {
            if (lim[i][j] == 100)
                continue;
            if (now[i][j] != lim[i][j])
                return false;
        }
    return true;
}

void reset(int x, int y)
{
    for(int i = x - 1; i <= x + 1; i++)
        for(int j = y - 1; j <= y + 1; j++)
        {
            if (i < 0 || i >= n || j < 0 || j >= m)
                continue;
            now[i][j] = nowB[i][j];
        }
}

bool dfs(int x, int y, int cnt, int pat)
{
    int pre = g[x][y];
    g[x][y] = pat;
    if (pat)
    {
        memcpy(nowB, now, sizeof(now));
        for (int i = x - 1; i <= x + 1; i++)
            for (int j = y - 1; j <= y + 1; j++)
            {
                if (i < 0 || i >= n || j < 0 || j >= m)
                    continue;
                now[i][j]++;
                if(now[i][j] > lim[i][j])
                {
                    reset(x, y);
                    g[x][y] = 0;
                    return false;
                }
            }
    }
    vis[x][y] = 1;
    if (cnt == 1)
    {
        if (check())
            return true;
        else
        {
            if (pat)
            {
                for (int i = x - 1; i <= x + 1; i++)
                    for (int j = y - 1; j <= y + 1; j++)
                    {
                        if (i < 0 || i >= n || j < 0 || j >= m)
                            continue;
                        now[i][j]--;
                    }
            }
            g[x][y] = pre;
            vis[x][y] = 0;
            return false;
        }
    }
    int ny = y + 1, nx = x;
    if (ny == m)
    {
        ny = 0;
        nx = x + 1;
    }
    if (dfs(nx, ny, cnt - 1, 1))
        return true;
    if (dfs(nx, ny, cnt - 1, 0))
        return true;
    if (pat)
    {
        for (int i = x - 1; i <= x + 1; i++)
            for (int j = y - 1; j <= y + 1; j++)
            {
                if (i < 0 || i >= n || j < 0 || j >= m)
                    continue;
                now[i][j]--;
            }
    }
    g[x][y] = pre;
    vis[x][y] = 0;
    return false;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; i++)
        scanf("%s", a[i]);
    init();
    for(int i = 1; i >= 0; i--)
        if (dfs(0, 0, n * m, i))
        {
            break;
        }
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < m; j++)
            printf("%d", g[i][j]);
        puts("");
    }
    return 0;
}

但是加入这种即时检查之后,只多过了一个点,而且还是洛谷加的额外测试点,如果在考试当中的话,分数是不会有任何变化的。

那么为什么这个即时检查会形同虚设呢?

分析之后发现,我们每次插入1之后,都会对它所在的九宫格进行检查——但是对于这个九宫格里的很多点,他们的情况甚至还没有确定下来,这个时候检查是没有意义的,那要怎么优化我们的检查方式呢?

正解

一开始在洛谷翻题解的时候,看到了插头DP这样高大上的标签,顿时望而却步,但是想清楚之后,发现这道题只是借用了插头DP的思想来进行记忆化搜索:

前面我们提到,每次插入1之后在相应的九宫格做check是没有意义的,那么我们应该在哪里check呢?

我们应该检查那些状态已经确定的点

即如下函数所示:

bool can(int x,int y,int v){
    for(int i=-1;i<=1;++i){
        for(int j=-1;j<=1;++j){
            int nx=x+i,ny=y+j;
            if(in(nx,ny) && cur[nx][ny]+v>a[nx][ny])return 0;
        }
    }
    if(in(x-1,y-1) && cur[x-1][y-1]+v<a[x-1][y-1])return 0;
    if(x==n-1 && in(x,y-1) && cur[x][y-1]+v<a[x][y-1])return 0;
    if(y==m-1 && in(x-1,y) && cur[x-1][y]+v<a[x-1][y])return 0;
    return 1;
}

对于九宫格内的点还是要进行检查,因为它们如果无法通过检查的话,就连插入操作都无法进行了

 

除了优化检查之外,我们还要加入记忆化来进一步提高效率

bitset<M>dp[N][N];

dp[i][j][st]表示当前进行到(i,j)位置,且前两行加上(i,j)左边两个点的状态为st的方案是否被搜索过

如果前面不足两行,或者左边不足两格,那么就是0(默认情况下bitset就是0)

这里用bitset的原因是int存不下

#include<bits/stdc++.h>
using namespace std;
const int S=15,N=10,M=1<<22,INF=0x3f3f3f3f;
int n,m,a[S][S],b[S][S],cur[S][S];
bitset<M>dp[N][N];
char s[S][S];
bool ok;
bool in(int x,int y){
    return 0<=x && x<n && 0<=y && y<m && a[x][y]!=INF;
}
bool can(int x,int y,int v){
    for(int i=-1;i<=1;++i){
        for(int j=-1;j<=1;++j){
            int nx=x+i,ny=y+j;
            if(in(nx,ny) && cur[nx][ny]+v>a[nx][ny])return 0;
        }
    }
    if(in(x-1,y-1) && cur[x-1][y-1]+v<a[x-1][y-1])return 0;
    if(x==n-1 && in(x,y-1) && cur[x][y-1]+v<a[x][y-1])return 0;
    if(y==m-1 && in(x-1,y) && cur[x-1][y]+v<a[x-1][y])return 0;
    return 1;
}
void dfs(int x,int y,int w){
    //cnt++;
    cout << x << " " << y << " " << w << endl;
    if(x==n && y==0){
        ok=1;
        for(int i=0;i<n;++i){
            for(int j=0;j<m;++j){
                printf("%d",b[i][j]);
            }
            puts("");
        }
        return;
    }
    if(ok)return;
    if(dp[x][y][w])return;
    dp[x][y][w]=1;
    for(int v=1;v>=0;--v){
        if(ok)return;
        if(can(x,y,v)){
            b[x][y]=v;
            if(v){
                for(int i=-1;i<=1;++i){
                    for(int j=-1;j<=1;++j){
                        int nx=x+i,ny=y+j;
                        if(!in(nx,ny))continue;
                        cur[nx][ny]++;
                    }
                }
            }
            int nw=(w<<1)|v;
            if(y==m-1)dfs(x+1,0,nw&((1<<(2*m))-1));
            else dfs(x,y+1,nw&((1<<(2*m+2))-1));
            if(v){
                for(int i=-1;i<=1;++i){
                    for(int j=-1;j<=1;++j){
                        int nx=x+i,ny=y+j;
                        if(!in(nx,ny))continue;
                        cur[nx][ny]--;
                    }
                }
            }
        }
    }
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<n;++i){
        scanf("%s",s[i]);
        for(int j=0;j<m;++j){
            a[i][j]=INF;
            if(s[i][j]!='_')a[i][j]=s[i][j]-'0';
        }
    }
    dfs(0,0,0);
    //printf("cnt:%d\n",cnt);
    return 0;
}

然后就能愉快的AC了

状压为什么要压两行?

这道题每个像素的影响范围和AcWing 1064. 小国王当中一样,都是九宫格,可为什么小国王这道题只需要压一行,但是这里却要压两行?

因为小国王那道题只有“影响范围”一个限制,我只要考虑会不会相互攻击到即可。

但是这道题像素之间在同一个九宫格内不仅会“相互影响”,而且这个影响是具体的:这个九宫格内存在一个像素的数量限制,对于被夹在中间的那一行来说,它既要向前看也要向后看,所以要压两行。

类似地,AcWing 292. 炮兵阵地这道题的影响范围有两行,但是无数量限制,所以这道题也只需要压两行。

一点点细节

if(y==m-1)dfs(x+1,0,nw&((1<<(2*m))-1));
            else dfs(x,y+1,nw&((1<<(2*m+2))-1));

这里为什么要和一个全是1的数字进行&运算?不是没有用吗

因为在我们dfs的过程中,w会一直左移,那些高位的1就一直往左移动,如果不和这个数字取余的话,高位的、已经离开两行以内范围的1就一直消不掉,就会导致我们状态转移不断积累,一直变大。

另外,“全是1的数字”这个说法本身就是不严谨的,因为这两个由1<<若干位减一得到的数字,也只有右边一串全是1,左边高位仍然是全0的

这里也有类似的用法:

for (int i = 0; i < n; ++ i) f[1 << i][0] = 0;//开局免费选一个起点(初始状态)
    for (int cur = 1, cost; cur < 1 << n; ++ cur)
        for (int pre = cur - 1 & cur; pre; pre = pre - 1 & cur)
            if (~(cost = get_cost(cur, pre)))
                for (int k = 1; k < n; ++ k)
                    f[cur][k] = min(f[cur][k], f[pre][k - 1] + cost * k);

作者:一只野生彩色铅笔
链接:https://www.acwing.com/solution/content/59439/
来源:AcWing
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

pre通过和cut取余,保证pre得到的是cur的子集。

 

最后,像是这种带有是否成功标准的dfs函数,既可以写成bool函数,也可以维护一个全局变量,来判断是否找到答案。

posted @ 2024-04-09 15:22  Gold_stein  阅读(40)  评论(0编辑  收藏  举报