pytorch
torch.cat(tensors, dim=0, *, out=None) → Tensor
Concatenates the given sequence of tensors in tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).
torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
cat是对tensor序列进行拼接,注意这里的输出参数是tensor序列, 可以是list, 也可以是tuple等
>>> x
tensor([[[[6, 4, 4],
[5, 7, 1]],
[[9, 7, 2],
[5, 9, 2]],
[[5, 7, 0],
[4, 9, 0]],
[[7, 3, 6],
[7, 6, 2]]]])
>>> x.shape
torch.Size([1, 4, 2, 3])
>>> y = torch.cat([x] * 2)
>>> y.shape
torch.Size([2, 4, 2, 3])

浙公网安备 33010602011771号