strassen 算法实现矩阵相乘,当输入为偶数,奇数,任意数的实现

1.使用随机数生成两个矩阵,并实现输出方法。

public void makeMatrix(int[][] matrixA, int[][] matrixB, int length){  //生成矩阵
        Random random=new Random();
        for(int i=0;i<length;i++){
            for (int j=0;j<length;j++){
                matrixA[i][j]=random.nextInt(5);
                matrixB[i][j]=random.nextInt(5);
            }
        }
    }
    public  void printMatrix(int[][] matrixA,int length){ //输出
        for(int i=0;i<length;i++){
            for (int j=0;j<length;j++){
                System.out.print(matrixA[i][j]+" ");
                if((j+1)%length==0)
                    System.out.println();
            }
        }
    }

 

2.使用Strassen算法需要涉及到矩阵的加减。所以先准备好方法。

public void add(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
        for(int i=0;i<length;i++) {
            for (int j = 0; j < length; j++) {
                matrixC[i][j]= matrixA[i][j]+ matrixB[i][j];
            }
        }
    }
    public void jian(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
        for(int i=0;i<length;i++) {
            for (int j = 0; j < length; j++) {
                matrixC[i][j]= matrixA[i][j] - matrixB[i][j];
            }
        }
    }

    public void cheng(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
        for(int i=0;i<length;i++) {
            for (int j = 0; j < length; j++) {
                matrixC[i][j]=0;
                for(int k=0;k<length;k++){
                    matrixC[i][j] = matrixC[i][j]+ matrixA[i][k] * matrixB[k][j];
                }

            }
        }
    }

3.当矩阵的阶数为 2的k次方

 //阶数为 2 的 K 次方的时候 Strassen算法
    public void strassen(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        int newsize=N/2;
        if(N==2){
            cheng(matrixA,matrixB,matrixC,N);
            return;
        }
        int[][] A11=new int[newsize][newsize];
        int[][] A12=new int[newsize][newsize];
        int[][] A21=new int[newsize][newsize];
        int[][] A22=new int[newsize][newsize];

        int[][] B11=new int[newsize][newsize];
        int[][] B12=new int[newsize][newsize];
        int[][] B21=new int[newsize][newsize];
        int[][] B22=new int[newsize][newsize];

        int[][] C11=new int[newsize][newsize];
        int[][] C12=new int[newsize][newsize];
        int[][] C21=new int[newsize][newsize];
        int[][] C22=new int[newsize][newsize];

        int[][] M1=new int[newsize][newsize];
        int[][] M2=new int[newsize][newsize];
        int[][] M3=new int[newsize][newsize];
        int[][] M4=new int[newsize][newsize];
        int[][] M5=new int[newsize][newsize];
        int[][] M6=new int[newsize][newsize];
        int[][] M7=new int[newsize][newsize];

        int[][] Aresult=new int[newsize][newsize];
        int[][] Bresult=new int[newsize][newsize];

        //分别给 A11 A12 A21 A22赋值
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                A11[i][j]=matrixA[i][j];
                A12[i][j]=matrixA[i][j+N/2];
                A21[i][j]=matrixA[i+N/2][j];
                A22[i][j]=matrixA[i+N/2][j+N/2];

                B11[i][j]=matrixB[i][j];
                B12[i][j]=matrixB[i][j+N/2];
                B21[i][j]=matrixB[i+N/2][j];
                B22[i][j]=matrixB[i+N/2][j+N/2];
            }
        }

        //计算M1 到M7
        add(A11,A22,Aresult,newsize);
        add(B11,B22,Bresult,newsize);
        strassen(Aresult,Bresult,M1,newsize);

        //M2
        add(A21,A22,Aresult,newsize);
        strassen(Aresult,B11,M2,newsize);

        //M3
        jian(B12,B22,Bresult,newsize);
        strassen(A11,Bresult,M3,newsize);

        //M4
        jian(B21,B11,Bresult,newsize);
        strassen(A22,Bresult,M4,newsize);

        //M5
        add(A11,A12,Aresult,newsize);
        strassen(Aresult,B22,M5,newsize);

        //M6
        jian(A21,A11,Aresult,newsize);
        add(B11,B12,Bresult,newsize);
        strassen(Aresult,Bresult,M6,newsize);

        //M7
        jian(A12,A22,Aresult,newsize);
        add(B21,B22,Bresult,newsize);
        strassen(Aresult,Bresult,M7,newsize);

        //C11
        add(M1,M4,Aresult,newsize);
        jian(M5,M7,Bresult,newsize);
        jian(Aresult,Bresult,C11,newsize);

        //C12
        add(M3,M5,C12,newsize);

        //C21
        add(M2,M4,C21,newsize);

        //C22
        add(M1,M3,Aresult,newsize);
        jian(M2,M6,Bresult,newsize);
        jian(Aresult,Bresult,C22,newsize);
        //把C的值填充
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                matrixC[i][j]=C11[i][j];
                matrixC[i][j+N/2]=C12[i][j];
                matrixC[i+N/2][j]=C21[i][j];
                matrixC[i+N/2][j+N/2]=C22[i][j];
            }
        }
    }

4.当阶数为偶数的时候

假设阶数为  n   ,所以 n=m*2^k    。此时,可以将偶数的矩阵拆分为 m*m个2*k的矩阵。大矩阵使用传统方法,小矩阵使用Strassen算法。

//为阶数为偶数时的 Strassen算法
    public void evenNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        int[] splits=getK(N);
        int m = splits[1];
        int k = splits[0];
        int jie=(int) Math.pow(2,k);
        //可以拆分为 m*m 个 2^^k 阶矩阵
        Object[][] TA = new Object[m][m];
        Object[][] TB = new Object[m][m];
        Object[][] TC = new Object[m][m];
        for(int hang=0;hang<m;hang++){
            for (int lie=0;lie<m;lie++){
                int[][] matrixMA=new int[jie][jie];
                int[][] matrixMB=new int[jie][jie];
                //给矩阵MA ,MB 赋值
                for(int i=0;i<jie;i++){
                    for(int j=0;j<jie;j++){
                        matrixMA[i][j]=matrixA[hang*jie+i][lie*jie+j];
                        matrixMB[i][j]=matrixB[hang*jie+i][lie*jie+j];
                    }
                }
                TA[hang][lie]=matrixMA;
                TB[hang][lie]=matrixMB;
            }
        }
        //Object 数组中存放好了 m*m 的2^k 阶矩阵  所以 TA TB 看做两个矩阵做乘法
        for(int i=0;i<m;i++){
            for(int j=0;j<m;j++){
                int[][] juzhenC = new int[jie][jie];
                for(int p=0;p<m;p++){
                    int[][] juzhenA = (int[][])TA[i][p];
                    int[][] juzhenB = (int[][])TB[p][j];
                    int[][] chengres =new int[jie][jie];
                    int[][] addres =new int[jie][jie];
                    strassen(juzhenA,juzhenB,chengres,jie);
                    add(juzhenC,chengres,addres,jie);
                    juzhenC=addres;
                }
                TC[i][j]=juzhenC;
            }
        }

        //给矩阵C 结果矩阵进行赋值
        for(int hang=0;hang<m;hang++){
            for (int lie=0;lie<m;lie++){
                int[][] matrixMC=(int[][])TC[hang][lie];
                //给矩阵MA ,MB 赋值
                for(int i=0;i<jie;i++){
                    for(int j=0;j<jie;j++){
                        matrixC[hang*jie+i][lie*jie+j]=matrixMC[i][j];

                    }
                }
            }
        }
    }

5.当阶数为奇数的时候,添加一行一列,放在第一行第一列,并吧 [0][0]位置设为1  ,这样不影响矩阵相乘结果。计算出结果再去掉第一行第一列即可。

 //为阶数为奇数时的 Strassen算法
    public void oddNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        //扩容矩阵,第一行,第一列增加,A[0,0] 位置的值为0
        N=N+1;
        int[][] newA=new int[N][N];
        int[][] newB=new int[N][N];
        int[][] newC=new int[N][N];
        for(int i=0;i<N;i++){
            for(int j=0;j<N;j++){
                if(i==0||j==0){
                    newA[i][j]=0;
                    newB[i][j]=0;
                    continue;
                }
                newA[i][j]=matrixA[i-1][j-1];
                newB[i][j]=matrixB[i-1][j-1];
            }
        }
        newA[0][0]=1;
        newB[0][0]=1;
        evenNumber(newA,newB,newC,N);
        for(int i=0;i<N-1;i++) {
            for (int j = 0; j < N-1; j++) {
                matrixC[i][j]=newC[i+1][j+1];
            }
        }
    }

6.所有代码如下

import java.util.Random;
import java.util.Scanner;

/**
 * 主要有四个方法:
 * 第一问,阶数为2^k 的时候矩阵相乘    strassen
 * 第二问  阶数为偶数次矩阵相乘      evenNumber
 * 第三问  阶数为技数次矩阵相乘      oddNumber
 *
 * 将这些方法封装到一个方法内  allType    这个方法不限制矩阵阶数,任意维度都可以相乘。
 * */
public class Matrix {
    public void add(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
        for(int i=0;i<length;i++) {
            for (int j = 0; j < length; j++) {
                matrixC[i][j]= matrixA[i][j]+ matrixB[i][j];
            }
        }
    }
    public void jian(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
        for(int i=0;i<length;i++) {
            for (int j = 0; j < length; j++) {
                matrixC[i][j]= matrixA[i][j] - matrixB[i][j];
            }
        }
    }

    public void cheng(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
        for(int i=0;i<length;i++) {
            for (int j = 0; j < length; j++) {
                matrixC[i][j]=0;
                for(int k=0;k<length;k++){
                    matrixC[i][j] = matrixC[i][j]+ matrixA[i][k] * matrixB[k][j];
                }

            }
        }
    }

    public void makeMatrix(int[][] matrixA, int[][] matrixB, int length){
        Random random=new Random();
        for(int i=0;i<length;i++){
            for (int j=0;j<length;j++){
                matrixA[i][j]=random.nextInt(5);
                matrixB[i][j]=random.nextInt(5);
            }
        }
    }
    public  void printMatrix(int[][] matrixA,int length){
        for(int i=0;i<length;i++){
            for (int j=0;j<length;j++){
                System.out.print(matrixA[i][j]+" ");
                if((j+1)%length==0)
                    System.out.println();
            }
        }
    }
    /**
     * 计算阶数N   N=m*2^k
     * 返回值为数组   k,m
     * */
    public int[] getK(int N){
        int k=0;
        if(N%2==0){
            k++;
            N=N/2;
        }
        return new int[]{k,N};
    }

    /**
     * ============================================================================
     * 此分界线以上为辅助用的方法,下边三个分别为 三种情况对应的算法。
     * 最后一个是 参数为任意数时的方法。
     * */


    //阶数为 2 的 K 次方的时候 Strassen算法
    public void strassen(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        int newsize=N/2;
        if(N==2){
            cheng(matrixA,matrixB,matrixC,N);
            return;
        }
        int[][] A11=new int[newsize][newsize];
        int[][] A12=new int[newsize][newsize];
        int[][] A21=new int[newsize][newsize];
        int[][] A22=new int[newsize][newsize];

        int[][] B11=new int[newsize][newsize];
        int[][] B12=new int[newsize][newsize];
        int[][] B21=new int[newsize][newsize];
        int[][] B22=new int[newsize][newsize];

        int[][] C11=new int[newsize][newsize];
        int[][] C12=new int[newsize][newsize];
        int[][] C21=new int[newsize][newsize];
        int[][] C22=new int[newsize][newsize];

        int[][] M1=new int[newsize][newsize];
        int[][] M2=new int[newsize][newsize];
        int[][] M3=new int[newsize][newsize];
        int[][] M4=new int[newsize][newsize];
        int[][] M5=new int[newsize][newsize];
        int[][] M6=new int[newsize][newsize];
        int[][] M7=new int[newsize][newsize];

        int[][] Aresult=new int[newsize][newsize];
        int[][] Bresult=new int[newsize][newsize];

        //分别给 A11 A12 A21 A22赋值
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                A11[i][j]=matrixA[i][j];
                A12[i][j]=matrixA[i][j+N/2];
                A21[i][j]=matrixA[i+N/2][j];
                A22[i][j]=matrixA[i+N/2][j+N/2];

                B11[i][j]=matrixB[i][j];
                B12[i][j]=matrixB[i][j+N/2];
                B21[i][j]=matrixB[i+N/2][j];
                B22[i][j]=matrixB[i+N/2][j+N/2];
            }
        }

        //计算M1 到M7
        add(A11,A22,Aresult,newsize);
        add(B11,B22,Bresult,newsize);
        strassen(Aresult,Bresult,M1,newsize);

        //M2
        add(A21,A22,Aresult,newsize);
        strassen(Aresult,B11,M2,newsize);

        //M3
        jian(B12,B22,Bresult,newsize);
        strassen(A11,Bresult,M3,newsize);

        //M4
        jian(B21,B11,Bresult,newsize);
        strassen(A22,Bresult,M4,newsize);

        //M5
        add(A11,A12,Aresult,newsize);
        strassen(Aresult,B22,M5,newsize);

        //M6
        jian(A21,A11,Aresult,newsize);
        add(B11,B12,Bresult,newsize);
        strassen(Aresult,Bresult,M6,newsize);

        //M7
        jian(A12,A22,Aresult,newsize);
        add(B21,B22,Bresult,newsize);
        strassen(Aresult,Bresult,M7,newsize);

        //C11
        add(M1,M4,Aresult,newsize);
        jian(M5,M7,Bresult,newsize);
        jian(Aresult,Bresult,C11,newsize);

        //C12
        add(M3,M5,C12,newsize);

        //C21
        add(M2,M4,C21,newsize);

        //C22
        add(M1,M3,Aresult,newsize);
        jian(M2,M6,Bresult,newsize);
        jian(Aresult,Bresult,C22,newsize);
        //把C的值填充
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                matrixC[i][j]=C11[i][j];
                matrixC[i][j+N/2]=C12[i][j];
                matrixC[i+N/2][j]=C21[i][j];
                matrixC[i+N/2][j+N/2]=C22[i][j];
            }
        }
    }

    //为阶数为偶数时的 Strassen算法
    public void evenNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        int[] splits=getK(N);
        int m = splits[1];
        int k = splits[0];
        int jie=(int) Math.pow(2,k);
        //可以拆分为 m*m 个 2^^k 阶矩阵
        Object[][] TA = new Object[m][m];
        Object[][] TB = new Object[m][m];
        Object[][] TC = new Object[m][m];
        for(int hang=0;hang<m;hang++){
            for (int lie=0;lie<m;lie++){
                int[][] matrixMA=new int[jie][jie];
                int[][] matrixMB=new int[jie][jie];
                //给矩阵MA ,MB 赋值
                for(int i=0;i<jie;i++){
                    for(int j=0;j<jie;j++){
                        matrixMA[i][j]=matrixA[hang*jie+i][lie*jie+j];
                        matrixMB[i][j]=matrixB[hang*jie+i][lie*jie+j];
                    }
                }
                TA[hang][lie]=matrixMA;
                TB[hang][lie]=matrixMB;
            }
        }
        //Object 数组中存放好了 m*m 的2^k 阶矩阵  所以 TA TB 看做两个矩阵做乘法
        for(int i=0;i<m;i++){
            for(int j=0;j<m;j++){
                int[][] juzhenC = new int[jie][jie];
                for(int p=0;p<m;p++){
                    int[][] juzhenA = (int[][])TA[i][p];
                    int[][] juzhenB = (int[][])TB[p][j];
                    int[][] chengres =new int[jie][jie];
                    int[][] addres =new int[jie][jie];
                    strassen(juzhenA,juzhenB,chengres,jie);
                    add(juzhenC,chengres,addres,jie);
                    juzhenC=addres;
                }
                TC[i][j]=juzhenC;
            }
        }

        //给矩阵C 结果矩阵进行赋值
        for(int hang=0;hang<m;hang++){
            for (int lie=0;lie<m;lie++){
                int[][] matrixMC=(int[][])TC[hang][lie];
                //给矩阵MA ,MB 赋值
                for(int i=0;i<jie;i++){
                    for(int j=0;j<jie;j++){
                        matrixC[hang*jie+i][lie*jie+j]=matrixMC[i][j];

                    }
                }
            }
        }
    }

    //为阶数为奇数时的 Strassen算法
    public void oddNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        //扩容矩阵,第一行,第一列增加,A[0,0] 位置的值为0
        N=N+1;
        int[][] newA=new int[N][N];
        int[][] newB=new int[N][N];
        int[][] newC=new int[N][N];
        for(int i=0;i<N;i++){
            for(int j=0;j<N;j++){
                if(i==0||j==0){
                    newA[i][j]=0;
                    newB[i][j]=0;
                    continue;
                }
                newA[i][j]=matrixA[i-1][j-1];
                newB[i][j]=matrixB[i-1][j-1];
            }
        }
        newA[0][0]=1;
        newB[0][0]=1;
        evenNumber(newA,newB,newC,N);
        for(int i=0;i<N-1;i++) {
            for (int j = 0; j < N-1; j++) {
                matrixC[i][j]=newC[i+1][j+1];
            }
        }
    }


    //综合所有情况
    public void allType(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
        if(N%2==1){
            oddNumber(matrixA,matrixB,matrixC,N);
        }else {
            int[] split=getK(N);
            if(split[1]==1){
                strassen(matrixA,matrixB,matrixC,N);
            }else {
                evenNumber(matrixA,matrixB,matrixC,N);
            }
        }
    }


    public static void main(String[] args) {
        System.out.println("请输入矩阵的阶数");
        Scanner input = new Scanner(System.in);
        int matrixSize=input.nextInt();
        int[][] matrixA=new int[matrixSize][matrixSize];
        int[][] matrixB=new int[matrixSize][matrixSize];
        int[][] matrixC=new int[matrixSize][matrixSize];
        int[][] matrixD=new int[matrixSize][matrixSize];
        Matrix t=new Matrix();
        //为矩阵 A B 赋值,填充矩阵
        t.makeMatrix(matrixA,matrixB,matrixSize);

        //输出 A B 矩阵
        System.out.println("矩阵A:");
        t.printMatrix(matrixA,matrixSize);
        System.out.println("矩阵B:");
        t.printMatrix(matrixB,matrixSize);


        //传统乘法,并输出
        t.cheng(matrixA,matrixB,matrixC,matrixSize);
        System.out.println("传统乘法:");
        t.printMatrix(matrixC,matrixSize);

        //strassen 乘法,并输出
        t.allType(matrixA,matrixB,matrixD,matrixSize);
        System.out.println("Strassen乘法:");
        t.printMatrix(matrixD,matrixSize);

    }

 

posted @ 2020-12-09 21:50  星际毁灭  阅读(598)  评论(0编辑  收藏  举报