【题解】【THUSC 2016】成绩单 LOJ 2292 区间dp

Prelude

快THUWC了,所以补一下以前的题。
真的是一道神题啊,网上的题解没几篇,而且还都看不懂,我做了一天才做出来。

传送到LOJ:(>人<;)


Solution

直接切入正题。
我们考虑区间dp,第一件事是离散化。
然后用\(g(i,j)\)表示消除完闭区间\([i,j]\)的最小费用。
然后呢?怎么转移?exm???
这时候会有一个非常自然的想法。
计算\(g(i,j)\)的时候,我们枚举两个数\(l,r\),然后保留下值在闭区间\([l,r]\)之内的所有数,先消除掉其他的数字,就只剩\([l,r]\)之内的数字了,再一次性消除掉她们。
时间复杂度\(O(n^5)\),但是显然是错的。
错在哪里呢?大概是错在下面这种情况,我懒得构造具体的反例了。
对于一组数字\(abcabca\),我们可以先消除掉中间的\(a\),再消除掉\(bcbc\),最后再消除掉\(aa\),在我们的dp里面似乎并没有考虑到这种情况。
因为\(aa\)是最后消除掉的,因此如果我们选择保留\(a\)的话,会保留下来所有的\(a\)
我们太仁慈了,保留下来了\([l,r]\)之间的所有的数字,其实不一定要保留所有数字。
怎么办呢?
脑洞大开!
我们用\(f(i, j, l, r)\)表示,消除完在闭区间\([i,j]\)之内的,除了值在\([l,r]\)之间的所有数字。
注意,在\([l,r]\)之间的数字,可以消除,也可以不消除。
然后显然有这个东西:

\(\Large g(i, j) = \min f(i, j, l, r)\)

实际上就是枚举\(l,r\)嘛。
然后我们考虑\(f(i, j, l, r)\)如何转移。
当闭区间\([i,j]\)内元素全部在\([l,r]\)之间的时候,显然\(f(i, j, l, r)=0\)
当闭区间\([i,j]\)内元素全部不在\([l,r]\)之间的时候,显然\(f(i, j, l, r)=g(i, j)\)
\(f(i, j, l, r)=g(i, j)\)似乎构成了循环依赖?
那么,我们枚举\(l,r\)的时候,必须保证区间\([i,j]\)内存在至少一个数字在\([l,r]\)内,这样就不会有循环依赖了。
解决了\(f(i, j, l, r)\)的边界问题,接下来看如何转移。
像普通的区间dp一样,我们枚举区间的分裂点\(k\),然后把区间\([i,j]\)分裂成\([i,k]\)\([k+1,j]\)两部分,递归做下去。
有式子:
\(\Large f(i, j, l, r) = \min f(i, k, l, r) + f(k+1, j, l, r)\)

感受一下,感觉似乎是能处理各种情况的?
但是实际上和刚刚的做法没有任何区别。
因为对于状态\(f(i, j, l, r)\),我们仍然保留了\([l,r]\)之间的所有数字,仍然是那么的仁慈。
我们需要加一种暴力斩掉所有数字的情况。
有式子:
\(\Large f(i, j, l, r) = \min g(i, k) + f(k+1, j, l, r)\)

仔细感受一下,这两个\(f(i, j, l, r)\)的转移式结合起来之后,就可以处理掉所有情况了!
时间复杂度仍然是\(O(n^5)\)
实现采用记忆化搜索,效果棒棒哒~
真是一道神题啊。。。


Code

#include <cstring>
#include <algorithm>
#include <cstdio>
#include <iostream>

using namespace std;
const int N = 52;
const int W = 1010;
const int INF = 0x3f3f3f3f;
int _w;

int bmin( int &a, int b ) {
    return a = b < a ? b : a;
}

int n, a, b, w[N];
int vis[W], num[N], m;
int f[N][N][N][N], g[N][N];
int F( int, int, int, int );
int G( int, int );

void discrete() {
    for( int i = 1; i <= n; ++i )
        vis[w[i]] = 1;
    m = 1;
    for( int i = 1; i < W; ++i )
        if( vis[i] )
            vis[i] = m, num[m++] = i;
    --m;
    for( int i = 1; i <= n; ++i )
        w[i] = vis[w[i]];
}

bool contain( int i, int j, int l, int r ) {
    for( int p = i; p <= j; ++p )
        if( w[p] >= l && w[p] <= r )
            return true;
    return false;
}

bool all( int i, int j, int l, int r ) {
    for( int p = i; p <= j; ++p )
        if( w[p] < l || w[p] > r )
            return false;
    return true;
}

int F( int i, int j, int l, int r ) {
    int &now = f[i][j][l][r];
    if( now != -1 ) return now;
    if( all(i, j, l, r) ) return now = 0;
    if( !contain(i, j, l, r) ) return now = G(i, j);
    now = INF;
    for( int k = i; k < j; ++k ) {
        bmin( now, F(i, k, l, r) + F(k+1, j, l, r) );
        bmin( now, G(i, k) + F(k+1, j, l, r) );
    }
    // printf( "f[%d][%d][%d][%d] = %d\n", i, j, l, r, now );
    return now;
}

int G( int i, int j ) {
    int &now = g[i][j];
    if( now != -1 ) return now;
    now = INF;
    for( int l = 1; l <= m; ++l )
        for( int r = l; r <= m; ++r )
            if( contain(i, j, l, r) ) {
                int u = num[l], v = num[r];
                bmin( now, F(i, j, l, r) + (v-u)*(v-u)*b + a );
            }
    // printf( "g[%d][%d] = %d\n", i, j, now );
    return now;
}

int main() {
    cin >> n >> a >> b;
    for( int i = 1; i <= n; ++i )
        cin >> w[i];
    discrete();
    memset(f, -1, sizeof f);
    memset(g, -1, sizeof g);
    printf( "%d\n", G(1, n) );
    return 0;
}
posted @ 2018-01-19 12:43 mlystdcall 阅读(...) 评论(...) 编辑 收藏