分治与递归-Starssen矩阵乘法

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 代码实现:

  1 /**
  2  * 矩阵乘法求解
  3  * @author Administrator
  4  *
  5  */
  6 public class Strassen {
  7     public static final int NUMBER = 4;
  8     private static int[][] A;
  9     private static int[][] B;
 10     
 11     public Strassen() {
 12         A = new int[NUMBER][NUMBER];
 13         B = new int[NUMBER][NUMBER];
 14     }
 15     
 16     public int[][] starssen(int[][] A, int[][] B) {
 17         int divide_length = A.length / 2;
 18         // 定义一些中间变量
 19         int[][] result = new int[A.length][A.length];
 20         
 21         int[][] M1 = new int[divide_length][divide_length];
 22         int[][] M2 = new int[divide_length][divide_length];
 23         int[][] M3 = new int[divide_length][divide_length];
 24         int[][] M4 = new int[divide_length][divide_length];
 25         int[][] M5 = new int[divide_length][divide_length];
 26         int[][] M6 = new int[divide_length][divide_length];
 27         int[][] M7 = new int[divide_length][divide_length];
 28         
 29         int[][] C11 = new int[divide_length][divide_length];
 30         int[][] C12 = new int[divide_length][divide_length];
 31         int[][] C21 = new int[divide_length][divide_length];
 32         int[][] C22 = new int[divide_length][divide_length];
 33         
 34         int[][] A11 = new int[divide_length][divide_length];
 35         int[][] A12 = new int[divide_length][divide_length];
 36         int[][] A21 = new int[divide_length][divide_length];
 37         int[][] A22 = new int[divide_length][divide_length];
 38         
 39         int[][] B11 = new int[divide_length][divide_length];
 40         int[][] B12 = new int[divide_length][divide_length];
 41         int[][] B21 = new int[divide_length][divide_length];
 42         int[][] B22 = new int[divide_length][divide_length];
 43         
 44         if (A.length == 2) {
 45             result = multi(A, B, A.length);
 46         } else {
 47             // 首先将矩阵A,B分为4块
 48             for (int i = 0; i < divide_length; ++i) {
 49                 for (int j = 0; j < divide_length; ++j) {
 50                     A11[i][j] = A[i][j];
 51                     A12[i][j] = A[i][j + divide_length];
 52                     A21[i][j] = A[i + divide_length][j];
 53                     A22[i][j] = A[i + divide_length][j + divide_length];
 54                     
 55                     B11[i][j] = B[i][j];
 56                     B12[i][j] = B[i][j + divide_length];
 57                     B21[i][j] = B[i + divide_length][j];
 58                     B22[i][j] = B[i + divide_length][j + divide_length];
 59                 }
 60             }
 61             
 62             // 计算M1
 63             M1 = starssen(A11, sub(B12, B22, divide_length));
 64             // 计算M2
 65             M2 = starssen(add(A11, A12, divide_length), B22);
 66             // 计算M3
 67             M3 = starssen(add(A21, A22, divide_length), B11);
 68             // 计算M4
 69             M4 = starssen(A22, sub(B21, B11, divide_length));
 70             // 计算M5
 71             M5 = starssen(add(A11, A22, divide_length), add(B11, B22, divide_length));
 72             // 计算M6
 73             M6 = starssen(sub(A12, A22, divide_length), add(B21, B22, divide_length));
 74             // 计算M7
 75             M7 = starssen(sub(A11, A21, divide_length), add(B11, B12, divide_length));
 76             
 77             // 计算C11,C12,C21,C22
 78             C11 = add(sub(add(M5, M4, divide_length), M2, divide_length), M6, divide_length);
 79             C12 = add(M1, M2, divide_length);
 80             C21 = add(M3, M4, divide_length);
 81             C22 = sub(sub(add(M5, M1, divide_length), M3, divide_length), M7, divide_length);
 82             
 83             // 合并C11,C12,C21,C22到C
 84             for (int i = 0; i < divide_length; ++i) {
 85                 for (int j = 0; j < divide_length; ++j) {
 86                     result[i][j] = C11[i][j];
 87                     result[i][j + divide_length] = C12[i][j];
 88                     result[i + divide_length][j] = C21[i][j];
 89                     result[i + divide_length][j + divide_length] = C22[i][j];
 90                 }
 91             }
 92         }
 93         return result;
 94     }
 95     
 96     public static int[][] initial() {
 97         int [][] result = new int[NUMBER][NUMBER];
 98         for (int i = 0; i < NUMBER; ++i) {
 99             for (int j = 0; j < NUMBER; ++j) {
100                 // 采用Math生成1~10之间的随机数
101                 result[i][j] = (int)(Math.random()*10);
102             }
103         }
104         return result;
105     }
106     
107     public void output(int[][] result) {
108         for (int b[] :result) {
109             for (int temp : b) {
110                 System.out.print(temp + " ");
111             }
112             System.out.println();
113         }
114     }
115     
116     /**
117      * 蛮力求解矩阵乘法
118      * @param a:矩阵a  n*n
119      * @param b:矩阵b  n*n
120      * @param n: 矩阵大小
121      */
122     public int[][] multi(int a[][], int b[][], int n) {
123         int result[][] = new int[n][n];
124         for (int i = 0; i < n; ++i) {
125             for (int j = 0; j < n; ++j) {
126                 result[i][j] = 0;
127                 for (int k = 0; k < n; ++k) {
128                     result[i][j] += a[i][k] * b[k][j];
129                 }
130             }
131         }
132         return result;
133     }
134     
135     /**
136      * 矩阵加法
137      * @param a
138      * @param b
139      * @param n
140      * @return
141      */
142     public int[][] add(int a[][], int b[][], int n) {
143         int result[][] = new int[n][n];
144         for (int i = 0; i < n; ++i) {
145             for (int j = 0; j < n; ++j) {
146                 result[i][j] = a[i][j] + b[i][j];
147             }
148         }
149         return result;
150     }
151     
152     /**
153      * 矩阵减法
154      * @param a
155      * @param b
156      * @param n
157      * @return
158      */
159     public int[][] sub(int a[][], int b[][], int n) {
160         int result[][] = new int[n][n];
161         for (int i = 0; i < n; ++i) {
162             for (int j = 0; j < n; ++j) {
163                 result[i][j] = a[i][j] - b[i][j];
164             }
165         }
166         return result;
167     }
168     
169     public static void main(String[] args) {
170         Strassen s = new Strassen();
171         A = initial();
172         B = initial();
173         s.output(A);
174         System.out.println("----------------------");
175         s.output(B);
176         System.out.println("----------------------");
177         
178         s.output(s.multi(A, B, NUMBER));
179         System.out.println("----------------------");
180         
181         int K[][] = new int[2][2];
182         K = s.starssen(A, B);
183         s.output(K);
184     }
185 }

 

posted @ 2019-09-29 15:39  cherry0408  阅读(358)  评论(0)    收藏  举报