torch.gather()

《动手学Pytorch》:

  代码:

  y_hat =torch.tensor([[0.1,0.3,0.6],[0.3.0.2.0.5]])

       y = torch.LongTensor([0,2])

  y_hat.gather(1,y.view(-1,1))

  输出:

  tensor([[0.1000,

                    [0.5000]]])

 理解:

  gather(dim,index)

  其中,dim指定索引维度,在上例中dim=1,即按行索引;

                  index指定索引位置,在上例中index = torch.LongTensor([0,2]).view(-1,1) ,即第0行取第0个元素,第一行取第2个元素

posted @ 2020-08-15 15:46  此间一看客  阅读(720)  评论(0)    收藏  举报