torch.einsum 的计算过程
概论
a = torch.randn(3, 2, 2)
b = torch.randn(3)
c = torch.einsum('...chw,c->...hw', a, b)
上面的 einsum 如何计算的?
简单说,把 b 广播为 a 的形状,然后做矩阵乘法,即逐位相乘运算,注意,不是点积,是逐位的相乘运算。
注:这里符合背景需求,背景是,a 是深度学习的某个张量,b是a的权重,要求 a 的每一个元素都要乘以权重 b ,来得到实际有效的值。
然后,再把矩阵乘积的结果逐位相加后,得到最后结果,同时也去掉了维度c。
运算过程
具体运算细节如下:
为了详细解释 c = torch.einsum('...chw,c->...hw', a, b) 的计算过程,我们可以逐步分析每个部分的运算,并通过一个具体的例子说明结果的产生过程。
1. 张量 a 和 b 的形状与内容
a是一个形状为(3, 2, 2)的张量,假设其值为:a = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]], [[0.9, 1.0], [1.1, 1.2]]])b是一个形状为(3,)的张量,假设其值为:b = torch.tensor([2.0, 3.0, 4.0])
2. einsum 表达式 '...chw,c->...hw' 解析
-
...chw:...匹配任意数量的前导维度,在本例中没有前导维度。c对应的是第一个维度(形状为3)。h对应第二个维度(形状为2)。w对应第三个维度(形状为2)。
-
c->...hw:c对应b中的元素,作为缩并维度,它会在计算中被消除(通过乘法与求和操作)。...hw表示最终保留高度(h)和宽度(w)维度。
3. 具体计算过程
步骤 1:广播 b 以匹配 a 的形状
张量 b(形状 (3,))会被广播到形状 (3, 2, 2),与张量 a 对应的维度对齐:
b_broadcasted = torch.tensor([[[2.0, 2.0],
[2.0, 2.0]],
[[3.0, 3.0],
[3.0, 3.0]],
[[4.0, 4.0],
[4.0, 4.0]]])
步骤 2:逐元素相乘
a 和广播后的 b 逐元素相乘,结果为:
result = a * b_broadcasted
result = torch.tensor([[[0.2, 0.4],
[0.6, 0.8]],
[[1.5, 1.8],
[2.1, 2.4]],
[[3.6, 4.0],
[4.4, 4.8]]])
步骤 3:沿着 c 维度求和
对上面的结果,沿着 c 维度(也就是第一个维度)求和:
- 对
(0.2, 1.5, 3.6)这三个数在c维度上相加,得到5.3。 - 对
(0.4, 1.8, 4.0)这三个数在c维度上相加,得到6.2。 - 以此类推,求得最终结果
c:
c = torch.tensor([[5.3, 6.2],
[7.1, 8.0]])
4. 总结结果
最终的结果 c 是一个形状为 (2, 2) 的张量。每个元素是通过 a 中与 b 对应的元素逐元素相乘后,再沿 c 维度(通道维度)求和得到的。
因此,使用 torch.einsum('...chw,c->...hw', a, b) 后,得到的张量 c 对应的是去掉 c 维度后的加权和,仅保留高度 h 和宽度 w 维度的结果。
广播运算
问:b 广播为 b_broadcasted ,代码如何写
答
在 PyTorch 中,可以使用 unsqueeze 和 expand 方法来将张量 b 广播为与 a 形状匹配的 b_broadcasted。具体代码如下:
import torch
# 原始张量 a 和 b
a = torch.randn(3, 2, 2)
b = torch.randn(3)
# 扩展 b 的维度,使其能够广播到与 a 形状匹配
b_broadcasted = b.unsqueeze(-1).unsqueeze(-1).expand_as(a)
print(b_broadcasted)
步骤解释:
unsqueeze(-1):将b的形状从(3,)变为(3, 1),即在最后一维添加一个大小为1的新维度。unsqueeze(-1):再一次将形状从(3, 1)变为(3, 1, 1),在最后再添加一个大小为1的维度。expand_as(a):将b广播到与a相同的形状(3, 2, 2)。
结果示例:
假设 b 的值为 [2.0, 3.0, 4.0],那么 b_broadcasted 将变为:
tensor([[[2.0, 2.0],
[2.0, 2.0]],
[[3.0, 3.0],
[3.0, 3.0]],
[[4.0, 4.0],
[4.0, 4.0]]])
这就是张量 b 在计算过程中如何被广播为 b_broadcasted 的代码实现。

浙公网安备 33010602011771号