pytorch

torch.cat(tensors, dim=0, *, out=None) → Tensor
Concatenates the given sequence of tensors in tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).
torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().

tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.

cat是对tensor序列进行拼接,注意这里的输出参数是tensor序列, 可以是list, 也可以是tuple等

>>> x
tensor([[[[6, 4, 4],
          [5, 7, 1]],

         [[9, 7, 2],
          [5, 9, 2]],

         [[5, 7, 0],
          [4, 9, 0]],

         [[7, 3, 6],
          [7, 6, 2]]]])

>>> x.shape
torch.Size([1, 4, 2, 3])
>>> y = torch.cat([x] * 2)
>>> y.shape
torch.Size([2, 4, 2, 3])
posted @ 2025-04-11 10:45  风冷无霜  阅读(25)  评论(0)    收藏  举报