torch.cat使用
torch.cat的作用是对张量按行或列进行拼接。在使用的过程中我也碰到了一些需要注意的点。
import torch x = torch.rand((4,5)) print(x) tmp = torch.Tensor() for i in range(4): if i%2 == 0: tmp = torch.cat((tmp, x[[i]]), dim=0) print(tmp)
在上述代码中,如果想要把x的第0,2行和tmp按行拼接起来,则下标索引要用 [ ] 括起来。
否则如果是 x[i] 的写法,则会变成: