torch.stack()与torch.cat()

torch.stack():http://www.45fan.com/article.php?aid=1D8JGDik5G49DE1X

torch.stack()个人理解:属于先变形再cat的操作,所以在哪个维度上stack,要先把原数据变成相应维度上的值。
例如:x = [1, 2], y = [3, 4], torch.stack([x, y], dim=1)
x要先变成2*1的形状,即[[1], [2]], y也一样,[[3], [4]],然后再在第一个维度上叠加,变成[[1, 3], [2, 4]]

torch.cat()与torch.stack区别:

# 沿着dim连接seq中的tensor, 所有的tensor必须有相同的size或为empty, 其相反的操作为 torch.split() 和torch.chunk()
torch.cat(seq,dim=0,out=None) 
torch.stack(seq, dim=0, out=None) #同上
 
#注: .cat 和 .stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以理解为叠加
>>> a=torch.Tensor([1,2,3])
>>> torch.stack((a,a)).size()
torch.size(2,3)
>>> torch.cat((a,a)).size()
torch.size(6)
posted @ 2022-08-31 17:47  SXQ-BLOG  阅读(147)  评论(0)    收藏  举报