【洛谷P2258】子矩阵

题目大意:给定一个 N*M 的矩阵,现从 N 行中选出 R 行,M 列中选出 C 列,构成一个 R*C 子矩阵,求这个子矩阵相邻元素差的绝对值之和的最小值是多少。

题解:
发现是对行和列的组合生成,若直接暴力的话,时间复杂度为 \(O({n \choose r}{m \choose c}nm)\)
代码如下

#include <bits/stdc++.h>
using namespace std;
const int maxn=20;

int n,m,r,c,ans,mp[maxn][maxn];
vector<int> row,col;

inline void calc(){
    int ret=0;
    for(int i=0;i<r;i++)for(int j=0;j<c-1;j++)ret+=abs(mp[row[i]][col[j]]-mp[row[i]][col[j+1]]);
    for(int i=0;i<r-1;i++)for(int j=0;j<c;j++)ret+=abs(mp[row[i]][col[j]]-mp[row[i+1]][col[j]]);
    ans=min(ans,ret);	
}
void dfsc(int now){
    if(col.size()>c||col.size()+m-now+1<c)return;
    if(now==m+1){
        calc();
        return;
    }
    col.push_back(now);
    dfsc(now+1);
    col.pop_back();
    dfsc(now+1);
}
void dfsr(int now){
    if(row.size()>r||row.size()+n-now+1<r)return;
    if(now==n+1){
        dfsc(1);
        return;	
    }
    row.push_back(now);
    dfsr(now+1);
    row.pop_back();
    dfsr(now+1);
}

void read_and_parse(){
    scanf("%d%d%d%d",&n,&m,&r,&c);
    for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)scanf("%d",&mp[i][j]);
}
void solve(){
    ans=1<<30;
    dfsr(1);
    printf("%d\n",ans);
}
int main(){
    read_and_parse();
    solve();
    return 0;
}

进一步考虑,发现若枚举出了 r 行,那么对于每一列来说,可以抽象成下列问题,即:给定一个长度为 N 的序列,现从序列中选出 M 个元素组成的子序列,使得这 M 个元素中相邻两个元素差的绝对值之和最小。发现是一个 dp,对于矩阵来说,将矩阵转化成序列即可,dp 的时间复杂度为 \(O(n^3)\)。总的时间复杂度为 \(O({n \choose r}m^3)\)

代码如下

#include <bits/stdc++.h>
#define cls(a,b) memset(a,b,sizeof(a))
using namespace std;
const int maxn=20;

int n,m,r,c,ans,mp[maxn][maxn];
vector<int> row;
int dp[maxn][maxn],extra[maxn],cost[maxn][maxn];

inline void calc(){
    cls(dp,0x3f),cls(extra,0),cls(cost,0);
    for(int i=1;i<=m;i++)for(int j=i+1;j<=m;j++)for(auto ro:row)cost[i][j]+=abs(mp[ro][i]-mp[ro][j]);
    for(int co=1;co<=m;co++)for(int i=0;i<row.size()-1;i++)extra[co]+=abs(mp[row[i]][co]-mp[row[i+1]][co]);
    for(int i=0;i<=m;i++)dp[i][0]=0;
    for(int i=1;i<=m;i++)
        for(int j=1;j<=i;j++)
            for(int k=j-1;k<i;k++)
                dp[i][j]=min(dp[i][j],dp[k][j-1]+cost[k][i]+extra[i]);
    for(int i=1;i<=m;i++)ans=min(ans,dp[i][c]);
}
void dfs(int now){
    if(row.size()>r||row.size()+n-now+1<r)return;
    if(now==n+1){
        calc();
        return;	
    }
    row.push_back(now);
    dfs(now+1);
    row.pop_back();
    dfs(now+1);
}

void read_and_parse(){
    scanf("%d%d%d%d",&n,&m,&r,&c);
    for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)scanf("%d",&mp[i][j]);
}
void solve(){
    ans=1<<30;
    dfs(1);
    printf("%d\n",ans);
}
int main(){
    read_and_parse();
    solve();
    return 0;
}
posted @ 2019-05-09 20:57  shellpicker  阅读(234)  评论(0编辑  收藏  举报