pytorch 的 scatter 函数中参数 index 是如何指定散布位置
基本逻辑
dim=0表示scatter操作将在行方向上进行。dim指定index参数的值。即 index[i,j]=v 的 v 的值。index张量的形状应该与src张量的形状匹配,或者能够广播到相同的形状。- 对于每一个
index的位置(i, j),src[i, j]中的值会被放置到res[index[i, j], j]中。
举例:若 index[2,3]=1, 代表分 3 步的任务。
首先,根据index的索引[2,3]取得src同索引的值,即 src[2,3],假设src[2,3]=10
其次,根据 dim 指定的值,作为索引的位置,用index的值取替换,如,当dim=0,则替换索引(2,3)的2,这样索引就变为(1,3)
最后,合并上两部的到的结果,把第2步得到的索引(1,3)用于res的索引,把第1步得到的src的值用于res的值,最终得到res[1,3]=10
一句话:若 dim=0, index[i,j]=v 中,取值 scr[i,j]=r,赋值 res[v,j]=r
若 dim = 1, index[i,j]=v ,则有:src[i,j]=r, res[i,v]=r
套用pytorch的文档:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
把self替换为res,就都一样了。
感悟:理解一个函数的功能,一定要找出这个功能背后的计算公式,才算是真明白,否则那就是一知半解,换一个数就不会了。
具体的例子
在 PyTorch 的 scatter 函数中,当你指定 dim=0 时,index 张量确实是指定了 src 中每个元素应该放置在 res 张量的哪一行。不过,列号的选择是由 index 张量的形状和 src 张量的形状共同决定的。
1. 基本逻辑
dim=0表示scatter操作将在行方向上进行。index张量的形状应该与src张量的形状匹配,或者能够广播到相同的形状。- 对于每一个
index的位置(i, j),src[i, j]中的值会被放置到res[index[i, j], j]中。
2. 列号是如何确定的
列号实际上是通过遍历 src 张量的第二个维度(即列的维度)隐含确定的。
假设 src 的形状为 (m, n),index 的形状与 src 相同,那么:
src中的每一列j(从0到n-1)会被映射到res张量的相同列j中。- 行号是由
index张量提供的,对应于src张量的每个元素在res张量中的行位置。
3. 结合例子解释
src = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
index = torch.tensor([[1, 1, 1]])
res = torch.zeros(2, 3).scatter_(0, index, src)
1. 初始情况
src张量的形状是(2, 3),意味着它有 2 行 3 列。index张量的形状是(1, 3),这意味着index为每一列指定了src中对应元素应该放置到res张量的哪个行中。
2. 逐列分析
对于第 j 列:
index[0, j]决定了src[0, j]和src[1, j]要放在res的哪个行。- 因为
dim=0,scatter_只在行维度(即第 0 维)上进行操作,列维度保持不变。
列 0 (j=0):
index[0, 0] = 1src[0, 0] = 1.将被放置在res[1, 0]src[1, 0] = 4.也将被放置在res[1, 0],最终覆盖1.
列 1 (j=1):
index[0, 1] = 1src[0, 1] = 2.将被放置在res[1, 1]src[1, 1] = 5.也将被放置在res[1, 1],最终覆盖2.
列 2 (j=2):
index[0, 2] = 1src[0, 2] = 3.将被放置在res[1, 2]src[1, 2] = 6.也将被放置在res[1, 2],最终覆盖3.
4. 总结
- 行号:由
index指定,表示src中的每个元素应该放置在res的哪一行。 - 列号:隐式决定,
src张量的每一列直接对应res张量的相同列。即src[i, j]的值会被放到res[index[i, j], j]。
所以在你的例子中,index 指定了所有元素应当放置在 res 张量的第 1 行,而列号则是隐式地从 src 中继承下来,与 src 的列号相同。

浙公网安备 33010602011771号