矩阵连乘求解优化

前言

旭东的博客 看到一篇博文:矩阵连乘最优结合 动态规划求解,挺有意思的,这里做个转载【略改动】。

问题

矩阵乘法是这样的,比如\[ A_{ab} B_{bc} = C_{ac} \]

两个矩阵,一个a行,一个c列,行列乘法次数为a*c。一行乘以一列得到C中的一个元素,乘法次数为b,故矩阵乘法AB需要的乘法次数是a*c*b。

我们把b称为接口,那么矩阵连乘的次数是乘积的尺寸乘以中间的接口。中间的接口是矩阵高度,如果尽快能把长得高的矩阵通过乘法消化掉,这些大接口发生作用的机会就少,最终乘法次数就少了。

采用一维数组存储各矩阵高度。每次遍历找到最大值,和左边的矩阵相乘即可,直到最后只剩下一个矩阵。

 

输入参数是数组arr:存储各矩阵高度,最后一个元素为最后一个矩阵的列数,这个数组包含矩阵连乘表达式的所有信息。

loopTimes 是循环次数,循环次数为矩阵个数减1。
arrMaxId函数用于获取数组最大值索引,跳过第一个矩阵,因为第一个矩阵左边没有其他矩阵作为乘数。
由于最后要输出计算式,我们为每个矩阵设置一个名称,这个名称随着乘法的进行发生变化。最终会留下第一个矩阵,其名称就是最终运算式。

代码如下

int matrixMulTimes(vector<int> &arr) {
	int maxId;
	int mulTimes = 0;
	int pre = 0, next=0;
	string str="";
	vector<string>  matrixName(arr.size() - 1);
	string mulLeftStr, mulRightStr;
	int loops = 0;
	int loopTimes = arr.size() - 2;
	while (loops++ < loopTimes) {
		maxId = arrMaxId(arr, 1, arr.size() - 2);
		pre = maxId - 1;
		next = maxId + 1;
		while (arr[pre--] == -1);
		while (arr[next++] == -1);
		mulTimes += arr[++pre] * arr[maxId] * arr[--next];
		arr[maxId] = -1;
		mulLeftStr = matrixName[pre] == "" ? string(1, 'A' + pre) : matrixName[pre];
		mulRightStr = matrixName[maxId] == "" ? string(1, 'A' + maxId) : matrixName[maxId];
		matrixName[pre] = "(" + mulLeftStr +"*" + mulRightStr + ")";
	}
	cout << matrixName[0] << endl;
	return mulTimes;
}  

函数arrMaxId代码如下

int arrMaxId(vector<int> &arr, int begin, int end) {
	if (arr.size() == 0 || arr.size()<end) {
		return -1;
	}
	int maxId = begin;
	for (int i = begin+1; i <= end; ++i) {
		if (arr[i] > arr[maxId]) {
			maxId = i;
		}
	}
	return maxId;
}

上面只能是一个近似最优解,因为每次消去最高的矩阵,可能参与乘法的另一个矩阵也比较高,导致其存活更久更多地参与到运算中去,最后得不偿失。

如果非要得到最优解,可以运算并存储所有子式的运算量,自底向上,直到算出整个乘法算式。这个叫做动态规划,不过复杂度比较高。如果设置1000个矩阵相乘,动态规划可能算十几分钟都不一定有结果,但第一个算法几秒钟就能给出答案。

动态规划细节
每一个子式均由起点和跨度唯一决定。对于n个矩阵相乘,起点至多有n个,最小为0,最大为n-1。跨度至多n种,最小为0,最大为n-1。
怎么用之前的子式得到后面的子式呢?这个还挺麻烦的,得进行遍历,遍历子式中所有肯能的分割点。用两个n阶方阵分别存储每个子式的最少计算次数及分割点。

首先确定跨度,再确定起点,构成一个二级嵌套循环,通过起点的平移确定子式。子式确定后,再在内部嵌套一级循环,遍历子式的可能的分割点,并保存子式的最少计算次数及对应分割点。最后我们得到最大的子式,也就是连乘本身的最少运算次数及分割方案。

代码如下:

//根据记录的分割点,生成最后的矩阵相乘表达式
string make_result(vector<vector<int> > &points, int t1, int t2) {
	if (t1 == t2)
		return string(1, 'A' + t1 - 1);
	int split = points[t1][t2];
	return "(" + make_result(points, t1, split) + "*" + make_result(points, split + 1, t2) + ")";
}

int calculate_M(vector<int> &arr) {

	int matrixNum = arr.size() - 1;

	vector<vector<int> > num(matrixNum + 1, vector<int>(matrixNum + 1));
	vector<vector<int> > points(matrixNum + 1, vector<int>(matrixNum + 1));

	int span;
	int start;
	int end;
	int spiltPoint;
	int mulTimes;
	int rows, columns, interfaces;

	for (span = 1; span < arr.size() - 1; span++) {
		for (start = 1; start + span < arr.size(); start++) {
			end = start + span;
			num[start][end] = INT_MAX;
			for (spiltPoint = start; spiltPoint < end; spiltPoint++) {
				rows = arr[start - 1];
				columns = arr[end];
				interfaces = arr[spiltPoint];

				mulTimes = num[start][spiltPoint] + num[spiltPoint + 1][end] + rows * interfaces * columns;
				if (mulTimes < num[start][end]) {
					points[start][end] = spiltPoint;
					num[start][end] = mulTimes;
				}
			}
		}
	}

	cout << make_result(points, 1, matrixNum) << "\t最少乘法次数为:" << num[1][matrixNum] << endl;
	return 0;
}

 

代码中用到的一些知识

C++提供模版类string,其中一个构造方法可将字符转化为字符串。如 string(1, 'A'+1),第一个参数是源字符延拓次数,这个构造函数将‘B’转化为"B"。

posted @ 2019-10-07 00:07  谷谷非鼠  阅读(802)  评论(0编辑  收藏  举报