torch.cat使用

torch.cat的作用是对张量按行或列进行拼接。在使用的过程中我也碰到了一些需要注意的点。

 

import torch

x = torch.rand((4,5))
print(x)
tmp = torch.Tensor()

for i in range(4):
    if i%2 == 0:
        tmp = torch.cat((tmp, x[[i]]), dim=0)
print(tmp)

在上述代码中,如果想要把x的第0,2行和tmp按行拼接起来,则下标索引要用 [ ] 括起来。

 

 

否则如果是 x[i] 的写法,则会变成:

 

posted @ 2020-10-13 11:02  Kayden_Cheung  阅读(552)  评论(0编辑  收藏  举报
//目录