GarsiaWachs
GarsiaWachs算法可以把时间复杂度压缩到O(nlogn)
具体的算法及证明可以参见《The Art of Computer Programming》第3卷6.2.2节Algorithm G和Lemma W,Lemma X,Lemma Y,Lemma Z。
现在说一个概要吧:
设一个序列是A[0..n-1],每次寻找最小的一个满足A[k-1]<=A[k+1]的k,(方便起见设A[-1]和A[n]等于正无穷大)
那么我们就把A[k]与A[k-1]合并,之后找最大的一个满足A[j]>A[k]+A[k-1]的j,把合并后的值A[k]+A[k-1]插入A[j]的后面。
有定理保证,如此操作后问题的答案不会改变。
举个例子:
186 64 35 32 103
因为35<103,所以最小的k是3,我们先把35和32删除,得到他们的和67,并向前寻找一个第一个超过67的数,把67插入到他后面
186 64(k=3,A[3]与A[2]都被删除了) 103
186 67(遇到了从右向左第一个比67大的数,我们把67插入到他后面) 64 103
186 67 64 103 (有定理保证这个序列的答案加上67就等于原序列的答案)
现在由5个数变为4个数了,继续!
186 (k=2,67和64被删除了)103
186 131(就插入在这里) 103
186 131 103
现在k=2(别忘了,设A[-1]和A[n]等于正无穷大)
234 186
420
最后的答案呢?就是各次合并的重量之和呗。420+234+131+67=852。
证明嘛,基本思想是通过树的最优性得到一个节点间深度的约束,之后
证明操作一次之后的解可以和原来的解一一对应,并保证节点移动之后他所在的
深度不会改变。详见TAOCP。
具体实现这个算法需要一点技巧,精髓在于不停快速寻找最小的k,即维护一个“2-递减序列”
朴素的实现的时间复杂度是O(n*n),但可以用一个平衡树来优化(好熟悉的优化方法),使得最终复杂度为O(nlogn)
解题思路:(这是我找到的一个关于GarsiaWachs算法的解释)
1. 这类题目一开始想到是DP, 设dp[i][j]表示第i堆石子到第j堆石子合并最小得分.
状态方程: dp[i][j] = min(dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1]);
sum[i]表示第1到第i堆石子总和. 递归记忆化搜索即可.
2. 不过此题有些不一样, 1<=n<=50000范围特大, dp[50000][50000]开不到这么大数组.
问题分析:
(1). 假设我们只对3堆石子a,b,c进行比较, 先合并哪2堆, 使得得分最小.
score1 = (a+b) + ( (a+b)+c )
score2 = (b+c) + ( (b+c)+a )
再次加上score1 <= score2, 化简得: a <= c, 可以得出只要a和c的关系确定,
合并的顺序也确定.
(2). GarsiaWachs算法, 就是基于(1)的结论实现.找出序列中满足stone[i-1] <=
stone[i+1]最小的i, 合并temp = stone[i]+stone[i-1], 接着往前面找是否
有满足stone[j] > temp, 把temp值插入stone[j]的后面(数组的右边). 循环
这个过程一直到只剩下一堆石子结束.
(3). 为什么要将temp插入stone[j]的后面, 可以理解为(1)的情况
从stone[j+1]到stone[i-2]看成一个整体 stone[mid],现在stone[j],
stone[mid], temp(stone[i-1]+stone[i-1]), 情况因为temp < stone[j],
因此不管怎样都是stone[mid]和temp先合并, 所以讲temp值插入stone[j]
的后面是不影响结果.
以POJ1738 为例
#include <cstdio> #include <iostream> #include <cstring> using namespace std; #define MAX 50005 int n; int a[MAX]; int num, result; void combine(int k){ int i, j; int temp = a[k]+a[k-1]; result += temp; for(i = k; i < num-1; ++i) a[i] = a[i+1]; num--; for(j = k-1; j > 0 && a[j-1] < temp; --j) a[j] = a[j-1]; a[j] = temp; while(j >= 2 && a[j] >= a[j-2]){ int d = num-j; combine(j-1); j = num-d; } } int main(){ int i; while(scanf("%d", &n) != EOF){ if(n == 0) break; for(i = 0; i < n; ++i) scanf("%d", &a[i]); num = 1; result = 0; for(i = 1; i < n; ++i){ a[num++] = a[i]; while(num >= 3 && a[num-3] <= a[num-1]) combine(num-2); } while(num > 1) combine(num-1); printf("%d\n", result); } return 0; }
浙公网安备 33010602011771号