torch.mul() 和 torch.mm() 的区别

torch.mul(a, b) 是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵,和a*b效果相同
torch.mm(a, b) 是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

import torch

a = torch.tensor([[1,1],
                  [2,2]])
b = torch.tensor([[1,1],
                  [0,2]])
result1 = torch.mm(a,b)#矩阵相乘
result2 = torch.mul(a,b)#对应位相乘
result3 = a * b#对应位相乘
print("result1:" , result1)
print("result2:" , result2)
print("result3:" , result3)
result1: tensor([[1, 3],
        [2, 6]])
result2: tensor([[1, 1],
        [0, 4]])
result3: tensor([[1, 1],
        [0, 4]])
posted @ 2020-07-30 11:52  小Aer  阅读(455)  评论(0编辑  收藏  举报