Click to Visit Homepage : zzyzz.top


Algorithms - Strassen's algorithm for matrix multiplication 矩阵乘法 Strassen 算法

 

问题:
    求解矩阵乘法  C = A * B, 已知 A, B, C 均为 N x N 的方阵, 切 N 为 2 的幂(为简化问题). 
        A = [[A11, A12], [A21, A22]]
        B = [[B11, B12], [B21, B22]]
        C = [[C11, C12], [C21, C22]]
        
        则(矩阵乘法运算法则):
            C11 = A11 * B11 + A12 * B21
            C12 = A11 * B12 + A12 * B22
            C21 = A21 * B11 + A22 * B21
            C22 = A21 * B12 + A22 * B22
N x N 方阵的常规计算方法:

def squre_matrix_multiply(A, B):
    n = len(A)
    # let c to be a new n x n matrix
    c = [[0 for y in range(n)] for x in range(n)]
    for i in range(n):
        for j in range(n):
            for k in range(n):
                c[i][j] = c[i][j] + A[i][k] * B[k][j]

    print(c)

if __name__ == '__main__':
    A = [[2,1],[3,6]]
    B = [[3,4],[2,2]]
    squre_matrix_multiply(A,B)

结果:
    [[8, 10], [21,24]]
通过分治思想求解:
分治思想: 将 N x N 划分为
4 个 N/2 * N/2 的子矩阵乘积之和.    def squre_matrix_multiply_recursive(A, B): try: n = len(A[0]) except TypeError: n = 1 # let c to be a new nxn matrix c = [[0 for x in range(n)] for y in range(n)] if n == 1: c = [[0],[0]] c[0][0] = A[0] * B[0] else: # partition A, B and C c[0][0] = squre_matrix_multiply_recursive([A[0][0]], [B[0][0]]) \ + squre_matrix_multiply_recursive([A[0][1]], [B[1][0]]) c[0][1] = squre_matrix_multiply_recursive([A[0][0]], [B[0][1]]) \ + squre_matrix_multiply_recursive([A[0][1]], [B[1][1]]) c[1][0] = squre_matrix_multiply_recursive([A[1][0]], [B[0][0]]) \ + squre_matrix_multiply_recursive([A[1][1]], [B[1][0]]) c[1][1] = squre_matrix_multiply_recursive([A[1][0]], [B[0][1]]) \ + squre_matrix_multiply_recursive([A[1][1]], [B[1][1]]) # process the res res = [[0 for x in range(n)] for y in range(n)] for i in range(n): for j in range(n): res[i][j] = sum_list(c[i][j]) return res def sum_list(A): # A: [[6], [0], [2], [0]] res = 0 try: for i in A: res += i[0] except TypeError: res += A return res

if __name__ == '__main__':
    A = [[2,1],[3,6]]
    B = [[3,4],[2,2]]

    print(squre_matrix_multiply_recursive(A,B))
结果: 
  [[
8, 10], [21, 24]]
Strassen 算法:
    Strassen 算法只递归进行 7 次运算 N/2 x N/2 矩阵的乘法(分治算法递归运算8次) . 
    
    1.  创建10个 N/2 x N/2 的矩阵 S1, S2, …, S10.
        S1 = B12 - B22
        S2 = A11 + A12
        S3 = A21 + A22
        S4 = B21 - B11
        S5 = A11 + A22
        S6 = B11 + B22
        S7 = A12 - A22
        S8 = B21 + B22
        S9 = A11 - A21
        S10 = B11 - B12
        
    2.  通过 S1 … S10 构建 P1, P2, …, P7
        P1 = A11 * S1 = A11 * B12 - A11 * B22
        P2 = S2 * B22 = A11 * B22 + A12 * B22
        P3 = S3 * B11 = A21 * B11 + A22 * B1
        P4 = A22 * S4 = A22 * B21 - A22 * B11
        P5 = S5 * S6 = A11 * B11 + A11 * B22 + A22 * B11 + A22 * B22
        P6 = S7 * S8 = A12 * B21 + A12 * B22 - A22 * B21 - A22 * B22
        P7 = S9 * 10 = A11 * B11 + A11 * B12 - A21 * B11 - A21 * B12
        
    3.  通过上面步骤构建的 P1 … P7 来计算 C
        C11 = P4 + P5 + P6 - P2
        C12 = P1 + P2
        C21 = P3 + P4
        C22 = P1 + P5 +P7 - P3

def strassn(A, B):
    try:
        n = len(A[0])
    except TypeError:
        n = 1
    # let c to be a new nxn matrix
    c = [[0 for x in range(n)] for y in range(n)]
    if n == 1:
        c[0][0] = A[0] * B[0]

    # partition A, B and C        
    else:
        # only suit for 2X2 matrix
        # step 1
        s1 = B[0][1] - B[1][1]      
        s2 = A[0][0] + A[0][1]
        s3 = A[1][0] + A[1][1]
        s4 = B[1][0] - B[0][0]
        s5 = A[0][0] + A[1][1]
        s6 = B[0][0] + B[1][1]
        s7 = A[0][1] - A[1][1]
        s8 = B[1][0] + B[1][1]
        s9 = A[0][0] - A[1][0]
        s10 = B[0][0] + B[0][1]

        # step 2
        p1 = A[0][0] * s1
        p2 = s2 * B[1][1]
        p3 = s3 * B[0][0]
        p4 = A[1][1] * s4
        p5 = s5 * s6
        p6 = s7 * s8
        p7 = s9 * s10

        # step 3
        c[0][0] = p5 + p4 - p2 + p6
        c[0][1] = p1 + p2
        c[1][0] = p3 + p4
        c[1][1] = p5 + p1 - p3 - p7

    return c

if __name__ == '__main__':
    A = [[2,1],[3,6]]
    B = [[3,4],[2,2]]

    print(strassn(A, B))
结果:
    [[8, 10], [21, 24]]

 

Reference, 

  1. Introduction to algorithms

 

 

strassn(A, B)
posted @ 2020-05-10 13:33  zzYzz  阅读(462)  评论(0编辑  收藏  举报


Click to Visit Homepage : zzyzz.top