AVX256加速矩阵乘法

最近打PKU的HPCGAME留下的代码,速度不是很快

const int BLOCK_SIZE = 1024;
const int BLOCK_SIZE2 = 256;

inline static void block_avx256_16x2(    // 电脑不支持AVX512捏
	int n,int K, //方阵大小
	double* A, double* B, double* C)
{
	__m256d c0000_0300, c0400_0700, c0800_1100, c1200_1500,
		c0001_0301, c0401_0701, c0801_1101, c1201_1501;

	__m256d a0x_3x, a4x_7x, a8x_11x, a12x_15x,
		bx0, bx1;

	double* c0001_0301_ptr = C + n;

	c0000_0300 = _mm256_load_pd(C);
	c0400_0700 = _mm256_load_pd(C + 4);
	c0800_1100 = _mm256_load_pd(C + 8);
	c1200_1500 = _mm256_load_pd(C + 12);

	c0001_0301 = _mm256_load_pd(c0001_0301_ptr);
	c0401_0701 = _mm256_load_pd(c0001_0301_ptr + 4);
	c0801_1101 = _mm256_load_pd(c0001_0301_ptr + 8);
	c1201_1501 = _mm256_load_pd(c0001_0301_ptr + 12);

	for (int x = 0; x < K; ++x)
	{
		a0x_3x = _mm256_load_pd(A);
		a4x_7x = _mm256_load_pd(A + 4);
		a8x_11x = _mm256_load_pd(A + 8);
		a12x_15x = _mm256_load_pd(A + 12);
		A+= 16;

		bx0 = _mm256_broadcast_sd(B++);
		bx1 = _mm256_broadcast_sd(B++);

		c0000_0300 = _mm256_add_pd(_mm256_mul_pd(a0x_3x, bx0), c0000_0300);
		c0400_0700 = _mm256_add_pd(_mm256_mul_pd(a4x_7x, bx0), c0400_0700);
		c0800_1100 = _mm256_add_pd(_mm256_mul_pd(a8x_11x, bx0), c0800_1100);
		c1200_1500 = _mm256_add_pd(_mm256_mul_pd(a12x_15x, bx0), c1200_1500);

		c0001_0301 = _mm256_add_pd(_mm256_mul_pd(a0x_3x, bx1), c0001_0301);
		c0401_0701 = _mm256_add_pd(_mm256_mul_pd(a4x_7x, bx1), c0401_0701);
		c0801_1101 = _mm256_add_pd(_mm256_mul_pd(a8x_11x, bx1), c0801_1101);
		c1201_1501 = _mm256_add_pd(_mm256_mul_pd(a12x_15x, bx1), c1201_1501);
	}
		
	_mm256_storeu_pd(C, c0000_0300);
	_mm256_storeu_pd(C + 4, c0400_0700);
	_mm256_storeu_pd(C + 8, c0800_1100);
	_mm256_storeu_pd(C + 12, c1200_1500);

	_mm256_storeu_pd(c0001_0301_ptr, c0001_0301);
	_mm256_storeu_pd(c0001_0301_ptr + 4, c0401_0701);
	_mm256_storeu_pd(c0001_0301_ptr + 8, c0801_1101);
	_mm256_storeu_pd(c0001_0301_ptr + 12, c1201_1501);
}

static inline void copy_b(int lda, const int K, double* b_src, double* b_dest) {
	double* b_ptr0, * b_ptr1;
	b_ptr0 = b_src;
	b_ptr1 = b_ptr0 + lda;
	for (int i = 0; i < K; ++i)
	{
		*b_dest++ = *b_ptr0++;
		*b_dest++ = *b_ptr1++;
	}
}

static inline void copy_a(int lda, const int K, double* a_src, double* a_dest) {
	for (int i = 0; i < K; ++i)
	{
		*a_dest++ = *a_src;
		*a_dest++ = *(a_src + 1);
		*a_dest++ = *(a_src + 2);
		*a_dest++ = *(a_src + 3);
		*a_dest++ = *(a_src + 4);
		*a_dest++ = *(a_src + 5);
		*a_dest++ = *(a_src + 6);
		*a_dest++ = *(a_src + 7);
		*a_dest++ = *(a_src + 8);
		*a_dest++ = *(a_src + 9);
		*a_dest++ = *(a_src + 10);
		*a_dest++ = *(a_src + 11);
		*a_dest++ = *(a_src + 12);
		*a_dest++ = *(a_src + 13);
		*a_dest++ = *(a_src + 14);
		*a_dest++ = *(a_src + 15);
		a_src += lda;
	}
}

static inline void do_block(int lda, int M, int N, int K, double* A, double* B, double* C)
{
	double *A_block, *B_block;
	A_block = (double*)_mm_malloc(M * K * sizeof(double), 64);
	B_block = (double*)_mm_malloc(K * N * sizeof(double), 64);

	double* a_ptr, * b_ptr, * c;

	const int Nmax = N - 1;
	int Mmax = M - 15;

	int i = 0, j = 0, p = 0;

	for (j = 0; j < Nmax; j += 2)
	{
		b_ptr = &B_block[j * K];
		copy_b(lda, K, B + j * lda, b_ptr); // 将 B 展开
		for (i = 0; i < Mmax; i += 16) {
			a_ptr = &A_block[i * K];
			if (j == 0) copy_a(lda, K, A + i, a_ptr); // 将 A 展开
			c = C + i + j * lda;
			block_avx256_16x2(lda, K, a_ptr, b_ptr, c);
		}
	}
	_mm_free(A_block);
	_mm_free(B_block);
}

void gemm(int lda, double* A, double* B, double* C)
{
#pragma omp parallel for
	for (int j = 0; j < lda; j += BLOCK_SIZE) {    // j i k 序 内存读写更快
		for (int i = 0; i < lda; i += BLOCK_SIZE) {
			for (int k = 0; k < lda; k += BLOCK_SIZE) {
				// 大分块里小分块
				for (int jj = j; jj < j + BLOCK_SIZE; jj += BLOCK_SIZE2)
					for (int ii = i; ii < i + BLOCK_SIZE; ii += BLOCK_SIZE2)
						for (int kk = k; kk < k + BLOCK_SIZE; kk += BLOCK_SIZE2)
							do_block(lda, BLOCK_SIZE2, BLOCK_SIZE2, BLOCK_SIZE2, A + ii + kk * lda, B + kk + jj * lda, C + ii + jj * lda);
			}
		}
	}
}
posted @ 2024-02-01 00:51  Icys  阅读(25)  评论(0编辑  收藏  举报