动态规划(2)——算法导论(17)

写在前面

在上一篇博客中,学习了钢条切割问题。这一篇博客再来学习另一个典型的动态规划问题:矩阵乘法链问题

提出问题

我们知道,矩阵的乘法是满足结合律的,即对于矩阵A,B,C 满足(A B) C = A (B C) 。不同的结合方式会导致最终所作的乘法总次数不同。

例如:对于矩阵 A(规模为10 x 100),B(100 x 5),C(5 x 50),如果按照( ( A B ) C )的结合方式,计算D = ( A B )将需要作10 x 100 x 5 = 5 000次乘法,计算 D x C 要做 10 x 5 x 50 = 2 500次乘法,总共要做7 500 次乘法;若按照 ( A (B C) )的结合方式,同理可算出一共需要计算75 000次乘法!二者相差1个数量级。

由上可见,找出最优的结合方式(最优括号化方案),使总的乘法数最少能极大的加快矩阵乘法链的计算速度。这便是矩阵乘法链问题:

给定n个矩阵的链\(A_1A_2...A_n\),其中矩阵\(A_i\)的规模为\(p_{i-1}×p_i\),求完全括号化方案,使得计算乘积\(A_1A_2...A_n\)所需的标量乘法次数最少。

暴力求解

最容易想到的是采用暴力求解的方法来找出最优结合方式。但遗憾的是,这不是一个高效的算法。因为,当n = 1时,只有唯一一种结合方式;当n > 1时,可以将总的结合方式的数目看做是两部分结合方式的数目的乘积,即:

\[ A_1A_2...A_n = ((A_1A_2...A_k)(A_{k+1}A_{k+2}...A_n)),k = 1,2..n-1 \]

因此,对于长度为n的矩阵乘法链,总的结合方式的数目P(n)可以用如下递归公式表示:

\[P(n) = \begin {cases} 1 & n = 1\\ \sum\limits_{k=1}^{n-1}[P(k)P(n-k)] & n \geq 1 \end{cases} \]

可以证明该公式的结果为 \(\Omega(2^n)\)

动态规划

下面介绍采用动态规划的方法来解决这个问题。

最优括号化方案的结构特征

动态规划方法的第一步是要寻找最优子结构,然后利用该子结构,从子问题的最优解中构造出原问题的最优解。

考虑对\(A_iA_{i+1}...A_j\)的任意一种括号化方案,其最终一定是对某两部分求积,即有如下的形式:

\[(A_iA_{i+1}...A_k)(A_{k+1}A_{k+2}...A_j) \]

至于 \(k\) 为多少,以及 \(A_iA_{i+1}...A_k\)\(A_{k+1}A_{k+2}...A_j\) 的内部又该如何括号化暂且不管。于是该问题就变成了先求解矩阵 \(A_iA_{i+1}...A_k\)\(A_{k+1}A_{k+2}...A_j\) 的乘积,然后再计算两乘积的乘积,即为最终矩阵相乘的结果。

2. 一个递归求解方案

\(m[i, j]\) 表示计算矩阵乘法链 \(A_iA_{i+1}...A_j\) 所需的标量乘法数目。

当 $i = j $ 时,矩阵链只包含唯一一个矩阵,因此 \(m[i, j] = 0\)

\(i < j\) 时,假设最优括号化方案在 \(k (i <= k < j)\) 时取得。设矩阵 \(A_n (n=i...j)\) 的维数是 \(p_{n-1}\) x $ p_n$,则:

\[m[i, j] = m[i, k] + m[k+1, j] + p_{i-1}p_kp_j \]

于是最优括号化方案可用如下公式描述:

\[m[i, j] = \begin {cases} 0 & i = j\\ \min\limits_{i\leq k = j}\{m[i, k] + m[k+1. j] + p_{i-1} p_k p_{j} \} & i < j \end{cases} \]

3. 计算最优代价

可以很容易地根据上述递归公式写出一个递归算法,来计算 \(A_1A_2...A_n\) 相乘的最小时间代价 \(m[1, n]\)。但该递归算法显然不比暴力搜索的方案好,时间复杂度仍为指数时间。

注意到,在该问题中,问题规模较大的子问题中又包含了问题规模较小的子问题,即所有子问题在求解时有重叠。这让我们想到了钢条切割问题中的处理办法:使用备忘录记录下每个子问题的结果,以便再次求解其时,可以直接得到答案。

不难看出,每对满足 \(1 <= i <= j <= n\) 的任意 \((i, j)\) 组合,都对应者一个唯一的子问题,故子问题的总个数为:$ C_n^2 + n$。( \(C_n^2\) 表示中 \(n\) 个中选 2 个的组合数)。

像这种子问题重叠的性质是应用动态规划的另一个标识。(第一种标识是最优子结构)

下面给出带备忘的自底向上方法的Java实现:

/**
 * p[i](0 <= i < p.lenth) 表示第(i + 1)个矩阵的行数,因此第(i + 1)个矩阵的 规模为p[i] × p[i+1]
 */
public static int[][] matrixChainOrder(int[] p) {
	int n = p.length - 1;// n为待求矩阵链的总长度,待求矩阵链为A0A1...A(n-1)
	int[][] record = new int[n][n];// record[i][j]表示AiA(i+1)...Aj最优括号化方案的结果
	// l为子矩阵链的长度,l = 2 to n(长度为1只包含一个矩阵,不需要作乘积,因此不考虑)
	for (int l = 2; l <= n; l++) {
		// i = 0 to n - l,表示起始矩阵的下标
		for (int i = 0; i <= n - l; i++) {
			int j = i + l - 1; // j 表示结束矩阵的下标
			record[i][j] = Integer.MAX_VALUE;
			// k = i to j-1,表示分割点
			for (int k = i; k < j; k++) {
				int q = record[i][k] + record[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
				if (q < record[i][j]) {
					record[i][j] = q;
				}
			}
		}
	}
	return record;
}

可以看出,上述算法的时间复杂度为 \(O(n^3)\),并且还需要 \(O(n^2)\) 的内存空间来保存 record 数组。但这比暴力求解的指数时间复杂度高效的多。

下面做一个测试,对于一个长度为 6 的矩阵链,求其最优括号化方案所需要作的标量乘法次数。其中每个矩阵的规模如下:

\[A_1 : 30 × 35,A_2 : 35× 15, A_3 : 15 × 5,A_4 : 5× 10, A_5 : 10 × 20,A_6 : 20× 25 \]

此时,输入参数 int p[] = {30, 35, 15, 5, 10, 20, 25},代入到上面的 matrixChainOrder() 方法中,求得结果为 record[0][p.length - 2] = 15125

4. 构造最优解

上述 matrixChainOrder() 方法只能求出(只记录了)各子链问题的最优方案需要进行的标量乘法的数目,而未记录其最优方案的分割方法,即k值。因此,我们可以改进一下,把k值也保存下来,并且把最终的最优括号化方案“友好的”打印出来。下面是改进后的实现代码:

/**
 * p[i](0 <= i < p.lenth) 表示第(i + 1)个矩阵的行数,因此第(i + 1)个矩阵的 规模为p[i] × p[i+1]
 */
public static int[][] matrixChainOrder(int[] p) {
	int n = p.length - 1;// n为待求矩阵链的总长度,待求矩阵链为A0A1...A(n-1)
	int[][] record = new int[n][n]; // record[i][j]表示AiA(i+1)...Aj最优括号化方案的结果
	int[][] cut = new int[n][n]; // cut[i][j]表示AiA(i+1)...Aj最优分割点
	// l为子矩阵链的长度,l = 2 to n(长度为1只包含一个矩阵,不需要作乘积,因此不考虑)
	for (int l = 2; l <= n; l++) {
		// i = 0 to n - l,表示起始矩阵的下标
		for (int i = 0; i <= n - l; i++) {
			int j = i + l - 1; // j 表示结束矩阵的下标
			record[i][j] = Integer.MAX_VALUE;
			// k = i to j-1,表示分割点
			for (int k = i; k < j; k++) {
				int q = record[i][k] + record[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
				if (q < record[i][j]) {
					record[i][j] = q;
					cut[i][j] = k;
				}
			}
		}
	}
	// 打印 最终括号化方案
	print(cut, 0, n - 1);
	return record;
}

// 打印 括号化方案
public static void print(int[][] cut, int i, int j) {
	if (i == j) {
		System.out.print("A" + i);
		return;
	}
	System.out.print("(");
	print(cut, i, cut[i][j]);
	print(cut, cut[i][j] + 1, j);
	System.out.print(")");
}
posted @ 2016-04-28 15:37  学数学的程序猿  阅读(1362)  评论(0编辑  收藏  举报