## 看得不明不白（我在下一篇中写了如何理解gather的用法）

gather是一个比较复杂的操作，对一个2维tensor，输出的每个元素如下：

out[i][j] = input[index[i][j]][j]  # dim=0
out[i][j] = input[i][index[i][j]]  # dim=1


## 二维tensor的gather操作

### 注意index此时的值

index = t.LongTensor([[0,1,2,3]])
print("index = \n", index)      #index是2维
print("index的形状: ",index.shape)  #index形状是(1,4)


index =
tensor([[0, 1, 2, 3]])
index的形状:  torch.Size([1, 4])


### 注意index此时的值

index = t.LongTensor([[0,1,2,3]]).t()  #index是2维
print("index = \n", index)    #index形状是(4,1)
print("index的形状: ",index.shape)


index =
tensor([[0],
[1],
[2],
[3]])
index的形状:  torch.Size([4, 1])


### 再来看看几个例子

b.gather()中取0维时，输出的结果是行形式，取1维时，输出的结果是列形式。

• b是一个 $3\times4$ 型的
>>> import torch as t
>>> b = t.arange(0,12).view(3,4)
>>> b
tensor([[ 0,  1,  2,  3],
[ 4,  5,  6,  7],
[ 8,  9, 10, 11]])
>>> index = t.LongTensor([[0,1,2]])

>>> index
tensor([[0, 1, 2]])

>>> b.gather(0,index)     #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

>>> index2 = t.LongTensor([[0,1,2]]).t()

>>> b.gather(1,index2)  #运行成功了
tensor([[ 0],
[ 5],
[10]])

>>> index3 = t.LongTensor([[0,1,2,3]]).t()

>>> b.gather(1,index3)  #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

• b是一个 $6\times6$ 型的
>>> import torch as t
>>> b = t.arange(0,36).view(6,6)
>>> b
tensor([[ 0,  1,  2,  3,  4,  5],
[ 6,  7,  8,  9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])

>>> index = t.LongTensor([[0,1,2,3,4,5,6]])
>>> b.gather(0,index)     #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

>>> index = t.LongTensor([[0,1,2,3,4,5]])
>>> b.gather(0,index)    #运行成功了
tensor([[ 0,  7, 14, 21, 28, 35]])
>>> b.gather(1,index)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

>>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()
>>> b.gather(1,index2)     #运行成功了
tensor([[ 0],
[ 7],
[14],
[21],
[28],
[35]])

>>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
>>> b.gather(1,index3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

>>> index4 = t.LongTensor([[0,1,2,3,4]])
>>> b.gather(0,index4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620



### 与gather相对应的逆操作是scatter_，gather把数据从input中按index取出，而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。

out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)


# 把两个对角线元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())


tensor([[ 0.,  0.,  0.,  3.],
[ 0.,  5.,  6.,  0.],
[ 0.,  9., 10.,  0.],
[12.,  0.,  0., 15.]])


posted on 2018-08-09 19:36  星辰之衍  阅读(4005)  评论(0编辑  收藏