torch.gather()
《动手学Pytorch》:
代码:
y_hat =torch.tensor([[0.1,0.3,0.6],[0.3.0.2.0.5]])
y = torch.LongTensor([0,2])
y_hat.gather(1,y.view(-1,1))
输出:
tensor([[0.1000,
[0.5000]]])
理解:
gather(dim,index)
其中,dim指定索引维度,在上例中dim=1,即按行索引;
index指定索引位置,在上例中index = torch.LongTensor([0,2]).view(-1,1) ,即第0行取第0个元素,第一行取第2个元素

浙公网安备 33010602011771号