AVX512加速矩阵乘法

最近打PKU的HPCGAME用的代码,这里只用上了20个zmm寄存器,改变block的大小应该还能优化一下速度。

代码只考虑了方阵,其他非2^n次方阵要自己改代码。具体原理很简单,看看代码就差不多知道。

const int BLOCK_SIZE = 1024;
const int BLOCK_SIZE2 = 256;

inline static void block_avx512_32x4(	// AVX256效果不好,硬着头皮上吧(汇编上看好像还有12个寄存器没用上,还有优化空间)
	int n, int K, //方阵大小
	double* A, double* B, double* C)
{
	__m512d c0000_0700,c0800_1500, c1600_2300, c2400_3100,
		c0001_0701, c0801_1501, c1601_2301, c2401_3101,
		c0002_0702, c0802_1502, c1602_2302, c2402_3102,
		c0003_0703, c0803_1503, c1603_2303, c2403_3103;

	__m512d a0x_7x, a8x_15x, a16x_23x, a24x_31x,
		bx0, bx1, bx2, bx3;

	double* c0001_0701_ptr = C + n;
	double* c0002_0702_ptr = C + n * 2;
	double* c0003_0703_ptr = C + n * 3;

	c0000_0700 = _mm512_load_pd(C);
	c0800_1500 = _mm512_load_pd(C + 8);
	c1600_2300 = _mm512_load_pd(C + 16);
	c2400_3100 = _mm512_load_pd(C + 24);

	c0001_0701 = _mm512_load_pd(c0001_0701_ptr);
	c0801_1501 = _mm512_load_pd(c0001_0701_ptr + 8);
	c1601_2301 = _mm512_load_pd(c0001_0701_ptr + 16);
	c2401_3101 = _mm512_load_pd(c0001_0701_ptr + 24);

	c0002_0702 = _mm512_load_pd(c0002_0702_ptr);
	c0802_1502 = _mm512_load_pd(c0002_0702_ptr + 8);
	c1602_2302 = _mm512_load_pd(c0002_0702_ptr + 16);
	c2402_3102 = _mm512_load_pd(c0002_0702_ptr + 24);

	c0003_0703 = _mm512_load_pd(c0003_0703_ptr);
	c0803_1503 = _mm512_load_pd(c0003_0703_ptr + 8);
	c1603_2303 = _mm512_load_pd(c0003_0703_ptr + 16);
	c2403_3103 = _mm512_load_pd(c0003_0703_ptr + 24);

	for (int x = 0; x < K; ++x)
	{
		a0x_7x = _mm512_load_pd(A);
		a8x_15x = _mm512_load_pd(A + 8);
		a16x_23x = _mm512_load_pd(A + 16);
		a24x_31x = _mm512_load_pd(A + 24);
		A += 32;

		bx0 = _mm512_broadcastsd_pd(_mm_load_sd(B++));
		bx1 = _mm512_broadcastsd_pd(_mm_load_sd(B++));
		bx2 = _mm512_broadcastsd_pd(_mm_load_sd(B++));
		bx3 = _mm512_broadcastsd_pd(_mm_load_sd(B++));


		c0000_0700 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx0), c0000_0700);
		c0800_1500 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx0), c0800_1500);
		c1600_2300 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx0), c1600_2300);
		c2400_3100 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx0), c2400_3100);

		c0001_0701 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx1), c0001_0701);
		c0801_1501 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx1), c0801_1501);
		c1601_2301 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx1), c1601_2301);
		c2401_3101 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx1), c2401_3101);

		c0002_0702 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx2), c0002_0702);
		c0802_1502 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx2), c0802_1502);
		c1602_2302 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx2), c1602_2302);
		c2402_3102 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx2), c2402_3102);

		c0003_0703 = _mm512_add_pd(_mm512_mul_pd(a0x_7x, bx3), c0003_0703);
		c0803_1503 = _mm512_add_pd(_mm512_mul_pd(a8x_15x, bx3), c0803_1503);
		c1603_2303 = _mm512_add_pd(_mm512_mul_pd(a16x_23x, bx3), c1603_2303);
		c2403_3103 = _mm512_add_pd(_mm512_mul_pd(a24x_31x, bx3), c2403_3103);
	}
	_mm512_storeu_pd(C, c0000_0700);
	_mm512_storeu_pd(C + 8, c0800_1500);
	_mm512_storeu_pd(C + 16, c1600_2300);
	_mm512_storeu_pd(C + 24, c2400_3100);

	_mm512_storeu_pd(c0001_0701_ptr, c0001_0701);
	_mm512_storeu_pd(c0001_0701_ptr + 8, c0801_1501);
	_mm512_storeu_pd(c0001_0701_ptr + 16, c1601_2301);
	_mm512_storeu_pd(c0001_0701_ptr + 24, c2401_3101);

	_mm512_storeu_pd(c0002_0702_ptr, c0002_0702);
	_mm512_storeu_pd(c0002_0702_ptr + 8, c0802_1502);
	_mm512_storeu_pd(c0002_0702_ptr + 16, c1602_2302);
	_mm512_storeu_pd(c0002_0702_ptr + 24, c2402_3102);

	_mm512_storeu_pd(c0003_0703_ptr, c0003_0703);
	_mm512_storeu_pd(c0003_0703_ptr + 8, c0803_1503);
	_mm512_storeu_pd(c0003_0703_ptr + 16, c1603_2303);
	_mm512_storeu_pd(c0003_0703_ptr + 24, c2403_3103);
}

static inline void copy_avx512_b(int lda, const int K, double* b_src, double* b_dest) {
	double* b_ptr0, * b_ptr1, * b_ptr2, * b_ptr3;
	b_ptr0 = b_src;
	b_ptr1 = b_ptr0 + lda;
	b_ptr2 = b_ptr1 + lda;
	b_ptr3 = b_ptr2 + lda;

	for (int i = 0; i < K; ++i)
	{
		*b_dest++ = *b_ptr0++;
		*b_dest++ = *b_ptr1++;
		*b_dest++ = *b_ptr2++;
		*b_dest++ = *b_ptr3++;
	}
}

static inline void copy_avx512_a(int lda, const int K, double* a_src, double* a_dest) {
	for (int i = 0; i < K; ++i)
	{
		memcpy(a_dest, a_src, 32 * 8);
		a_dest += 32;
		a_src += lda;
	}
}

static inline void do_block_avx512(int lda, int M, int N, int K, double* A, double* B, double* C)
{
	double* A_block, * B_block;
	A_block = (double*)_mm_malloc(M * K * sizeof(double), 64);
	B_block = (double*)_mm_malloc(K * N * sizeof(double), 64);

	double* a_ptr, * b_ptr, * c;

	const int Nmax = N - 3;
	int Mmax = M - 32;

	int i = 0, j = 0, p = 0;

	for (j = 0; j < Nmax; j += 4)
	{
		b_ptr = &B_block[j * K];
		copy_avx512_b(lda, K, B + j * lda, b_ptr); // 将 B 展开
		for (i = 0; i < Mmax; i += 32) {
			a_ptr = &A_block[i * K];
			if (j == 0) copy_avx512_a(lda, K, A + i, a_ptr); // 将 A 展开
			c = C + i + j * lda;
			block_avx512_32x4(lda, K, a_ptr, b_ptr, c);
		}
	}
	_mm_free(A_block);
	_mm_free(B_block);
}

void gemm_avx512(int lda, double* A, double* B, double* C)
{
#pragma omp parallel for
	for (int j = 0; j < lda; j += BLOCK_SIZE) {    // j i k 序 内存读写更快
		for (int i = 0; i < lda; i += BLOCK_SIZE) {
			for (int k = 0; k < lda; k += BLOCK_SIZE) {
				// 大分块里小分块
				for (int jj = j; jj < j + BLOCK_SIZE; jj += BLOCK_SIZE2)
					for (int ii = i; ii < i + BLOCK_SIZE; ii += BLOCK_SIZE2)
						for (int kk = k; kk < k + BLOCK_SIZE; kk += BLOCK_SIZE2)
							do_block_avx512(lda, BLOCK_SIZE2, BLOCK_SIZE2, BLOCK_SIZE2, A + ii + kk * lda, B + kk + jj * lda, C + ii + jj * lda);
			}
		}
	}
}
posted @ 2024-02-01 00:47  Icys  阅读(31)  评论(0编辑  收藏  举报