分治与递归-Starssen矩阵乘法





代码实现:
1 /** 2 * 矩阵乘法求解 3 * @author Administrator 4 * 5 */ 6 public class Strassen { 7 public static final int NUMBER = 4; 8 private static int[][] A; 9 private static int[][] B; 10 11 public Strassen() { 12 A = new int[NUMBER][NUMBER]; 13 B = new int[NUMBER][NUMBER]; 14 } 15 16 public int[][] starssen(int[][] A, int[][] B) { 17 int divide_length = A.length / 2; 18 // 定义一些中间变量 19 int[][] result = new int[A.length][A.length]; 20 21 int[][] M1 = new int[divide_length][divide_length]; 22 int[][] M2 = new int[divide_length][divide_length]; 23 int[][] M3 = new int[divide_length][divide_length]; 24 int[][] M4 = new int[divide_length][divide_length]; 25 int[][] M5 = new int[divide_length][divide_length]; 26 int[][] M6 = new int[divide_length][divide_length]; 27 int[][] M7 = new int[divide_length][divide_length]; 28 29 int[][] C11 = new int[divide_length][divide_length]; 30 int[][] C12 = new int[divide_length][divide_length]; 31 int[][] C21 = new int[divide_length][divide_length]; 32 int[][] C22 = new int[divide_length][divide_length]; 33 34 int[][] A11 = new int[divide_length][divide_length]; 35 int[][] A12 = new int[divide_length][divide_length]; 36 int[][] A21 = new int[divide_length][divide_length]; 37 int[][] A22 = new int[divide_length][divide_length]; 38 39 int[][] B11 = new int[divide_length][divide_length]; 40 int[][] B12 = new int[divide_length][divide_length]; 41 int[][] B21 = new int[divide_length][divide_length]; 42 int[][] B22 = new int[divide_length][divide_length]; 43 44 if (A.length == 2) { 45 result = multi(A, B, A.length); 46 } else { 47 // 首先将矩阵A,B分为4块 48 for (int i = 0; i < divide_length; ++i) { 49 for (int j = 0; j < divide_length; ++j) { 50 A11[i][j] = A[i][j]; 51 A12[i][j] = A[i][j + divide_length]; 52 A21[i][j] = A[i + divide_length][j]; 53 A22[i][j] = A[i + divide_length][j + divide_length]; 54 55 B11[i][j] = B[i][j]; 56 B12[i][j] = B[i][j + divide_length]; 57 B21[i][j] = B[i + divide_length][j]; 58 B22[i][j] = B[i + divide_length][j + divide_length]; 59 } 60 } 61 62 // 计算M1 63 M1 = starssen(A11, sub(B12, B22, divide_length)); 64 // 计算M2 65 M2 = starssen(add(A11, A12, divide_length), B22); 66 // 计算M3 67 M3 = starssen(add(A21, A22, divide_length), B11); 68 // 计算M4 69 M4 = starssen(A22, sub(B21, B11, divide_length)); 70 // 计算M5 71 M5 = starssen(add(A11, A22, divide_length), add(B11, B22, divide_length)); 72 // 计算M6 73 M6 = starssen(sub(A12, A22, divide_length), add(B21, B22, divide_length)); 74 // 计算M7 75 M7 = starssen(sub(A11, A21, divide_length), add(B11, B12, divide_length)); 76 77 // 计算C11,C12,C21,C22 78 C11 = add(sub(add(M5, M4, divide_length), M2, divide_length), M6, divide_length); 79 C12 = add(M1, M2, divide_length); 80 C21 = add(M3, M4, divide_length); 81 C22 = sub(sub(add(M5, M1, divide_length), M3, divide_length), M7, divide_length); 82 83 // 合并C11,C12,C21,C22到C 84 for (int i = 0; i < divide_length; ++i) { 85 for (int j = 0; j < divide_length; ++j) { 86 result[i][j] = C11[i][j]; 87 result[i][j + divide_length] = C12[i][j]; 88 result[i + divide_length][j] = C21[i][j]; 89 result[i + divide_length][j + divide_length] = C22[i][j]; 90 } 91 } 92 } 93 return result; 94 } 95 96 public static int[][] initial() { 97 int [][] result = new int[NUMBER][NUMBER]; 98 for (int i = 0; i < NUMBER; ++i) { 99 for (int j = 0; j < NUMBER; ++j) { 100 // 采用Math生成1~10之间的随机数 101 result[i][j] = (int)(Math.random()*10); 102 } 103 } 104 return result; 105 } 106 107 public void output(int[][] result) { 108 for (int b[] :result) { 109 for (int temp : b) { 110 System.out.print(temp + " "); 111 } 112 System.out.println(); 113 } 114 } 115 116 /** 117 * 蛮力求解矩阵乘法 118 * @param a:矩阵a n*n 119 * @param b:矩阵b n*n 120 * @param n: 矩阵大小 121 */ 122 public int[][] multi(int a[][], int b[][], int n) { 123 int result[][] = new int[n][n]; 124 for (int i = 0; i < n; ++i) { 125 for (int j = 0; j < n; ++j) { 126 result[i][j] = 0; 127 for (int k = 0; k < n; ++k) { 128 result[i][j] += a[i][k] * b[k][j]; 129 } 130 } 131 } 132 return result; 133 } 134 135 /** 136 * 矩阵加法 137 * @param a 138 * @param b 139 * @param n 140 * @return 141 */ 142 public int[][] add(int a[][], int b[][], int n) { 143 int result[][] = new int[n][n]; 144 for (int i = 0; i < n; ++i) { 145 for (int j = 0; j < n; ++j) { 146 result[i][j] = a[i][j] + b[i][j]; 147 } 148 } 149 return result; 150 } 151 152 /** 153 * 矩阵减法 154 * @param a 155 * @param b 156 * @param n 157 * @return 158 */ 159 public int[][] sub(int a[][], int b[][], int n) { 160 int result[][] = new int[n][n]; 161 for (int i = 0; i < n; ++i) { 162 for (int j = 0; j < n; ++j) { 163 result[i][j] = a[i][j] - b[i][j]; 164 } 165 } 166 return result; 167 } 168 169 public static void main(String[] args) { 170 Strassen s = new Strassen(); 171 A = initial(); 172 B = initial(); 173 s.output(A); 174 System.out.println("----------------------"); 175 s.output(B); 176 System.out.println("----------------------"); 177 178 s.output(s.multi(A, B, NUMBER)); 179 System.out.println("----------------------"); 180 181 int K[][] = new int[2][2]; 182 K = s.starssen(A, B); 183 s.output(K); 184 } 185 }

浙公网安备 33010602011771号