高性能计算-cublas-gemm接口解析

1. 介绍

2. 接口

  • cuBLAS中用于运算矩阵乘法的函数有4个,分别是 cublasSgemm(单精度实数)、cublasDgemm(双精度实数)、cublasCgemm(单精度复数)、cublasZgemm(双精度复数),定义在 cublas_v2.h 和 cublas_api.h 中。

3. 参数介绍

  • 四个函数形式相似,均输入了14个参数。该函数实际上是用于计算 C = α A B +β C 的,其中 A m×k 、B k×n 、C m×n 为矩阵,α 、β 为标量。
  • 以 cublasSgemm_v2 为例:
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2(cublasHandle_t handle,
                                                     cublasOperation_t transa,
                                                     cublasOperation_t transb,
                                                     int m,
                                                     int n,
                                                     int k,
                                                     const float* alpha,
                                                     const float* A,
                                                     int lda,
                                                     const float* B,
                                                     int ldb,
                                                     const float* beta,
                                                     float* C,
                                                     int ldc);
  • 以下根据 cublas 的数据处理及计算流程,解析参数的含义及填写规则
流程 参数 解释
输入数据 A 输入第一个矩阵一维数组
B 输入第一个矩阵一维数组
数据处理 1. 设置主序解析参数 2. 设置转置参数 lda A 矩阵的主序维度,cublas 将按此参数按列保存数据,解析数组;如果A数组是依据C语言按行保存,那么 lda 应为A的列数,cublas 按此参数保存后的矩阵 A' 为 A 的转置矩阵。
ldb B 矩阵的主序维度,cutblas 将按此参数按列保存数据,解析数组。同理,得到B的转置矩阵 B'。
transa 是否再对 A' 转置;无论是否转置,这里将数据处理的结果标记为A''
transb 是否再对 B' 转置;无论是否转置,这里将数据处理的结果标记为B''
设置数据处理后参与计算的矩阵参数 m A'' 的行数
n B'' 的列数
k A'' 的列数
运算(对数据处理的结果进行数学运算) - A'' * B'';如果A' B' 进行了转置,A'' * B'' = A * B = C1'
计算结果 C C1' 按列保存的一维数组;
ldc C1' 的行数(即主序维度)
计算结果解析 - 根据ldc,如果A' B'使用了转置参,对C1' 一维数组按行保存的结果是 C的转置矩阵 C1'',需要对 C1'' 转置得到最终计算结果C。
  • 关于是否使用转置参数的参数填写示例:
  • 不使用转置:
    cublasSgemm_v2(handle, n, m, k, CUBLAS_OP_N, CUBLAS_OP_N ,alpha, B, n, A, k, beta, C, n)
  • cublas 计算公式:

\[C^{T} = B^{T} * A^{T} \]

  • 使用转置:
    cublasSgemm_v2(handle, m, n, k, CUBLAS_OP_T, CUBLAS_OP_T ,alpha, A, k, B, n, beta, C, m)
  • cublas 计算公式:

\[C = A * B \]

  • 注意:输出数组 C 解析出来的结果是最终结果的转置,需要再次转置得到最终结果。
posted @ 2025-08-27 16:17  安洛8  阅读(70)  评论(0)    收藏  举报