mskitten

又惘又怠

dp方法论——由矩阵相乘问题学习dp解题思路

前篇戳:dp入门——由分杆问题认识动态规划

导语

刷过一些算法题,就会十分珍惜“方法论”这种东西。Leetcode上只有题目、讨论和答案,没有方法论。往往答案看起来十分切中要害,但是从看题目到得到思路的那一段,就是绕不过去。楼主有段时间曾把这个过程归结于智商和灵感的结合,直到有天为了搞懂Leetcode上一位老兄的题型总结,花两天时间学习了回溯法,突然有种惊为天人的感觉——原来真正掌握一个算法是应该触类旁通的,而不是将题中一个细节换掉就又成了新题……

掌握方法论绝对是一种很爽的感觉。看起来好像很花费时间,其实是一种“因为慢,所以快”的方法。以前可能你学习一个dp题目要大半天;当你花了半个周时间,学会了dp的套路,你会发现,有些medium的dp题甚至不需要半个小时就能做完,而且从头到尾不需提示,全靠自己!

方法论

那么,怎么从一个看起来毫无头绪的问题出发,找到解题的思路并用dp将问题解出来呢?本文以矩阵相乘问题为例,给出dp问题的一般解题思路。

当然,按照思路解题的前提是你已经知道这道题要用dp去解,如何确定一个问题可以用dp去解,则是下一篇要讨论的话题。

下面就是动态规划的一般解题思路:

  1. 分析最优解的特征。
  2. 递归地定义最优解的值。
  3. 计算最优解的值。
  4. 根据计算好的信息构造最优解。

看起来非常抽象是吧?在这里不需要完全理解。等你看完全文再回来,保你会有不一样的感受。

矩阵相乘问题

问题

这是一个看起来可能有点抽象的数学问题,但请你耐心往下看。当你看完解法时,你会惊异于动态规划的魔力。

题目:给出一个由n个矩阵组成的矩阵链<A1,A2,...,An>,矩阵Ai的秩为pi-1×pi。将A1A2...An这个乘积全括号化,使得计算这个乘积所需要的的标量乘法最少。

全括号化是以一种递归的形式定义的:

一个全括号化的乘积只有两种可能:一是一个单个矩阵;二是两个全括号化的乘积的乘积。

天啦也太绕了,举个例子吧。对于矩阵链<A1,A2,A3,A4>的乘积,共有五种全括号化的方法:

(A1(A2(A3A4))),

(A1((A2A3)A4)),

((A1A2)(A3A4)),

(((A1A2)A3)A4),

((A1(A2A3))A4)

我们知道矩阵乘法是满足结合律的,所以以上五个式子的乘积相等,但是它们的运算时间是否相等呢?

矩阵乘法的运算时间

我们知道,矩阵乘法的定义是:

两个互相兼容的矩阵A,B可以相乘。互相兼容是指A的列数与B的行数相等。假如A是一个p×q的矩阵,而B是一个q×r的矩阵,则乘积C是一个p×r的矩阵且有

cij = ∑ aik·bkj, k = 1,...,q.

由于要对C中的每一个元素进行计算(共q·r个元素),而每次运算要做q次乘法,所以总的运算时间为pqr。

来看看让乘积中的不同因子结合对运算时间有什么影响。假设我们有 <A1,A2,A3>这个矩阵链,三个矩阵的秩分别为10×100, 100×5和5×50。则

  • ((A1A2)A3)的运算时间为10×100×5+10×5×50=7500;
  • (A1(A2A3))的运算时间为100×5×50+10×100×50=75000。

按照不同的顺序做矩阵乘法,所需要的乘法次数竟相差10倍。

初步分析

按照惯例,我们来感受一下穷举的算法复杂度。

假设有一个长度为n的矩阵链,我们通过遍历所有的全括号化的可能性来解题。设全括号化的可能性数目为P(n)。当n为1时,矩阵链只有一个矩阵,符合全括号化的定义;当n>=2时,全括号化后为两个矩阵的乘积,即((...)(...))的形式。用递归的思路去分析,则中间两个括号的分界位置有n-1种可能,如下面竖线所示

A1|A2|A3|...|An

当分界线将矩阵链分为长度为k和n-k的两个子矩阵链时,全括号化可能性为P(k)P(n-k)。我们对所有的k值求和,就得出给整个矩阵链全括号化的数目:

P(n) = ∑ P(k)P(n-k), k=1...n-1   (n>=2)

这是一个卡塔兰数(Catalan Number),它的增长速率为Ω(4n/n3/2),它的渐进值为Ω(2n)

对渐进值还不太熟,如果有小伙伴明白“增长速率”和“渐进值”之间的关系,欢迎指教。

总的来说,如果对这个题目使用穷举法,算法复杂度是指数的。后面我们分析了dp的算法复杂度,再来比较。

用dp方法论解题

算法的学习永远没有“手把手”这一说。如果你在认真学习这篇文章,希望你能做到比你看到的小节思路提前一点。比如,在看第一步前,先对这个题目有一点大致思路,明白让自己迷茫的点在哪里;看第x步前,对第x步的内容在心中有一个猜测。这样做比起完全放弃思考,只是跟着文章的思路走,收获会大很多。

第一步:分析最优解的特征

这一步的精髓是分析最优子解如何构成最优解

在上一节中已经提到,对于n>=2的情况,全括号化后为((chain_1)(chain_2))的形式。这样,问题自然而然地分成了两个子问题:求前后两个子括号中的最优解。

假设对于某种特定的分割(即chain_1chain_2之间的分界线位置固定),chain_1的秩为m×p,其内部的标量乘法数目为x;chain_2的秩为p×n,其内部的标量乘法数目为y。则整个矩阵链的乘法次数为x+y+mpn。由于m,p,n是固定的,我们需要让x和y为最小值从而使整个矩阵链的乘法次数最小。即,对于某种特定的分割,两个子括号中的最优解构成整个问题的最优解的一个选项

总结来说,我们将矩阵乘积简略地看成两个子矩阵的乘积,这两个子矩阵的分界有n-1种可能。对每一种可能,问题被分割成两个子问题,即求左右两个子矩阵链的最优解。如果遍历这n-1种可能并选出最好的一个,那就是整个问题的最优解。

第二步:递归地定义最优解的值

第二步非常关键,是我们将前后思路打通的一步。

第一步中提出了一个比较简单的思路,即把矩阵链分割成左右两个子矩阵链。既然有了这个初步思路,我们就来涂鸦一番,看看这个思路是否可行。

对于递归性的问题,一个很好的方法是画递归树,这样会使得问题看起来比较具象,而且也会暴露一些算法上的问题,比如重叠子树等。画递归树的时候,最好举一个实际的例子。这里我们假设有一个长度为4的矩阵链<A1,A2,A3,A4>,简单地画一下它的子问题分割:

 

上图中的数字表示子矩阵链的长度,根为4,即初始矩阵链;它可以分为1+3,2+2,3+1三种情况,这三种情况又可以各自细分。

这里暴露了一个问题,请看图中的两个涂色的子树。两个子树的节点数字是一样的。但是左边这个子树的根节点3代表的是A2A3A4这个乘积;而右边这个代表的是A1A2A3这个乘积。由于A1,A2,A3,A4四个矩阵的秩是未知的,它们很可能不相同,则A1A2A3A2A3A4的最优解也很有可能不同。换言之,它们并不是同一个子问题,它们的子子树也并不相同。

这个问题意味着我们对子问题的定义不够严谨——子问题不能只用长度这个变量来确定。也就是说,如果在bottom-up的dp中用一个数组记录子问题的值,那么这个数组应该是一个二维数组。子问题不仅应该由子矩阵链的长度确定,还要加上起始index这样的信息。

为了更通用一些,我们不用起始index+长度,而选用起始index+结束index的定义方法,这是二维dp的惯用套路,在许多字符串和数组有关的问题中都有用到。

设用一个二位矩阵dp[][]存取子问题的解。定义dp[i][j](1<=i<=j<=n)的值为Ai...Aj的最小乘法次数。则按照以上的思路,可以把Ai...Aj再递归细分为子问题Ai...AkAk+1...Aj(i<=k<j),则Ai...Aj的最优解值为两个子问题最优解的和+两个子矩阵链相乘的乘法次数。即有

i==j时,dp[i][j] = 0;

i <j时,dp[i][j] = min{dp[i][k] + dp[k+1][j] + pi-1pkpj}, k = i...j-1 (p为各个矩阵的秩,见题目一节)

到此为止,最关键的一步顺利完成啦(楼主写得好累,击掌╭(○`∀´○)╯╰(○'◡'○)╮)。在这一步中,我们递归地定义了子问题最优解的值,完成了算法最核心的设计部分。在后面两步中,我们只要把上面这两个式子翻译成代码,再注意一些实现细节就可以了。

第三步:计算最优解的值

细节一

从第二步顺理成章,我们会在一个二维数组里记录子问题的解。但是按照什么顺序去填这个二维数组是个问题。

还是举例子,在<A1,A2,A3,A4>这个矩阵链中,我们会有一个5×5的二维数组,随便挑选dp[1][4]这个元素举例。根据第二步中的状态转移方程,有

dp[1][4] = min{(dp[1][1]+dp[2][4]+...),(dp[1][2]+dp[3][4]+...),(dp[1][3]+dp[4][4]+...)}

省略号表示我们此处不需关注pi-1pkpj这一项,只需要看这个格子对其它格子的依赖是什么样子。

由上图可以看出,要计算某一个元素(粉色边框),我们需要其左边下面的元素(同样深度的蓝色表示一组数据)。

所以,我们的遍历方向是从下到上,从左到右

细节二

细心的读者可能注意到还有一个问题,就是我们一直在求“最优解的值”,也就是“最小的乘法次数”,可是题目中要求的是“最优解”,也就是“加括号的方式”。

这两者并不矛盾,专注于求解前者可以让我们先思考相对简单的问题,通常在求解前者的过程中,我们也找出了后者,只是没有将它记录下来。

在此题中,我们可以选择用一个同样的二维矩阵s[][]来记录后者,其中s[i][j]中记录Ai...Aj的分割分界线k。

代码

 1     int matrixChain(int[] p){
 2         int n = p.length - 1; //number of matrices
 3         int[][] dp = new int[n + 1][n + 1]; //we need dp[1][n]
 4         int[][] s = new int[n + 1][n + 1];    //for storing of k
 5         for(int[] row : dp)
 6             Arrays.fill(row, Integer.MAX_VALUE);
 7 
 8         for(int i = 1; i <= n; i++)
 9             dp[i][i] = 0;    //dp[i][j] = 0 when i == j
10         
11         for(int i = n; i >= 1; i--)
12             for(int j = i; j <= n; j++){
13                 if(i == j){
14                     dp[i][j] = 0;
15                 }else{
16                     for(int k = i; k < j; k++){
17                         int count = dp[i][k] + dp[k+1][j] + p[i-1]*p[k]*p[j];
18                         if(count < dp[i][j]){
19                             dp[i][j] = count; //record optimal solution value
20                             s[i][j] = k;      //record splitting point k
21                         }
22                     }
23                 }
24             }
25         return dp[1][n];
26     }

运行一个例子:

即输入的数组p为{30,35,15,5,10,20,25}。

如果在return之前打印出dp[][]和s[][]的值,结果为:

      

从左图可看出最优解为dp[1][6] = 15,125,即最少可以进行一万五千多次乘法。右图记录了对于每一个[i,j]决定的子矩阵链如何进行括号分割。

顺便分享一个ArrayPrinter的util,可以直接用,能打印出上图那样的二维int数组。

 1 public class ArrayPrinter {
 2     public static void print(int[] arr){
 3         printReplacing(false, arr, 0,"");
 4     }
 5     
 6     public static void print(int[][] matrix){
 7         printReplacing(false, matrix, 0,"");
 8     }
 9     
10     public static void printReplacing(int[] arr, int before, String after){
11         printReplacing(true, arr, before, after);
12     }
13     
14     public static void printReplacing(int[][] matrix, int before, String after){
15         printReplacing(true, matrix, before, after);
16     }
17     
18     /*--------------------------private utils-------------------------------*/
19     
20     private static void printReplacing(boolean replace, int[] arr, int before, String after){
21         int maxLen = maxLength(arr);
22         if(replace){
23             for(int i : arr)
24                 print(((i==before)?after:number(i)), maxLen);
25         }else{
26             for(int i : arr)
27                 print(number(i), maxLen);
28         }
29         print("\n", maxLen);
30     }
31     
32     public static void printReplacing(boolean replace, int[][] matrix, int before, String after){
33         int maxLen = maxLength(matrix);
34         if(replace){
35             for(int[] row : matrix){
36                 for(int i : row)
37                     print(((i==before)?after:number(i)), maxLen);
38                 print("\n", maxLen);
39             }
40         }else{
41             for(int[] row : matrix){
42                 for(int i : row)
43                     print(number(i), maxLen);
44                 print("\n", maxLen);
45             }
46         }
47     }
48 
49     private static int maxLength(int[] arr){
50         int maxLen = 0;
51         for(int aint : arr)
52             maxLen = Math.max(Integer.toString(aint).length(), maxLen);
53         return maxLen;
54     }
55     
56     private static int maxLength(int[][] matrix){
57         int maxLen = 0;
58         for(int row[] : matrix)
59             maxLen = Math.max(maxLength(row), maxLen);
60         return maxLen;
61     }
62     
63     //actual printing 
64     private static void print(String s, int length){
65         System.out.print(String.format("%1$"+(length+1)+"s", s));
66     }
67     
68     //formatting of number
69     private static String number(int i){
70         return NumberFormat.getNumberInstance(Locale.US).format(i);
71     } 
72 }
ArrayPrinter

使用方法:

1 ArrayPrinter.printReplacing(dp, Integer.MAX_VALUE, "/");
2 ArrayPrinter.print(s);

第四步:根据计算好的信息构造最优解

还差一步就大功告成。这一步我们要拿着上一步计算出的矩阵s把最终的全括号矩阵乘积打印出来。递归打印即可。

 1     private void printParenthesis(int[][] s, int i, int j) {
 2         if(i == j)
 3             print("A"+i);
 4         else{
 5             print("(");
 6             printParenthesis(s, i, s[i][j]);
 7             printParenthesis(s, s[i][j]+1, j);
 8             print(")");
 9         }
10     }

打印结果:

复杂度

前面说过,穷举法的复杂度大概是O(2n)。在以上的dp算法中,主算法需要填满一个(n+1)×(n+1)的二维数组的上半部分,每填一个元素需要一个长度为j-i的循环,可通过这个思路对j-i进行求和(i=0...n, j=i...n),也可以通过大概估算得到时间复杂度为O(n3),远好于穷举法。

空间复杂度主要由二维数组决定,为O(n2)。

总结

本文主要介绍了解一个dp问题的思路。

dp问题一般有两个显著特点,这一点下一篇会详细讲述:

  • 问题的最优解由子问题的最优解构成
  • 子问题互相重叠

也再复习一下解题的四个步骤,看你现在有没有更深刻的理解:

  1. 分析最优解的特征。               (分析最优子解如何构成最优解)
  2. 递归地定义最优解的值。               (画递归树,定义子问题,写状态转移方程)
  3. 计算最优解的值。                        (写代码求出最优解,如果有要求的话,记录额外信息,为第4步作准备)
  4. 根据计算好的信息构造最优解。       (从第3步记录的信息中构建最优解,在本题中就是括号的写法)

参考资料

算法导论(英文版)3rd Ed. 15.2

posted @ 2018-09-12 14:19  mskitten  阅读(...)  评论(...编辑  收藏