torch.gather

image

解释:

以下面代码为例:

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(dim=1, index)
print(tensor_1)

(1) output.shape = index.shape # 确定最后输出的output的shape必须与index的相同,这里是13的tensor,那么output必须也是13的tensor,先把壳打起来torch.tensor([[?,?,?]])

(2) 对output所有值的索引,按shape方式排出来,也就是[[(0,0),(0,1),(0,2)]]

(3) 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行,dim=1即替换列。变成[[(0,2),(0,1),(0,0)]]

(4) 按这个索引获取tensor_0相应位置的值,填进去就好了,得到torch.tensor([[5,4,3]])

posted @ 2024-04-09 10:53  SXQ-BLOG  阅读(27)  评论(0)    收藏  举报