pytorch 张量索引

索引虽好,但有时却不易理解。

在Numpy中,您可以使用数组索引到数组中。例如,为了在二维数组中选择(1, 2)和(3, 2)处的元素,您可以这样做:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)
[ 8 20]

  

在张量中可以同样这么使用(pytorch中操作和numpy相似):

>>> import numpy as np
>>> import torch
>>> data = torch.tensor(np.reshape(np.arange(30), [5, 6]))
>>> data
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29]])
>>> a = [1, 3]
>>> b = [2, 2]
>>> data[a, b]
tensor([ 8, 20])

  

二维的时候还比较容易理解,当扩展到三维时就复杂一些,举个例子:

x = torch.tensor(torch.arange(7*8*9)).resize(7,8,9)
a = range(x.size(0))
b = [5,3,6,1,1,0,2]
c = x[a,b,:]
print(c)
tensor([[ 45,  46,  47,  48,  49,  50,  51,  52,  53],
        [ 99, 100, 101, 102, 103, 104, 105, 106, 107],
        [198, 199, 200, 201, 202, 203, 204, 205, 206],
        [225, 226, 227, 228, 229, 230, 231, 232, 233],
        [297, 298, 299, 300, 301, 302, 303, 304, 305],
        [360, 361, 362, 363, 364, 365, 366, 367, 368],
        [450, 451, 452, 453, 454, 455, 456, 457, 458]])

c.shape
torch.Size([7, 9])

  

参考:https://blog.csdn.net/goodxin_ie/article/details/89672700  

posted on 2021-11-20 15:48  朴素贝叶斯  阅读(197)  评论(0编辑  收藏  举报

导航