2025.3.23
广播机制(Broadcasting) 是 PyTorch、NumPy 等科学计算库中的一种重要机制,用于在不同形状的张量或数组之间执行逐元素操作时,自动调整其形状以匹配。通过广播机制,可以避免显式地复制数据或扩展维度,从而提高代码的简洁性和计算效率123。
广播机制的核心规则
广播机制遵循以下规则:
- 维度对齐:从最后一个维度(即最右边的维度)开始,逐一对齐两个张量的维度。如果维度数不同,则在维度较少的张量的前面补充长度为1的维度1313。
- 维度匹配:对于每个维度,两个张量的尺寸必须满足以下条件之一:
- 扩展维度:如果某个维度的尺寸为1,则通过复制数据来扩展该维度,使其与另一个张量的尺寸匹配1313。
广播机制的具体步骤
- 补充维度:如果两个张量的维度数不同,则在维度较少的张量的前面补充长度为1的维度,使其维度数与较大的张量一致1313。
- 扩展维度:对于每个维度,如果一个张量的尺寸为1,则通过复制数据来扩展该维度,使其与另一个张量的尺寸匹配1313。
- 执行操作:在形状匹配后,对两个张量执行逐元素操作1313。
广播机制的示例
示例 1:标量与张量的广播
import torch
a = torch.tensor([1, 2, 3]) # 形状为 (3,)
b = 10 # 标量,形状为 ()
c = a + b # 广播机制将 b 扩展为 (3,)
print(c) # 输出: tensor([11, 12, 13])
在这个例子中,标量 b
被广播为形状 (3,)
,然后与 a
逐元素相加1323。
示例 2:不同形状的张量广播
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状为 (2, 3)
b = torch.tensor([10, 20, 30]) # 形状为 (3,)
c = a + b # 广播机制将 b 扩展为 (2, 3)
print(c) # 输出: tensor([[11, 22, 33], [14, 25, 36]])
在这个例子中,b
被广播为形状 (2, 3)
,然后与 a
逐元素相加1323。
示例 3:多维张量的广播
import torch
a = torch.tensor([[1], [2], [3]]) # 形状为 (3, 1)
b = torch.tensor([10, 20, 30]) # 形状为 (3,)
c = a + b # 广播机制将 a 扩展为 (3, 3), b 扩展为 (3, 3)
print(c) # 输出: tensor([[11, 21, 31], [12, 22, 32], [13, 23, 33]])
在这个例子中,a
被广播为形状 (3, 3)
,b
也被广播为形状 (3, 3)
,然后逐元素相加1323。
不支持广播的情况
如果两个张量的形状在任何维度上既不相等,也不为1,则无法进行广播,PyTorch 会抛出 RuntimeError
。例如:
import torch
a = torch.tensor([[1, 2], [3, 4]]) # 形状为 (2, 2)
b = torch.tensor([1, 2, 3]) # 形状为 (3,)
c = a + b # 报错:RuntimeError
因为 a
和 b
在第二个维度上的尺寸不匹配(2 != 3),且没有一方的尺寸为11323。
广播机制的总结
广播机制是一种强大的工具,能够简化代码并提高计算效率。通过自动扩展张量的形状,广播机制使得在不同形状的张量之间执行逐元素操作变得简单而高效1313。然而,使用时需要注意张量的形状是否满足广播规则,以避免错误。