pytorch各种乘法,mm, matmul, dot, @, *, mul, multiply

  1. torch.mm
    线代的矩阵乘法,要求输入都是矩阵

  2. torch.matmul
    注意:torch.mm和torch.matmul不等价
    根据输入不同执行不同的操作:

  • 输入都是二维矩阵,矩阵乘法,等同于torch.mm
  • 输入都是一维向量,计算向量内积,等同于torch.dot
  • 第一个参数是向量,第二个是矩阵,则将第一个参数变成(1,n)的矩阵,再执行矩阵乘法
  • 第一个参数是矩阵,第二个是向量,执行矩阵向量乘法,等同于torch.mv
  • 两个都是高维张量,自己看文档去
  1. torch.dot
    向量点积(内积),输入必须都是一维的。向量点积计算公式:
    \(\bold a=(a_1, a_2, a_3)\)
    \(\bold b=(b_1, b_2, b_3)\)
    \(\bold a \cdot \bold b=a_1b_1+a_2b_2+a_3b_3\)

因此向量内积是个标量

  1. torch.mul
    按元素相乘,element-wise的乘法,也叫哈达玛积

  2. torch.multiply
    torch.mul的别称

  3. *
    torch.mul的简写

  4. @
    torch.matmul的简写(注意不是torch.mat的简写)

  5. torch.outer
    向量外积,输入向量维度分别为n和m,则输出(n, m)

posted @ 2025-01-04 23:24  王冰冰  阅读(1358)  评论(4)    收藏  举报