numpy和pytorch中的索引机制
“索引”指使用方括号来引用ndarray中的值(数组或是一个数),本文详细地介绍在Numpy和PyTorch中关于索引的使用
1. 基础索引
基础索引是指通过整数、或者slice object(即形状如[start: stop: step]的对象)完成的索引。也是日常使用中最简单的索引形式。
根据原始的张量array和索引矩阵idx又可以分为如下三种情况:
-
array为一维向量,通过idx索引到单元素;
-
array为多维张量,通过idx索引到单元素;
-
array为多维张量,通过idx索引到多维张量。
例如:
import numpy as np
arr = np.arange(24).reshape(2, 3, 4)
# Out [1]
# array([[[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10, 11]],
#
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
print(arr[0, 2])
# Out [2] 索引到多维输出
# [ 8 9 10 11]
print(arr[0, 0, -1])
# Out [3] 索引到单个元素
# 3
从上述例子中看到,使用整数对多维数组进行索引比较直观,易于理解。
值得注意的是:array[0, 2]和array[0][2]尽管得到的结果一致,但是索引过程不同。a[0, 2]直接索引得到结果,而a[0][2]则首先生成一个临时数组,随后得到结果。使用函数id(),也可以看出二者在内存中的位置并不相同。
上面介绍了通过整数数字进行索引,还可以通过slice object(索引对象)对ndarray进行切片和跨步等操作。
x = np.arange(10)
print(x[2: 5])
# Out [1]
# array([2, 3, 4])
print(x[:-7])
# Out [2]
# array(0, 1, 2)
print(x[1: 7: 2])
# Out [3]
# array([1, 3, 5])
y = np.arange(35).reshape(5, 7)
print(y[1:5:2, ::3]
# Out [4]
# array([[7, 10, 13],
# [21, 24, 27])
最后需要注意,使用基础索引得到结果,不会对数据进行拷贝,而是会生成原始数据的视图(view,类似C++中的reference)。如果需要进行拷贝,建议显式使用copy()。
2. 高级索引
2.1. 索引数组
不同于前面讲到的使用整数和slice object,Numpy和PyTorch支持使用ndarray(以及任何可以转换为数据的对象,比如列表list)来索引ndarray。对于索引数组,我们从易到难,通过例子来进行说明。
例:
x = np.arange(10, 1, -1)
print(x[np.array([3, 3, 1, 8]))
# Out [1]
# array([7, 7, 9, 2])
用一维数组索引一维数组,索引数组就好像是index的作用一样。同时也允许负数。注意,返回结果的shape和索引数组一样(而不是和原数组)。
例:
print( x[np.array([[1, 1], [2, 3]])] )
# Out [1]
# array([[9, 9],
# [8, 7])
甚至我们也可以用一个2D的ndarray作为索引数组去索引一个1D的ndarray,当然返回的结果也是一个2D的ndarray。
2.2. 索引多维数组
本文的关键来了,在多维数组中(也叫做张量,真正的N-D Array),我们有不止一个dimension。在每个dimension中,我们都可以设置一个索引数组。
情况1:
设矩阵y是一个5x7的矩阵,我们先来看第一个例子:
y = np.arange(35).reshape(5, 7)
print(y[np.array(0, 2, 4), np.array(0, 1, 2)]
# Out [1]
# array( [ 0, 15, 30] )
情况1:1)被索引的是个多维数组;2)每个dimension上都有一个索引数组,如上一例中两个dimension;3)每个索引数组的shape匹配。如上一个例子,两个索引数组都是3维数组。
那么,会首先把索引数组array(0, 2, 4)和array(0, 1, 2)组成3个坐标:(0, 0), (2, 1), (4, 2)。然后用这些坐标,依次去ndarray中拿值。
情况2:两个索引数组的shape匹配不上,怎么办?
Numpy会首先使用广播机制,看一下添加了广播后,两个值能不能匹配上,如下面例子:
print(y[np.array([0, 2, 4]), 1])
# Output [1]
# array([ 1, 25, 29])
两个索引数组分别是array([0, 2, 4])和1。通过广播机制,广播成(0, 1),(2, 1),(4, 1),三个坐标,然后依次往里面拿取值。
情况3:被索引数组有3-D,可我们只有2个索引数组?
结果是,没有被索引的维度,保持原样,只处理被索引的数组,参考下面的例子:
z = np.arange(24).reshape(2, 2, 2, 3)
print(z)
# Out [1]
# array([[[[ 0, 1, 2],
# [ 3, 4, 5]],
#
# [[ 6, 7, 8],
# [ 9, 10, 11]]],
#
#
# [[[12, 13, 14],
# [15, 16, 17]],
#
# [[18, 19, 20],
# [21, 22, 23]]]])
m = z[np.array([1, 1, 0]), :, np.array([0, 1, 0]), :])
print(m.shape)
# Out [2]
# (3, 2, 3)
解析:上面这个例子中,原始张量为(2, 2, 2, 3)的形状,在第0维、第2维的时候两个索引数组组成了3个坐标,这个是第一个的3。剩下的2,3,原样保留。
情况4,索引数组也可以是多维。
例:
z = np.arange(12).reshape(3, 4)
print(z)
# Out [1]
#array([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
print(z[np.array([[2, 1], [1, 0]]), :])
# Out [2]
#[[[ 8 9 10 11]
# [ 4 5 6 7]]
#
# [[ 4 5 6 7]
# [ 0 1 2 3]]]
我们用一个2x2的indexing去索引一个3x4的矩阵,最终用2x2替代了3,得到了2x2x4的张量。
3. 布尔索引
布尔索引是第三种重要的索引方式,布尔索引数组的形状必须与要索引的数组尺寸相同,如:
y = np.arange(35).reshape(5, 7)
b = y>20
print(y[b])
# Out [1]
# array([21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34])
与索引数组的情况不同,布尔索引数组得到的,是一个一维数组(想想看:这是为什么)。
这是由于,布尔数组的输出是一个不定长度的量,你不知道到底是1个元素满足条件,还是10个元素满足条件,因此,最好的方式是将他们展平为一维向量进行输出。
与索引数组情况类似,如果y的维数多于b,那么结果将是多维的:
print(b[:, 5])
# Out [1]
# [False False False True True]
print(y[b[:, 5]])
# Out [2]
# [[21 22 23 24 25 26 27]
# [28 29 30 31 32 33 34]]
这里y的shape本身是5x7,b[:, 5]本身是一个5维向量,那么就是挑出第4行和第5行,保留原张量里面的第二维dimension=7,最终得到一个2x7的矩阵。
需要强调一点的是,可以使用一个多维布尔索引数组去索引另一个多维数组。如使用一个2x3的二维布尔数组,从三维2x3x5的数组中去索引,最终能够得到一个Mx5的结果。
总结一下:
- 整数索引和slice object索引最简单,见名知意;
- 索引数组索引,牢记替换原则,第x维,被替换为索引数组的shape,牢记坐标匹配原则,多个索引数组先匹配成为坐标,再取值。
- 布尔数组索引,牢记shape匹配原则,只有相同shape的才能索引,牢记一维输出原则,最后输出的一维向量。
参考资料
[1] Numpy中文网 - 索引,https://numpy.org.cn/user/basics/indexing.html

浙公网安备 33010602011771号