luogu P2331 [SCOI2005]最大子矩阵

传送门

\[\huge\mathit{warning} \]

\[\small\text{以下说明文字高能,请心脏病,,,,,,人士谨慎观看,请未成年人在家长陪同下观看} \]


皮这一下很开心

其实是代码很丑而已

不要在意那些奇怪的变量名,和那四个布尔函数

看到\(k\)很小,\(m\leq2\),很爽有没有,设\(f_{i,j,k}\)表示第\(i\)行的二进制状态为\(j\)(0不放,1放),选了\(k\)个矩阵的最大值.转移时枚举当前放的状态,记为\(o\),然后和上一行状态作比较,如果j不等于当前状态o,并且o不为0,k就加1

观察样例,我们注意到选出的两个子矩阵是两条竖着的,而如果用上述方法,如果要选右下角的3,那么得出来最少需要3个子矩阵

继续观察,可以发现如果上一行状态为3(二进制11),且当前行为1或2,那么这连下来的一部分可以接在上面,例如\(\begin{matrix}0&1\\1&1\\1&0\end{matrix}\)以及\(\begin{matrix}0&1\\1&1\\0&1\end{matrix}\),这两种情况都至少只有2个子矩阵.

所以,转移时,如果当前状态o不是j的子集,并且o不为0,k就加1

其实还是错的,因为有这种情况\(\begin{matrix}1&1\\1&1\\0&1\end{matrix}\),这种情况子矩阵个数为2,但是上述算法会得到1

综合上述三种情况,我们可以发现如果上一行状态为3,这一行状态为1或2,如果上一行所在的1连通块中每行状态全是3,那么k是要加1的

所以,转移时,如果当前状态o不是j的子集,或者o是j子集并且o不为0并且j所在的1连通块中每行状态全是3,k就加1

这时需要多开一维,表示并且j所在的1连通块中每行状态是否全是3

好了,剩下的详见代码

对了,注意不一定要选k个非空子矩阵

#include<bits/stdc++.h>
#define LL long long
#define il inline
#define re register
#define db double
#define eps (1e-5)

using namespace std;
il LL rd()
{
    LL x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
    return x*w;
}
int n,nn,m,kk,a[4],f[2][4][2][12];
il bool o1(int o){return o>0;}
il bool o2(int j,int o){return o==3&&(j==1||j==2)&&((o&j)==j);}
il bool o3(int j,int o){return (o&j)!=o;}
il bool o4(int j,int k,int o){return j==3&&(o==1||o==2)&&(!k);}

int main()
{
  n=rd(),m=rd(),kk=rd();nn=1<<m;
  memset(f,-63,sizeof(f));
  int O=f[0][0][0][0],inf=-23333333;
  f[0][0][0][0]=0;
  int nw=1,la=0;
  for(int i=1;i<=n;i++)
    {
      for(int j=1;j<=m;j++) a[j]=rd();
      if(m==2) a[3]=a[1]+a[2];
      for(int j=0;j<nn;j++)
        for(int k=0;k<=1;k++)
          for(int l=0;l<=kk;l++)
            {
              if(f[la][j][k][l]<=inf) continue;
              for(int o=0;o<nn;o++)
                {
                  int nk=((k&o1(o))|o2(j,o)),dl=(o3(j,o)|o4(j,k,o));
                  f[nw][o][nk][l+dl]=max(f[nw][o][nk][l+dl],f[la][j][k][l]+a[o]);
                }
              f[la][j][k][l]=O;
            }
      nw^=1,la^=1;
    }
  int ans=inf;
  for(int j=0;j<nn;j++)
    for(int k=0;k<=1;k++)
      for(int l=0;l<=kk;l++)
        ans=max(ans,f[la][j][k][l]);
  printf("%d\n",ans);
  return 0;
}


posted @ 2018-10-16 19:55  ✡smy✡  阅读(127)  评论(0编辑  收藏  举报