GarsiaWachs

GarsiaWachs算法可以把时间复杂度压缩到O(nlogn)

具体的算法及证明可以参见《The Art of Computer Programming》第3卷6.2.2节Algorithm G和Lemma W,Lemma X,Lemma Y,Lemma Z。
现在说一个概要吧:
设一个序列是A[0..n-1],每次寻找最小的一个满足A[k-1]<=A[k+1]的k,(方便起见设A[-1]和A[n]等于正无穷大)
那么我们就把A[k]与A[k-1]合并,之后找最大的一个满足A[j]>A[k]+A[k-1]的j,把合并后的值A[k]+A[k-1]插入A[j]的后面。
有定理保证,如此操作后问题的答案不会改变。
举个例子:
186 64 35 32 103
因为35<103,所以最小的k是3,我们先把35和32删除,得到他们的和67,并向前寻找一个第一个超过67的数,把67插入到他后面
186 64(k=3,A[3]与A[2]都被删除了) 103
186 67(遇到了从右向左第一个比67大的数,我们把67插入到他后面) 64 103
186 67 64 103 (有定理保证这个序列的答案加上67就等于原序列的答案)
现在由5个数变为4个数了,继续!
186 (k=2,67和64被删除了)103
186 131(就插入在这里) 103
186 131 103
现在k=2(别忘了,设A[-1]和A[n]等于正无穷大)
234 186
420
最后的答案呢?就是各次合并的重量之和呗。420+234+131+67=852。
 
证明嘛,基本思想是通过树的最优性得到一个节点间深度的约束,之后
证明操作一次之后的解可以和原来的解一一对应,并保证节点移动之后他所在的
深度不会改变。详见TAOCP。
 
具体实现这个算法需要一点技巧,精髓在于不停快速寻找最小的k,即维护一个“2-递减序列”
朴素的实现的时间复杂度是O(n*n),但可以用一个平衡树来优化(好熟悉的优化方法),使得最终复杂度为O(nlogn)
 
 
解题思路:(这是我找到的一个关于GarsiaWachs算法的解释)
 
      1. 这类题目一开始想到是DP, 设dp[i][j]表示第i堆石子到第j堆石子合并最小得分.
 
         状态方程: dp[i][j] = min(dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1]);
 
         sum[i]表示第1到第i堆石子总和. 递归记忆化搜索即可.
 
      2. 不过此题有些不一样, 1<=n<=50000范围特大, dp[50000][50000]开不到这么大数组.
 
         问题分析:
 
         (1). 假设我们只对3堆石子a,b,c进行比较, 先合并哪2堆, 使得得分最小.
 
              score1 = (a+b) + ( (a+b)+c )
 
              score2 = (b+c) + ( (b+c)+a )
 
              再次加上score1 <= score2, 化简得: a <= c, 可以得出只要a和c的关系确定,
 
              合并的顺序也确定.
 
         (2). GarsiaWachs算法, 就是基于(1)的结论实现.找出序列中满足stone[i-1] <=
 
              stone[i+1]最小的i, 合并temp = stone[i]+stone[i-1], 接着往前面找是否
 
              有满足stone[j] > temp, 把temp值插入stone[j]的后面(数组的右边). 循环
 
              这个过程一直到只剩下一堆石子结束.
 
        (3). 为什么要将temp插入stone[j]的后面, 可以理解为(1)的情况
 
             从stone[j+1]到stone[i-2]看成一个整体 stone[mid],现在stone[j],
 
             stone[mid], temp(stone[i-1]+stone[i-1]), 情况因为temp < stone[j],
 
             因此不管怎样都是stone[mid]和temp先合并, 所以讲temp值插入stone[j]
 
             的后面是不影响结果.
 
以POJ1738 为例
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
#define MAX 50005
int n;
int a[MAX];
int num, result;
void combine(int k){
 int i, j;
 int temp = a[k]+a[k-1];
 result += temp;
 for(i = k; i < num-1; ++i)
  a[i] = a[i+1];
 num--;
 for(j = k-1; j > 0 && a[j-1] < temp; --j)
  a[j] = a[j-1];
 a[j] = temp;
 while(j >= 2 && a[j] >= a[j-2]){
  int d = num-j;
  combine(j-1);
  j = num-d;
 }
}
int main(){
 int i;
 while(scanf("%d", &n) != EOF){
  if(n == 0) break;
  for(i = 0; i < n; ++i)
   scanf("%d", &a[i]);
  num = 1;
  result = 0;
  for(i = 1; i < n; ++i){
   a[num++] = a[i];
   while(num >= 3 && a[num-3] <= a[num-1])
    combine(num-2);
  }
  while(num > 1) combine(num-1);
  printf("%d\n", result);
 }
 return 0;
}

 

posted @ 2017-07-26 17:45  浅忆~  阅读(590)  评论(0)    收藏  举报