numpy和pytorch中的索引机制

“索引”指使用方括号来引用ndarray中的值(数组或是一个数),本文详细地介绍在Numpy和PyTorch中关于索引的使用

1. 基础索引

基础索引是指通过整数、或者slice object(即形状如[start: stop: step]的对象)完成的索引。也是日常使用中最简单的索引形式。

根据原始的张量array和索引矩阵idx又可以分为如下三种情况:

  • array为一维向量,通过idx索引到单元素;

  • array为多维张量,通过idx索引到单元素;

  • array为多维张量,通过idx索引到多维张量。

例如:

import numpy as np

arr = np.arange(24).reshape(2, 3, 4)
# Out [1]
# array([[[0, 1, 2, 3],
#         [4, 5, 6, 7],
#         [8, 9, 10, 11]],
# 
#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

print(arr[0, 2])
# Out [2] 索引到多维输出
# [ 8  9 10 11]

print(arr[0, 0, -1])
# Out [3] 索引到单个元素
# 3

从上述例子中看到,使用整数对多维数组进行索引比较直观,易于理解。

值得注意的是:array[0, 2]array[0][2]尽管得到的结果一致,但是索引过程不同。a[0, 2]直接索引得到结果,而a[0][2]则首先生成一个临时数组,随后得到结果。使用函数id(),也可以看出二者在内存中的位置并不相同。

上面介绍了通过整数数字进行索引,还可以通过slice object(索引对象)对ndarray进行切片和跨步等操作。

x = np.arange(10)
print(x[2: 5])
# Out [1]
# array([2, 3, 4])

print(x[:-7])
# Out [2]
# array(0, 1, 2)

print(x[1: 7: 2])
# Out [3]
# array([1, 3, 5])

y = np.arange(35).reshape(5, 7)
print(y[1:5:2, ::3]
# Out [4]
# array([[7, 10, 13],
#        [21, 24, 27])

最后需要注意,使用基础索引得到结果,不会对数据进行拷贝,而是会生成原始数据的视图(view,类似C++中的reference)。如果需要进行拷贝,建议显式使用copy()

2. 高级索引

2.1. 索引数组

不同于前面讲到的使用整数和slice object,Numpy和PyTorch支持使用ndarray(以及任何可以转换为数据的对象,比如列表list)来索引ndarray。对于索引数组,我们从易到难,通过例子来进行说明。

例:

x = np.arange(10, 1, -1)
print(x[np.array([3, 3, 1, 8]))
# Out [1]
# array([7, 7, 9, 2])

用一维数组索引一维数组,索引数组就好像是index的作用一样。同时也允许负数。注意,返回结果的shape和索引数组一样(而不是和原数组)。

例:

print( x[np.array([[1, 1], [2, 3]])] )
# Out [1]
# array([[9, 9],
#        [8, 7])

甚至我们也可以用一个2D的ndarray作为索引数组去索引一个1D的ndarray,当然返回的结果也是一个2D的ndarray。

 

2.2. 索引多维数组

本文的关键来了,在多维数组中(也叫做张量,真正的N-D Array),我们有不止一个dimension。在每个dimension中,我们都可以设置一个索引数组。

情况1:

设矩阵y是一个5x7的矩阵,我们先来看第一个例子:

y = np.arange(35).reshape(5, 7)
print(y[np.array(0, 2, 4), np.array(0, 1, 2)]
# Out [1]
# array( [ 0, 15, 30] )

情况1:1)被索引的是个多维数组;2)每个dimension上都有一个索引数组,如上一例中两个dimension;3)每个索引数组的shape匹配。如上一个例子,两个索引数组都是3维数组。

那么,会首先把索引数组array(0, 2, 4)array(0, 1, 2)组成3个坐标(0, 0), (2, 1), (4, 2)。然后用这些坐标,依次去ndarray中拿值。


情况2:两个索引数组的shape匹配不上,怎么办?

Numpy会首先使用广播机制,看一下添加了广播后,两个值能不能匹配上,如下面例子:

print(y[np.array([0, 2, 4]), 1])
# Output [1]
# array([ 1, 25, 29])

两个索引数组分别是array([0, 2, 4])1。通过广播机制,广播成(0, 1)(2, 1)(4, 1),三个坐标,然后依次往里面拿取值。


情况3:被索引数组有3-D,可我们只有2个索引数组?

结果是,没有被索引的维度,保持原样,只处理被索引的数组,参考下面的例子:

z = np.arange(24).reshape(2, 2, 2, 3)
print(z)
# Out [1]
# array([[[[ 0,  1,  2],
#          [ 3,  4,  5]],
#
#         [[ 6,  7,  8],
#          [ 9, 10, 11]]],
#
#
#        [[[12, 13, 14],
#          [15, 16, 17]],
#
#         [[18, 19, 20],
#          [21, 22, 23]]]])
m = z[np.array([1, 1, 0]), :, np.array([0, 1, 0]), :])
print(m.shape)
# Out [2]
# (3, 2, 3)

解析:上面这个例子中,原始张量为(2, 2, 2, 3)的形状,在第0维、第2维的时候两个索引数组组成了3个坐标,这个是第一个的3。剩下的2,3,原样保留。


情况4,索引数组也可以是多维。

例:

z = np.arange(12).reshape(3, 4)
print(z)
# Out [1]
#array([[ 0,  1,  2,  3],
#       [ 4,  5,  6,  7],
#       [ 8,  9, 10, 11]])
print(z[np.array([[2, 1], [1, 0]]), :])
# Out [2]
#[[[ 8  9 10 11]
#  [ 4  5  6  7]]
#
# [[ 4  5  6  7]
#  [ 0  1  2  3]]]

我们用一个2x2的indexing去索引一个3x4的矩阵,最终用2x2替代了3,得到了2x2x4的张量。

 

3. 布尔索引

布尔索引是第三种重要的索引方式,布尔索引数组的形状必须与要索引的数组尺寸相同,如:

y = np.arange(35).reshape(5, 7)
b = y>20
print(y[b])
# Out [1]
# array([21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34])

与索引数组的情况不同,布尔索引数组得到的,是一个一维数组(想想看:这是为什么)。

这是由于,布尔数组的输出是一个不定长度的量,你不知道到底是1个元素满足条件,还是10个元素满足条件,因此,最好的方式是将他们展平为一维向量进行输出。

与索引数组情况类似,如果y的维数多于b,那么结果将是多维的:

print(b[:, 5])
# Out [1]
# [False False False  True  True]
print(y[b[:, 5]])
# Out [2]
# [[21 22 23 24 25 26 27]
#  [28 29 30 31 32 33 34]]

这里y的shape本身是5x7b[:, 5]本身是一个5维向量,那么就是挑出第4行和第5行,保留原张量里面的第二维dimension=7,最终得到一个2x7的矩阵。

需要强调一点的是,可以使用一个多维布尔索引数组去索引另一个多维数组。如使用一个2x3的二维布尔数组,从三维2x3x5的数组中去索引,最终能够得到一个Mx5的结果。

总结一下

  1. 整数索引和slice object索引最简单,见名知意;
  2. 索引数组索引,牢记替换原则,第x维,被替换为索引数组的shape,牢记坐标匹配原则,多个索引数组先匹配成为坐标,再取值。
  3. 布尔数组索引,牢记shape匹配原则,只有相同shape的才能索引,牢记一维输出原则,最后输出的一维向量。

 

参考资料

[1] Numpy中文网 - 索引,https://numpy.org.cn/user/basics/indexing.html

posted @ 2022-04-16 15:32  tracer9  阅读(567)  评论(0)    收藏  举报