小矩阵相乘效率对比:lapack, cblas, 手写函数

我们需要做很多很多小矩阵相乘(维数只有几十),但是次数很多,所以用哪个矩阵库的函数对我们很重要。这里写一个很小的测试代码,测试lapack(包含着朴素的blas),cblas,还有手写函数,对比它们做小矩阵相乘的效率。
对于给定的维数,这三种办法,每种都做1000次方阵相乘\(AB = C\),每次相乘用的矩阵 \(A,B\) 都是随机的。计时用的是 clock(),取的是 cpu 时间。

#include<iostream>
using namespace std;
#include<fstream>
#include<cmath>
#include<vector>
#include<complex>

#include "mkl.h"

extern "C" void dgemm_(char *TRANSA, char *TRANSB, int *M, int *N, int *K, double* ALPHA, double *A, int* LDA, double *B, int* LDB, double* BETA, double *C, int* LDC);

/*
 * wraps dgemm_() in lapack, uses one of the optional modes of it to do C := A B
 * int n                dimension
 * double * A           A[ n*n ]
 * double * B           B[ n*n ]
 * double * C           C[ n*n ]
 */
void lapack_dgemm( int n, double * A, double * B, double * C ){

        // dgemm (... ) : C = alpha * op( A ) * op( B ) + beta * C
        char TRANSA='N'; // op( A ) = A
        char TRANSB='N'; // op( B ) = B
        int M=n; // number of rows in A
        int N=n; // number of columns in B
        int K=n; // number of columns in A, also equals number of rows in B
        double ALPHA=1.0; // alpha
        double BETA=0.0; // beta
        int LDA=n; // leading dimension of A
        int LDB=n; // leading dimension of B
        int LDC=n; // leading dimension of C

        dgemm_(&TRANSA, &TRANSB, &M, &N, &K, &ALPHA, B, &LDA, A, &LDB, &BETA, C, &LDC);
        // because dgemm is written in fortran, it actually gets B^\top A^\top = ( AB )^\top, an (AB)^\top will actually be stored in fortran manner, that is AB in C++
}

void mtx_multiply( int n, double * A, double * B, double * C ){
	double y;
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++){
			y = 0;
			for(int k=0;k<n;k++) y += A[i*n+k] * B[k*n+j];
			C[i*n+j] = y;
		}
	}
};

void cmtx_multiply( int n, complex<double> * cA, complex<double> * cB, complex<double> * cC ){
	
	//#pragma omp parallel for
	for(int i=0;i<n;i++){
		complex<double> y;
		for(int j=0;j<n;j++){
			y = 0;
			for(int k=0;k<n;k++) y += cA[i*n+k] * cB[k*n+j];
			cC[i*n+j] = y;
		}
	}
};

void cblaszgemm3m( int n, complex<double> * A, complex<double> * B, complex<double> * C ){
	complex<double> alpha = {1,0}, beta = {0,0};
	cblas_zgemm3m( CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n, &alpha, A, n, B, n, &beta, C, n );
}

void cblaszgemm( int n, complex<double> * A, complex<double> * B, complex<double> * C ){
	complex<double> alpha = {1,0}, beta = {0,0};
	cblas_zgemm( CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n, &alpha, A, n, B, n, &beta, C, n );
}

void printmtx( int n, complex<double> * A ){
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++)cout<< A[i*n+j]<<", ";
		cout<<endl;
	}
}

void randcmtx( int n, complex<double> * A ){
	for(int i=0;i<n*n;i++) A[i] = ((double)rand())/RAND_MAX;
}
void randmtx( int n, double * A ){
	for(int i=0;i<n*n;i++) A[i] = ((double)rand())/RAND_MAX;
}

int main(){

	/*
	// test: A = [ 0, 1, 0, 0 ], B = [ 0, 1, -1, 0 ]
	// AB = [ -1, 0, 0, 0 ], A^T B^T = [ 0, 0, 0, -1 ]
	int n = 2;
	double A[4] = { 0, 1, 0, 0 };
	double B[4] = { 0, 1, -1, 0 };
	double C[4];

	lapack_dgemm( n, A, B, C );

	cout<<"C: "; for(int i=0;i<4;i++) cout<<C[i]<<","; cout<<endl;
	*/

	vector<int> dim = {10, 20, 30, 40  };
	vector<double> ave_t_lapack_dgemm;
	vector<double> ave_t_cblas_dgemm;
	vector<double> ave_t_hand_dgemm;
	vector<double> ave_t_cblas_zgemm3m;
	vector<double> ave_t_cblas_zgemm;
	vector<double> ave_t_hand_zgemm;
	vector<double> ratio;

	int nrepeat = 1000; double x;

	for(auto n : dim ){
		cout<<" n = "<<n<<endl;
		double * A = new double [ n*n ];
		for(int i=0;i<n*n;i++) A[i] = ((double)rand())/RAND_MAX;
		double * B = new double [ n*n ];
		for(int i=0;i<n*n;i++) B[i] = ((double)rand())/RAND_MAX;
		double * C = new double [ n*n ];

		clock_t t1, t2, t3, t4;

		double alpha = 1, beta = 0;
		double t_lapack_dgemm = 0, t_cblas_dgemm = 0, t_hand_dgemm = 0;
		for(int i=0;i<nrepeat;i++){
			randmtx( n, A ); randmtx( n, B );
			t1 = clock(); lapack_dgemm( n, A, B, C ); t2 = clock(); t_lapack_dgemm += (t2-t1);
			t1 = clock(); 
			cblas_dgemm( CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n, alpha, A, n, B, n, beta, C, n );
			t2 = clock(); t_cblas_dgemm += (t2-t1);
			t1 = clock(); mtx_multiply( n, A, B, C ); t2 = clock(); t_hand_dgemm += (t2-t1);
		}

		x = t_lapack_dgemm/CLOCKS_PER_SEC/nrepeat;	
		cout<<" lapack dgemm:  " << x <<" s."<<endl;
		ave_t_lapack_dgemm.push_back( x );

		x = t_cblas_dgemm/CLOCKS_PER_SEC/nrepeat;
		cout<<" cblas dgemm: "<< x <<" s."<<endl;
		ave_t_cblas_dgemm.push_back( x );

		x = t_hand_dgemm/CLOCKS_PER_SEC/nrepeat;
		cout<<" hand written gemm: " << x << " s."<<endl;
		ave_t_hand_dgemm.push_back( x );

		complex<double> * cA = new complex<double> [ n*n ];
		complex<double> * cB = new complex<double> [ n*n ];
		complex<double> * cC = new complex<double> [ n*n ];

		double t_cblas_zgemm3m = 0, t_cblas_zgemm = 0, t_hand_zgemm = 0;
		for(int i=0;i<nrepeat;i++){
			randcmtx(n, cA); randcmtx(n, cB);
			t1 = clock(); cblaszgemm3m( n, cA, cB, cC ); t2 = clock(); t_cblas_zgemm3m += (t2-t1);
			t1 = clock(); cblaszgemm( n, cA, cB, cC ); t2 = clock(); t_cblas_zgemm += (t2-t1);
			t1 = clock(); cmtx_multiply( n, cA, cB, cC ); t2 = clock(); t_hand_zgemm += (t2-t1);
		}
		x = t_cblas_zgemm3m /CLOCKS_PER_SEC/nrepeat;
		cout<<" cblas zgemm3m: " << x <<" s."<<endl;
		ave_t_cblas_zgemm3m.push_back( x );

		x = t_cblas_zgemm /CLOCKS_PER_SEC/nrepeat;
		cout<<" cblas zgemm: " << x <<" s."<<endl;
		ave_t_cblas_zgemm.push_back( x );

		x = t_hand_zgemm / CLOCKS_PER_SEC/nrepeat;
		cout<<" hand zgemm: " << x <<" s."<<endl;
		ave_t_hand_zgemm.push_back( x );

		delete [] A; delete [] B; delete [] C;
		delete [] cA; delete [] cB; delete [] cC;
	}

	cout<<" ave_t_lapack_dgemm = [ "; for(auto t : ave_t_lapack_dgemm) cout<<t<<", "; cout<<"]\n";
	cout<<" ave_t_cblas_dgemm = [ "; for(auto t : ave_t_cblas_dgemm) cout<<t<<", "; cout<<"]\n";
	cout<<" ave_t_hand_dgemm = [ "; for(auto t : ave_t_hand_dgemm) cout<<t<<", "; cout<<"]\n";
	cout<<" ave_t_cblas_zgemm3m = [ "; for(auto t : ave_t_cblas_zgemm3m) cout<<t<<", "; cout<<"]\n";
	cout<<" ave_t_cblas_zgemm = [ "; for(auto t : ave_t_cblas_zgemm) cout<<t<<", "; cout<<"]\n";
	cout<<" ave_t_hand_zgemm = [ "; for(auto t : ave_t_hand_zgemm) cout<<t<<", "; cout<<"]\n";

	return 0;
}

编译:

icc gemm.cpp -qmkl -lblas -lgsl -O3

运行:

./a.out

做出来的结果如下
image

结论是 cblas 比 朴素的 blas或者手写函数都要强(of course)。
但实践中有几个点我要记一下:

  • 编译时如果不开 -O3,cblas 很慢,在 n=10, 20 时不如手写
  • 如果不是在代码中运行 1000 次取平均,只跑一次进行比较的话,cblas 在 n=10,20,30 也不如手写函数。这个我不是完全理解,但考虑到实践中是密集的矩阵运算,所以运行1000次取平均似乎更接近实际场景。在实践中用 cblas 确实也比手写更快,PVPC Si28 中用 zgemm3m 比用手写函数要耗时少25%,用手写函数要 16s, 用 zgemm3m 要 12s。所以暂时不纠结这个问题了,先用着 cblas。
  • 图中zgemm 似乎比 zgemm3m 还快一点,实践中也得到了印证,在 PVPC Si28 中,用 zgemm 只要 10s。

posted on 2021-11-16 18:22  luyi07  阅读(651)  评论(0编辑  收藏  举报

导航