乘法
https://blog.csdn.net/leo_95/article/details/89946318
matmul---可以进行张量乘法, 输入可以是高维.
mat-------矩阵对应位相乘,两个tensor输入维度相同,(m×n)和(m×n),返回(m×n)
mm--------只能进行矩阵乘法,也就是输入的两个tensor维度只能是(n×m) 和(m×p) ,返回(n×p)
bmm------是两个三维张量相乘, 两个tensor维度是(b×n×m)和(b×m×p), 第一维b代表batch_size,返回(b×n×p)
import torch
a=torch.randn(4,3)
b=torch.randn(3,5)
x_mm=torch.mm(a,b)
x_matmul=torch.matmul(a,b)
print("x_mm",x_mm.shape,x_mm)
print("x_matmul",x_matmul.shape,x_matmul)
'''
x_mm torch.Size([4, 5]) tensor([[ 3.6220, -1.5059, 2.4699, 0.8015, 0.3079],
[-0.0563, 2.5913, 0.2289, 0.9281, -0.4027],
[-2.0079, 3.9938, -3.0274, 0.6345, -0.7448],
[ 0.3485, -4.1003, 0.9263, -1.3284, 0.6894]])
x_matmul torch.Size([4, 5]) tensor([[ 3.6220, -1.5059, 2.4699, 0.8015, 0.3079],
[-0.0563, 2.5913, 0.2289, 0.9281, -0.4027],
[-2.0079, 3.9938, -3.0274, 0.6345, -0.7448],
[ 0.3485, -4.1003, 0.9263, -1.3284, 0.6894]])
'''
a = torch.rand(4,3,28,64)
b = torch.rand(4,3,64,32)
x=torch.matmul(a,b).shape # [4, 3, 28, 32],只计算最后两维的乘积
print("x",x)
a = torch.rand(4,3,28,64)
b = torch.rand(4,1,64,32)
y=torch.matmul(a,b).shape # [4, 3, 28, 32] 有broadcasting 操作
print("y",y)
a = torch.rand(91,1)
b = torch.rand(8285,1,3)
z=torch.matmul(a,b).shape # [8285, 91, 3],有broadcasting 操作
print("z",z)
a = torch.rand(8285,91,1)
b = torch.rand(8285,1,3)
w=torch.matmul(a,b).shape # [8285, 91, 3],有broadcasting 操作
print("w",w)
'''
a = torch.rand(20,91,1)
b = torch.rand(8285,1,3)
w=torch.matmul(a,b).shape # [8285, 91, 3],有broadcasting 操作
print("w",w)
RuntimeError: The size of tensor a (20) must match the size of tensor b (8285) at non-singleton dimension 0
'''
posted on 2019-09-29 15:20 happygril3 阅读(293) 评论(0) 收藏 举报
浙公网安备 33010602011771号