torch.cat使用
torch.cat的作用是对张量按行或列进行拼接。在使用的过程中我也碰到了一些需要注意的点。
import torch
x = torch.rand((4,5))
print(x)
tmp = torch.Tensor()
for i in range(4):
if i%2 == 0:
tmp = torch.cat((tmp, x[[i]]), dim=0)
print(tmp)
在上述代码中,如果想要把x的第0,2行和tmp按行拼接起来,则下标索引要用 [ ] 括起来。

否则如果是 x[i] 的写法,则会变成:


浙公网安备 33010602011771号