torch.einsum 的用法实例
torch 处理 tensor 张量的广播,使用 einsum 函数,摘录一段使用代码,并分析用法
# In[6]:
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_gray_weighted_fancy.shape
# Out[6]:
torch.Size([2, 5, 5])
这段代码利用 einsum 函数来进行张量运算,并将每个通道的图像加权转换为灰度图像。einsum 的使用使得代码简洁且易于理解。:
代码解读
-
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)img_t是一个图像张量,其形状假设为(3, 5, 5),表示3个通道的5x5图像(如RGB图像)。weights是一个权重张量,其形状为(3,),对应于每个通道的权重。'...chw'表示任意形状的张量(...),其中c是通道维度,h是高度维度,w是宽度维度。因为没有前导维度,img_t的形状具体为(3, 5, 5)。'c->...hw'表示将通道维度c与weights的权重值相乘,并将其求和后,留下高度和宽度维度(h, w)。最终结果是一个没有通道维度的二维张量,其形状为(5, 5),即灰度图像。
总结:
img_gray_weighted_fancy是一个5x5的灰度图像,原来的RGB图像被weights加权后生成。 -
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)batch_t是一个批量图像张量,假设形状为(2, 3, 5, 5),表示2个样本、每个样本有3个通道(如RGB),每个通道是5x5的图像。- 依然使用
weights作为权重张量,形状为(3,),对应每个通道的权重。 '...chw'这里的...表示前导的维度,即批量维度,这个索引表达式允许处理不同大小的张量。'c->...hw'指示要将通道维度c加权求和,结果保留批量、高度和宽度维度。对于批量图像,每个样本都会生成一个灰度图像。
总结:
batch_gray_weighted_fancy的形状为(2, 5, 5),表示2个样本对应的5x5灰度图像。 -
batch_gray_weighted_fancy.shape- 打印
batch_gray_weighted_fancy的形状,得到torch.Size([2, 5, 5])。
总结:最终输出的形状表明,处理后的批量图像变为2个5x5的灰度图像,每个样本的RGB图像都通过通道加权转换为了灰度图像。
- 打印
总览
einsum表达式'...chw,c->...hw'的含义是:对张量中指定的通道维度c进行加权求和,去除通道维度,保留批量、行和列的维度。- 这段代码通过
einsum函数简化了将彩色图像(多通道)转换为灰度图像的过程,避免了显式的广播和求和操作,极大地提升了代码的简洁性和可读性。
提问
上述代码"...chw,c->...hw"中,为什么两次使用 ... ,若变为"chw,c->...hw"可以吗?
回答:
在 einsum 表达式中,... 是一个通配符,表示任意数量的前导维度。这里的 ... 可以匹配0个或多个维度。表达式中的 ...chw,c->...hw 使用 ... 是为了处理任意数量的前导维度,而不仅仅是一个具体的维度。
为什么使用 ... 两次?
-
'...chw,c->...hw'中的...chw:- 这里的
...可以匹配batch_t的前导维度。例如,如果batch_t的形状是(2, 3, 5, 5),那么...匹配的就是2。 chw表示通道、高度、宽度,其中c是通道维度,h是高度维度,w是宽度维度。
- 这里的
-
'...chw,c->...hw'中的...hw:- 表示结果张量中保留前导维度
...和高度、宽度hw,但去除了通道维度c。
- 表示结果张量中保留前导维度
这意味着 ... 通配符可以让这个表达式更通用,无论张量前面有多少个维度,它都可以适用。比如:
- 对于一个形状为
(3, 5, 5)的图像张量,...chw会匹配(3, 5, 5),...hw会匹配(5, 5)。 - 对于一个形状为
(2, 3, 5, 5)的批量图像张量,...chw会匹配(2, 3, 5, 5),...hw会匹配(2, 5, 5)。
如果改为 'chw,c->...hw' 可以吗?
不可以。原因如下:
-
没有通配符
...来匹配前导维度:- 如果你使用
'chw,c->...hw',这里chw固定表示只有3个维度:通道、高度、宽度。 - 这个表达式不能匹配有前导维度的情况(如批量维度),即它只能处理形状精确为
(3, h, w)的张量。
- 如果你使用
-
无法处理批量张量:
- 如果
batch_t的形状是(2, 3, 5, 5),chw不能正确处理这个形状,因为它缺少了批量维度(2),结果会导致错误。 ...通配符的使用允许einsum表达式能够自动适应任意数量的前导维度(例如批量维度),使得代码更加通用。
- 如果
总结
使用 ... 通配符两次的原因是为了让这个 einsum 表达式适用于不同的张量形状,无论是否有前导维度(如批量维度)。这使得代码更加灵活和通用。如果你去掉其中的一个 ...,表达式的适用性将变得更窄,无法处理具有额外维度的张量。
问:...chw,c->...hw 变成 ...chw,c->hw,会有什么后果
答:
如果将 einsum 表达式从 ...chw,c->...hw 改为 ...chw,c->hw,会导致运算中前导维度(如批量维度)被丢弃或移除,从而改变了结果张量的形状和语义。下面详细解释其后果。
原始表达式:...chw,c->...hw
...chw:匹配任意数量的前导维度(...),以及通道、高度、宽度。c->...hw:消除了通道维度c,保留了前导维度...和高度、宽度hw。
对于形状为 (2, 3, 5, 5) 的 batch_t 张量(其中 2 是批量维度),运算后得到形状为 (2, 5, 5) 的张量。前导维度(批量维度)被保留。
修改后的表达式:...chw,c->hw
...chw:同样匹配任意数量的前导维度以及通道、高度、宽度。c->hw:这里的hw只保留了高度和宽度维度,完全忽略了前导维度...。
后果:
-
前导维度丢失:
- 前导维度
...(例如批量维度)被丢弃。 - 如果输入张量有前导维度(如批量维度
2),这些维度在结果中将不再存在。
- 前导维度
-
结果形状:
- 假设
batch_t的形状是(2, 3, 5, 5):...chw中的...匹配2(批量维度)。- 由于结果只保留
hw(高度和宽度),最终输出的形状将是(5, 5),即没有了批量维度。
- 这意味着,无论输入张量的前导维度是什么,最终结果都是将每个批次的图像折叠成一个单独的高度-宽度图像,导致批量信息丢失。
- 假设
-
语义混乱:
- 在深度学习处理中,保留批量维度通常是非常重要的。批量维度的丢失意味着无法将结果与原始输入数据一一对应。
- 如果你期望处理一组图像并保留每张图像的结果,但由于误操作丢失了批量维度,那么在后续步骤中处理这些结果会非常混乱。
结论
将 einsum 表达式从 ...chw,c->...hw 改为 ...chw,c->hw 会导致丢失前导维度(如批量维度)。如果前导维度被丢弃,输出将不再保留批次的结构信息,这在许多情况下可能是不可取的。通常你需要保留前导维度,除非你的具体应用场景明确不需要它们。

浙公网安备 33010602011771号