pytorch多维张量相乘和广播机制示例
示例:
import torch
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
[0.6000, 0.6000, 0.9000, 0.9000],
[0.1000, 0.1000, 0.2000, 0.2000]],
[[0.1000, 0.2000, 0.5000, 0.3000],
[0.6000, 0.6000, 0.9000, 0.9000],
[0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)
wh = torch.tensor([[[200.],
[400.],
[200.],
[400.]],
[[200.],
[400.],
[200.],
[400.]]]).to(torch.float32)
print(box.shape) # (2, 3 ,4)
print(wh.shape) # (2, 4, 1)
result = box @ wh
print(result.shape) # (2, 3, 1)
print(result)
# tensor([[[320.],
# [900.],
# [180.]],
# [[320.],
# [900.],
# [180.]]])
下面这个示例用到了广播机制:
import torch
box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000],
[0.6000, 0.6000, 0.9000, 0.9000],
[0.1000, 0.1000, 0.2000, 0.2000]],
[[0.1000, 0.2000, 0.5000, 0.3000],
[0.6000, 0.6000, 0.9000, 0.9000],
[0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32)
wh = torch.tensor([[[200.],
[400.],
[200.],
[400.]]]).to(torch.float32)
print(box.shape) # (2, 3 ,4)
print(wh.shape) # (1, 4, 1) 注意这个wh的第0维度的大小是1
result = box @ wh # 这里在第0维度会使用广播机制
print(result.shape) # (2, 3, 1)
print(result)
# tensor([[[320.],
# [900.],
# [180.]],
# [[320.],
# [900.],
# [180.]]])

浙公网安备 33010602011771号