tensor修改维度(cat, stack, unsqueeze, repeat_interleave, permute)

在 PyTorch 中,处理张量维度的操作是非常常见的。这里简要介绍如何增加和减少维度,以及 cat, stack, squeeze, 和 unsqueeze 的使用方法。

增加维度

  • unsqueeze: 在指定位置增加一个维度(即,将一维变为二维,二维变为三维等)。例如,有一个形状为 (2, 3) 的张量,你可以通过 unsqueeze(0) 在第一个位置增加一个维度,使其变成形状 (1, 2, 3)

    import torch
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    print(x.unsqueeze(0).shape)  # 输出: torch.Size([1, 2, 3])
    

减少维度

  • squeeze: 移除所有大小为1的维度。如果指定了维度,则仅移除指定位置大小为1的维度。

    x = torch.rand(1, 2, 3, 1)
    print(x.squeeze().shape)  # 如果可能,输出: torch.Size([2, 3])
    

连接张量

  • cat (concatenate): 按照指定维度连接一系列张量。这些张量必须具有相同的形状,除了连接的维度外。

    x = torch.randn(2, 3)
    y = torch.randn(2, 3)
    print(torch.cat((x, y), dim=0).shape)  # 输出: torch.Size([4, 3])
    
  • stack: 沿新维度连接一系列张量。与 cat 不同,stack 会创建一个新的维度来放置这些张量。

    x = torch.randn(2, 3)
    y = torch.randn(2, 3)
    print(torch.stack((x, y), dim=0).shape)  # 输出: torch.Size([2, 2, 3])
    

repeat_interleave 是 PyTorch 中用于对张量进行重复操作的一个方法。它允许你以灵活的方式重复张量的元素,不仅限于简单的逐个元素复制,还可以指定不同的维度和重复模式。这对于构建复杂的模型结构(如在注意力机制中生成合适的掩码或扩展张量)非常有用。

复制张量

torch.repeat_interleave(input, repeats, dim=None)
  • input: 输入张量。
  • repeats: 整数或张量。表示每个元素要重复几次。如果是一个张量,则其长度必须与输入张量在指定维度上的长度相匹配。
  • dim: 可选参数,指定在哪一个维度上进行重复。如果不指定,则默认会在展平后的张量上进行操作。

示例

1. 简单重复

假设我们有一个一维张量,并想简单地重复每个元素:

x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)
# 输出: tensor([1, 1, 2, 2, 3, 3])

2. 指定维度重复

对于二维张量,我们可以指定在一个维度上重复:

x = torch.tensor([[1, 2], [3, 4]])
x.repeat_interleave(2, dim=0)
# 输出:
# tensor([[1, 2],
#         [1, 2],
#         [3, 4],
#         [3, 4]])

x.repeat_interleave(2, dim=1)
# 输出:
# tensor([[1, 1, 2, 2],
#         [3, 3, 4, 4]])

3. 不同元素不同次数重复

更进一步,我们可以为每个元素指定不同的重复次数:

x = torch.tensor([1, 2, 3])
repeats = torch.tensor([2, 3, 1])
torch.repeat_interleave(x, repeats)
# 输出: tensor([1, 1, 2, 2, 2, 3])

这里,第一个元素重复了2次,第二个元素重复了3次,而第三个元素只重复了1次。

在注意力机制中的应用

在实现自注意力机制时,repeat_interleave 可以用来扩展查询、键或值的张量,以便它们可以与不同批次的序列正确对齐。例如,在处理变长序列时,可能需要根据每个序列的实际长度来调整查询、键和值的尺寸,这时 repeat_interleave 就显得特别有用。

总之,repeat_interleave 提供了一种强大且灵活的方法来处理张量的重复操作,无论是简单的元素重复还是更复杂的基于维度的重复需求。

重排列张量

torch.permute 是 PyTorch 中用于重新排列张量维度顺序的一个方法。使用 permute,你可以指定每个维度的新位置,而不改变数据布局或数据本身。这对于处理多维数据(如图像、视频等)以及在深度学习模型中调整张量形状特别有用。

基本用法

torch.permute(dims)
  • dims: 一个整数的元组或列表,表示新的维度顺序。每个整数代表原始张量维度的新位置。

示例

1. 简单的例子

假设有一个形状为 (2, 3, 4) 的张量:

x = torch.randn(2, 3, 4)
print(x.shape)  # 输出: torch.Size([2, 3, 4])

如果我们想交换第一和第二维度的位置,可以这样做:

y = x.permute(1, 0, 2)
print(y.shape)  # 输出: torch.Size([3, 2, 4])

这意味着原来的第一个维度(大小为 2)现在变成了第二个维度,而原来的第二个维度(大小为 3)现在成了第一个维度。

2. 应用到实际场景

考虑一个更具体的例子,比如处理一批图像数据,每张图像有颜色通道(RGB)、高度和宽度三个维度。如果输入张量形状是 (batch_size, channels, height, width),而在某些情况下你可能需要将其转换为 (batch_size, height, width, channels) 形式以适应特定的操作或库(例如某些可视化工具),你可以使用 permute 来实现这一点。

images = torch.randn(10, 3, 64, 64)  # 假设有10张图片,每张图片3个通道,大小为64x64
# 将维度从 (batch_size, channels, height, width) 转换为 (batch_size, height, width, channels)
images_permuted = images.permute(0, 2, 3, 1)
print(images_permuted.shape)  # 输出: torch.Size([10, 64, 64, 3])

注意事项

  • 使用 permute 后得到的新张量与原张量共享内存,这意味着修改其中一个会影响另一个。
  • permute 不会改变张量元素的数量或者它们之间的相对位置;它仅改变维度的顺序。
  • 如果你需要增加或减少维度,应该使用 unsqueezesqueeze,而不是 permute
posted @ 2025-05-23 09:44  玉米面手雷王  阅读(128)  评论(0)    收藏  举报