分治策略(2)——算法导论(4)

1. 引言

    这一篇博文主要介绍基于分治策略的矩阵乘法的Strassen算法

2. 矩阵乘法的Strassen算法

(1) 普通矩阵乘法算法

    矩阵乘法的基本算法的计算规则是:

        若A=(aij)和B=(bij)是n×n的方阵(i,j = 1,2,3...),则C = A · B中的元素Cij为:

image

    下面给出Java实现代码:

	public static void main(String[] args) {
		int[][] a = new int[][] { //
				{ 1, 0, 1, 2 }, //
				{ 1, 2, 0, 2 }, //
				{ 0, 2, 1, 0 }, //
				{ 0, 0, 1, 2 },//
		};
		int[][] b = new int[][] { //
				{ 1, 0, 1, 2 }, //
				{ 1, 2, 0, 2 }, //
				{ 0, 2, 1, 0 }, //
				{ 0, 0, 1, 2 },//
		};
		printMatrix(squareMatrixMutiply(a, b));
	}


	/**
	 * 基本矩阵乘法(假定矩阵a和矩阵b都是n×n的矩阵,且n为2的幂)
	 * @param a 矩阵a
	 * @param b 矩阵b
	 * @return
	 */
	private static int[][] squareMatrixMutiply(int[][] a, int[][] b) {
		int[][] c = new int[a.length][a.length];
		for (int i = 0; i < c.length; i++) {
			for (int j = 0; j < c.length; j++) {
				c[i][j] = 0;
				for (int k = 0; k < c.length; k++) {
					c[i][j] += a[i][k] * b[k][j];
				}
			}
		}
		return c;
	}
	
	/**
	 * 打印矩阵
	 * 
	 * @param matrix
	 */
	private static void printMatrix(int[][] matrix) {
		for (int[] is : matrix) {
			for (int i : is) {
				System.out.print(i + "\t");
			}
			System.out.println();
		}
	}

 

结果:image

 

 

 

(2) 一个简单的分治算法

 

    为简单起见,当使用分治法(Divide and Conquer)计算矩阵C=A*B时,假定三个矩阵都是n×n的矩阵,并且n为2的幂。分治法(Divide and Conquer)还是上一篇提到的三个步骤,算法的核心就是这个公式:

image

    其中,Aij,Bij,Cij分别是A,B,C矩阵的n / 2 * n / 2的子矩阵,即:

image

    值得说明的是,我们不必创建子数组,那将浪费θ(n²)的时间来复制数组元素;明智的做法是直接根据下标运算。

下图是原书的伪代码(其中所说的“(4.9)”即为上图所给的三个等式):

image

下面给出Java实现代码:

public static void main(String[] args) {
	int[][] a = new int[][] { //
			{ 1, 0, 1, 2 }, //
			{ 1, 2, 0, 2 }, //
			{ 0, 2, 1, 0 }, //
			{ 0, 0, 1, 2 },//
	};
	int[][] b = new int[][] { //
			{ 1, 0, 1, 2 }, //
			{ 1, 2, 0, 2 }, //
			{ 0, 2, 1, 0 }, //
			{ 0, 0, 1, 2 },//
	};
	printMatrix(squareMatrixMutiplyByRecursive(new ChildMatrix(a, 0, 0, a.length), new ChildMatrix(b, 0, 0, b.length), 0, 0, 0, 0));
}

/**
 * 打印矩阵
 * 
 * @param matrix
 */
private static void printMatrix(int[][] matrix) {
	for (int[] is : matrix) {
		for (int i : is) {
			System.out.print(i + "\t");
		}
		System.out.println();
	}
}

/**
 * 基于分治法的矩阵乘法
 * 
 * @param a
 * @param b
 * @return
 */
private static int[][] squareMatrixMutiplyByRecursive(ChildMatrix matrixA, ChildMatrix matrixB, int lastStartRowA, int lastStartColumnA, int lastStartRowB,
		int lastStartColumnB) {
	int[][] c = new int[matrixA.length][matrixA.length];
	if (matrixA.length == 1) {
		c[0][0] = matrixA.getFromParentMatrix(matrixA.startRow, matrixA.startColumn) * //
				matrixB.getFromParentMatrix(matrixB.startRow, matrixB.startColumn);
		return c;
	}
	int childLength = matrixA.length / 2;
	// 第一步:分解
	ChildMatrix childMatrixA11 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA, lastStartColumnA, childLength);
	ChildMatrix childMatrixA12 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA, lastStartColumnA + childLength, childLength);
	ChildMatrix childMatrixA21 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA + childLength, lastStartColumnA, childLength);
	ChildMatrix childMatrixA22 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA + childLength, lastStartColumnA + childLength, childLength);

	ChildMatrix childMatrixB11 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB, lastStartColumnB, childLength);
	ChildMatrix childMatrixB12 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB, lastStartColumnB + childLength, childLength);
	ChildMatrix childMatrixB21 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB + childLength, lastStartColumnB, childLength);
	ChildMatrix childMatrixB22 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB + childLength, lastStartColumnB + childLength, childLength);
	// 第二步:解决
	int[][] temp1 = squareMatrixMutiplyByRecursive(childMatrixA11, childMatrixB11, 0, 0, 0, 0);
	int[][] temp2 = squareMatrixMutiplyByRecursive(childMatrixA12, childMatrixB21, 0, childLength, childLength, 0);
	int[][] c11 = sumMatrix(temp1, temp2);

	int[][] temp3 = squareMatrixMutiplyByRecursive(childMatrixA11, childMatrixB12, 0, 0, 0, childLength);
	int[][] temp4 = squareMatrixMutiplyByRecursive(childMatrixA12, childMatrixB22, 0, childLength, childLength, childLength);
	int[][] c12 = sumMatrix(temp3, temp4);

	int[][] temp5 = squareMatrixMutiplyByRecursive(childMatrixA21, childMatrixB11, childLength, 0, 0, 0);
	int[][] temp6 = squareMatrixMutiplyByRecursive(childMatrixA22, childMatrixB21, childLength, childLength, childLength, 0);
	int[][] c21 = sumMatrix(temp5, temp6);

	int[][] temp7 = squareMatrixMutiplyByRecursive(childMatrixA21, childMatrixB12, childLength, 0, 0, childLength);
	int[][] temp8 = squareMatrixMutiplyByRecursive(childMatrixA22, childMatrixB22, childLength, childLength, childLength, childLength);
	int[][] c22 = sumMatrix(temp7, temp8);
	// 第三步:合并
	for (int i = 0; i < c.length; i++) {
		for (int j = 0; j < c.length; j++) {
			if (i < childLength && j < childLength) {
				c[i][j] = c11[i][j];
			} else if (i < childLength && j < c.length) {
				int[][] child = c12;
				c[i][j] = child[i][j - childLength];
			} else if (i < c.length && j < childLength) {
				int[][] child = c21;
				c[i][j] = child[i - childLength][j];
			} else {
				int[][] child = c22;
				c[i][j] = child[i - childLength][j - childLength];
			}
		}
	}
	return c;
}

private static int[][] sumMatrix(int[][] a, int[][] b) {
	int[][] c = new int[a.length][b.length];
	for (int i = 0; i < a.length; i++) {
		for (int j = 0; j < a.length; j++) {
			c[i][j] += a[i][j];
			c[i][j] += b[i][j];
		}
	}
	return c;
}

/**
 * ChildMatrix 表示某个矩阵的一个子矩阵
 * 
 * @author D.K
 *
 */
static class ChildMatrix {
	/**
	 * 父矩阵
	 */
	int[][] parentMatrix;
	/**
	 * 子矩阵在父矩阵中的起始行坐标
	 */
	int startRow;
	/**
	 * 子矩阵在父矩阵中的起始列坐标
	 */
	int startColumn;
	/**
	 * 子矩阵长度
	 */
	int length;

	public ChildMatrix(int[][] parentMatrix, int startRow, int startColumn, int length) {
		super();
		this.parentMatrix = parentMatrix;
		this.startRow = startRow;
		this.startColumn = startColumn;
		this.length = length;
	}

	/**
	 * 获取父矩阵的row行,colum列元素
	 * 
	 * @param row
	 * @param colum
	 * @return
	 */
	public int getFromParentMatrix(int row, int colum) {
		return parentMatrix[row][colum];
	}
}

 

结果是:image

 

 

(3) Strassen算法

 

    Strassen算法的核心思想是令递归树稍微不那么茂盛,它只进行7次递归(上面的分治法地递归了8次)。Strassen算法的描述如下:

    ① 分解矩阵A,B,C为image

同样不要创建子数组而只是进行下标计算。

    ② 创建10个n/2 ×n/2的矩阵S1,S2,S3…,S10,其计算公式如下:

QQ截图20150913101504

    ③ 递归地计算7个矩阵积P1, P2…P3,P7,计算公式如下:

image

    ④ 计算Cij,计算公式如下:

 

未标题-1    实现代码就不给出了,与上面类似。

 

3. 算法分析

(1) 普通矩阵乘法

    对于普通的矩阵乘法,3次嵌套循环,每层执行n次,所需时间为θ(n³);

(2) 简单分治算法

    ① 基本情况:T(1) = θ(1);

    ② 递归情况:分解后,矩阵规模变为原来的1/2。递归八次,用时8T(n/2);4次矩阵加法,每个矩阵中的元素个数为n² / 4, 用时θ(n²);其余用时θ(1)。因此共用时8T(n/2) + θ(n²)。

image

    可解得,T(n)  = θ(n³)。可看出分治算法并不优于普通矩阵乘法

(3) Strassen算法

   Strassen算法分析与上面基本一致,不同的是只进行了7次递归,并且额外多了几次n / 2 × n / 2矩阵的加法,但只是常数次。Strassen算法用时为:

image

可解得,T(n) = θ(n^lg7);

posted @ 2015-09-13 13:36  学数学的程序猿  阅读(976)  评论(2编辑  收藏  举报