Pytorch: repeat, repeat_interleave, tile的用法

https://zhuanlan.zhihu.com/p/474153365

torch.repeat
使张量沿着某个维度进行复制, 并且不仅可以复制张量,也可以拓展张量的维度:

import torch

x = torch.randn(2, 4)


# 1. 沿着某个维度复制
x.repeat(1, 1).size()  # torch.Size([2, 4])

x.repeat(2, 1).size()  # torch.Size([4, 4])

x.repeat(1, 2).size()  # torch.Size([2, 8])


# 2. 不仅可以复制维度, 还可以拓展维度
x.repeat(1, 1, 1).size()  # torch.Size([1, 2, 4])

x.repeat(2, 1, 1).size()  # torch.Size([2, 2, 4])

x.repeat(1, 1, 1, 1).size()  # torch.Size([1, 1, 2, 4])


# 3. repeat中传入的参数不可以少于x的维度
x.repeat(1)  # 报错
torch.repeat_interleave
torch.repeat_interleave的行为与numpy.repeat类似,但是和torch.repeat不同,这边还是以代码为例:

import torch
x = torch.randn(2, 2)

print(x)
>>> tensor([[ 0.4332,  0.1172],
            [ 0.8808, -1.7127]])

print(x.repeat(2, 1))
>>> tensor([[ 0.4332,  0.1172],
            [ 0.8808, -1.7127],
            [ 0.4332,  0.1172],
            [ 0.8808, -1.7127]])

print(x.repeat_interleave(2, dim=0))
>>> tensor([[ 0.4332,  0.1172],
            [ 0.4332,  0.1172],
            [ 0.8808, -1.7127],
            [ 0.8808, -1.7127]])

print(x.repeat_interleave(2, dim=1))
>>> tensor([[ 0.4332,  0.4332,  0.1172,  0.1172],
            [ 0.8808,  0.8808, -1.7127, -1.7127]])

# 如果不传dim参数, 则默认复制后拉平
print(x.repeat_interleave(2))
>>> tensor([ 0.4332,  0.4332,  0.1172,  0.1172,  0.8808,  0.8808, -1.7127, -1.7127])
从这个代码可以看出来torch.repeat更像是把tensor作为一个整体进行复制, 而torch.repeat_interleave更是针对tensor里的每个元素进行复制,并且torch.repeat_interleave可以通过传入一个一维的torch.Tensor来指定每个元素复制的次数

import torch
x = torch.tensor([[1, 2], [3, 4]])

result = torch.repeat_interleave(x, torch.tensor([1, 3]), dim=0)
print(result)
>>> tensor([[1, 2],
            [3, 4],
            [3, 4],
            [3, 4]])
torch.tile
torch.tile函数也是元素复制的一个函数, 但是在传参上和torch.repeat不同,但是也是以input为一个整体进行复制, torch.tile如果只传入一个参数的话, 默认是沿着行进行复制

import torch
x = torch.tensor([[1, 2], [3, 4]])

# 只传入一个参数
print(x.tile((2, )))
>>> tensor([[1, 2, 1, 2],
            [3, 4, 3, 4]])

print(x.repeat(1, 2))
>>> tensor([[1, 2, 1, 2],
            [3, 4, 3, 4]])
torch.tile传入一个元组的话, 表示(行复制次数, 列复制次数)

import torch
x = torch.tensor([[1, 2], [3, 4]])

print(x.tile((2, 2)))
>>> tensor([[1, 2, 1, 2],
            [3, 4, 3, 4],
            [1, 2, 1, 2],
            [3, 4, 3, 4]])

print(x.repeat(2, 2))
>>> tensor([[1, 2, 1, 2],
            [3, 4, 3, 4],
            [1, 2, 1, 2],
            [3, 4, 3, 4]])
当传入的参数少于需要复制的元素的维度时, 如果一个tensor的形状为(2, 2, 2),传入tile中的参数为(2, 2)时, 会默认表示为(1, 2, 2)

import torch
x = torch.randn(2, 2, 2)
print(x)
>>> tensor([[[ 0.8517,  0.8721],
             [-1.1591, -0.2000]],

            [[ 0.3888, -0.8365],
             [-1.6383, -0.1539]]])

print(x.tile((2, 2)))
>>> tensor([[[ 0.8517,  0.8721,  0.8517,  0.8721],
             [-1.1591, -0.2000, -1.1591, -0.2000],
             [ 0.8517,  0.8721,  0.8517,  0.8721],
             [-1.1591, -0.2000, -1.1591, -0.2000]],

            [[ 0.3888, -0.8365,  0.3888, -0.8365],
             [-1.6383, -0.1539, -1.6383, -0.1539],
             [ 0.3888, -0.8365,  0.3888, -0.8365],
             [-1.6383, -0.1539, -1.6383, -0.1539]]])
当传入的参数多于需要复制的元素维度时,会拓展维度

import torch
x = torch.randn(2, 2)
print(x)
>>> tensor([[ 1.1165, -0.5559],
            [-0.6341,  0.5215]])

print(x.tile((2, 2, 2)))
>>> tensor([[[ 1.1165, -0.5559,  1.1165, -0.5559],
             [-0.6341,  0.5215, -0.6341,  0.5215],
             [ 1.1165, -0.5559,  1.1165, -0.5559],
             [-0.6341,  0.5215, -0.6341,  0.5215]],

            [[ 1.1165, -0.5559,  1.1165, -0.5559],
             [-0.6341,  0.5215, -0.6341,  0.5215],
             [ 1.1165, -0.5559,  1.1165, -0.5559],
             [-0.6341,  0.5215, -0.6341,  0.5215]]])


使用tile和reshape代替repeat_interleave
import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)

y = torch.repeat_interleave(x, repeats=3, dim=0)

print(y)
>>> tensor([[1, 2, 3],
            [1, 2, 3],
            [1, 2, 3],
            [4, 5, 6],
            [4, 5, 6],
            [4, 5, 6]])

# 直接使用tile, 无法得到类似的结果
z = torch.tile(x, (3, ))
print(z)
>>> tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
            [4, 5, 6, 4, 5, 6, 4, 5, 6]])

z = torch.tile(x, (3, 1))
print(z)
>>> tensor([[1, 2, 3],
            [4, 5, 6],
            [1, 2, 3],
            [4, 5, 6],
            [1, 2, 3],
            [4, 5, 6]])

# 需要使用 tile + reshape 才可以得到类似的结果
z = torch.tile(x, (3, ))
print(z.shape)  # (2, 9)
print(z.reshape(6, 3))  # 得到了和y一样的输出
>>> tensor([[1, 2, 3],
            [1, 2, 3],
            [1, 2, 3],
            [4, 5, 6],
            [4, 5, 6],
            [4, 5, 6]])
posted @ 2022-08-18 18:18  SXQ-BLOG  阅读(937)  评论(0)    收藏  举报