pytorch 张量索引
索引虽好,但有时却不易理解。
在Numpy中,您可以使用数组索引到数组中。例如,为了在二维数组中选择(1, 2)和(3, 2)处的元素,您可以这样做:
# data is [[0, 1, 2, 3, 4, 5], # [6, 7, 8, 9, 10, 11], # [12 13 14 15 16 17], # [18 19 20 21 22 23], # [24, 25, 26, 27, 28, 29]] data = np.reshape(np.arange(30), [5, 6]) a = [1, 3] b = [2, 2] selected = data[a, b] print(selected)
[ 8 20]
在张量中可以同样这么使用(pytorch中操作和numpy相似):
>>> import numpy as np >>> import torch >>> data = torch.tensor(np.reshape(np.arange(30), [5, 6])) >>> data tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17], [18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29]]) >>> a = [1, 3] >>> b = [2, 2] >>> data[a, b] tensor([ 8, 20])
二维的时候还比较容易理解,当扩展到三维时就复杂一些,举个例子:
x = torch.tensor(torch.arange(7*8*9)).resize(7,8,9) a = range(x.size(0)) b = [5,3,6,1,1,0,2] c = x[a,b,:] print(c) tensor([[ 45, 46, 47, 48, 49, 50, 51, 52, 53], [ 99, 100, 101, 102, 103, 104, 105, 106, 107], [198, 199, 200, 201, 202, 203, 204, 205, 206], [225, 226, 227, 228, 229, 230, 231, 232, 233], [297, 298, 299, 300, 301, 302, 303, 304, 305], [360, 361, 362, 363, 364, 365, 366, 367, 368], [450, 451, 452, 453, 454, 455, 456, 457, 458]]) c.shape torch.Size([7, 9])
参考:https://blog.csdn.net/goodxin_ie/article/details/89672700