torch.gather

解释:
以下面代码为例:
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(dim=1, index)
print(tensor_1)
(1) output.shape = index.shape # 确定最后输出的output的shape必须与index的相同,这里是13的tensor,那么output必须也是13的tensor,先把壳打起来torch.tensor([[?,?,?]])
(2) 对output所有值的索引,按shape方式排出来,也就是[[(0,0),(0,1),(0,2)]]
(3) 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行,dim=1即替换列。变成[[(0,2),(0,1),(0,0)]]
(4) 按这个索引获取tensor_0相应位置的值,填进去就好了,得到torch.tensor([[5,4,3]])
本文来自博客园,作者:SXQ-BLOG,转载请注明原文链接:https://www.cnblogs.com/sxq-blog/p/18123415

浙公网安备 33010602011771号