乘法

https://blog.csdn.net/leo_95/article/details/89946318

matmul---可以进行张量乘法, 输入可以是高维.

mat-------矩阵对应位相乘,两个tensor输入维度相同,(m×n)和(m×n),返回(m×n

mm--------只能进行矩阵乘法,也就是输入的两个tensor维度只能是(n×m) 和(m×p) ,返回(n×p

bmm------是两个三维张量相乘, 两个tensor维度是(b×n×m)和(b×m×p), 第一维b代表batch_size,返回(b×n×p)

import torch
a=torch.randn(4,3)
b=torch.randn(3,5)

x_mm=torch.mm(a,b)
x_matmul=torch.matmul(a,b)
print("x_mm",x_mm.shape,x_mm)
print("x_matmul",x_matmul.shape,x_matmul)

'''
x_mm torch.Size([4, 5]) tensor([[ 3.6220, -1.5059,  2.4699,  0.8015,  0.3079],
        [-0.0563,  2.5913,  0.2289,  0.9281, -0.4027],
        [-2.0079,  3.9938, -3.0274,  0.6345, -0.7448],
        [ 0.3485, -4.1003,  0.9263, -1.3284,  0.6894]])
x_matmul torch.Size([4, 5]) tensor([[ 3.6220, -1.5059,  2.4699,  0.8015,  0.3079],
        [-0.0563,  2.5913,  0.2289,  0.9281, -0.4027],
        [-2.0079,  3.9938, -3.0274,  0.6345, -0.7448],
        [ 0.3485, -4.1003,  0.9263, -1.3284,  0.6894]])



'''

a = torch.rand(4,3,28,64)
b = torch.rand(4,3,64,32)
x=torch.matmul(a,b).shape # [4, 3, 28, 32],只计算最后两维的乘积
print("x",x)

a = torch.rand(4,3,28,64)
b = torch.rand(4,1,64,32)
y=torch.matmul(a,b).shape # [4, 3, 28, 32] 有broadcasting 操作
print("y",y)

a = torch.rand(91,1)
b = torch.rand(8285,1,3)
z=torch.matmul(a,b).shape # [8285, 91, 3],有broadcasting 操作
print("z",z)

a = torch.rand(8285,91,1)
b = torch.rand(8285,1,3)
w=torch.matmul(a,b).shape # [8285, 91, 3],有broadcasting 操作
print("w",w)


'''
a = torch.rand(20,91,1)
b = torch.rand(8285,1,3)
w=torch.matmul(a,b).shape # [8285, 91, 3],有broadcasting 操作
print("w",w)
RuntimeError: The size of tensor a (20) must match the size of tensor b (8285) at non-singleton dimension 0
'''

  

 

 

 

 

 

      

posted on 2019-09-29 15:20  happygril3  阅读(293)  评论(0)    收藏  举报

导航