10.2.2 平均汇聚
torch.repeat_interleave 用于按指定规则重复张量的元素,支持按维度扩展或自定义每个元素的重复次数。以下是详细说明和示例:
作用
- 功能:沿特定维度重复张量的元素,支持两种模式:
- 统一重复次数:所有元素重复相同次数。
- 自定义重复次数:每个元素按单独指定的次数重复。
- 与
torch.repeat的区别:repeat复制整个张量(如[1,2] → [1,2,1,2])。repeat_interleave逐个元素连续重复(如[1,2] → [1,1,2,2])。
语法
torch.repeat_interleave(
input, # 输入张量
repeats, # 重复次数(整数或张量)
dim=None, # 操作的维度(默认为展平后处理)
output_size=None # 预分配输出大小(可选)
)
参数详解
repeats:- 若为整数:所有元素重复该次数。
- 若为张量:形状需与
input在dim维度一致,每个元素对应重复次数。
dim:- 指定操作的维度。若为
None,先展平输入张量再处理。
- 指定操作的维度。若为
output_size:- 预定义输出大小,用于性能优化(通常无需手动指定)。
示例
示例 1:一维张量,统一重复次数
x = torch.tensor([1, 2, 3])
y = torch.repeat_interleave(x, repeats=2)
print(y) # 输出: tensor([1, 1, 2, 2, 3, 3])
示例 2:二维张量,沿指定维度重复
x = torch.tensor([[1, 2], [3, 4]])
# 沿 dim=0 重复(每行重复2次)
y = torch.repeat_interleave(x, repeats=2, dim=0)
print(y)
# 输出:
# tensor([[1, 2],
# [1, 2],
# [3, 4],
# [3, 4]])
示例 3:自定义每个元素的重复次数
x = torch.tensor([1, 2, 3])
repeats = torch.tensor([2, 3, 1]) # 每个元素重复2、3、1次
y = torch.repeat_interleave(x, repeats=repeats)
print(y) # 输出: tensor([1, 1, 2, 2, 2, 3])
示例 4:沿不同维度自定义重复次数
x = torch.tensor([[1, 2], [3, 4]])
repeats = torch.tensor([1, 2]) # 沿 dim=1 的每个元素重复1、2次
y = torch.repeat_interleave(x, repeats=repeats, dim=1)
print(y)
# 输出:
# tensor([[1, 2, 2],
# [3, 4, 4]])
形状计算规则
- 统一重复次数:
- 若输入形状为
(D1, D2, ..., Dn),沿dim=k重复R次,则输出形状为(D1, ..., Dk*R, ..., Dn)。
- 若输入形状为
- 自定义重复次数:
- 沿
dim=k的维度大小变为sum(repeats),其他维度不变。
- 沿
注意事项
- 广播限制:
repeats张量必须与输入在dim维度长度一致。 - 性能优化:若已知输出大小,可用
output_size避免额外内存分配。 - 默认维度:
dim=None时,输入会被展平为 1D 再处理。
通过灵活设置 repeats 和 dim,可以实现复杂的数据扩展需求。

浙公网安备 33010602011771号