gather算子大不同

技术背景

在MindSpore和PyTorch框架中都有关于gather算子的实现。其实gather算子就是根据张量的索引,在指定维度下提取元素。但是gather算子在两个框架中的实现又有所不同,本文用几个示例来展开介绍一下。

MindSpore示例

在MindSpore中的gather实现,可以支持多维度的index:

In [1]: import mindspore as ms

In [2]: arr = ms.numpy.ones((1,27), dtype=ms.float32)

In [3]: idx = ms.numpy.zeros((27,26), dtype=ms.int32)

In [4]: res = ms.ops.gather(arr, idx, -1)

In [5]: res.shape
Out[5]: (1, 27, 26)

这里index的维度是(27,26),而我们去索引的只有-1这个维度。其实在MindSpore中应该是通过内部实现,把index展平之后进行索引在做一个reshape,而这些内容都不在用户层去操作。

PyTorch示例

在PyTorch中,index维度必须跟输入的维度数量一致,否则就会发生RuntimeError,例如:

In [1]: import torch as tc

In [2]: arr = tc.ones((1,27),dtype=tc.float32)

In [3]: idx = tc.zeros((27,26),dtype=tc.int64)

In [4]: res = tc.gather(arr, -1, idx)
--------------------------------------------------------------------------
RuntimeError                             Traceback (most recent call last)
Cell In[4], line 1
----> 1 res = tc.gather(arr, -1, idx)

RuntimeError: Size does not match at dimension 0 expected index [27, 26] to be smaller than self [1, 27] apart from dimension 1

In [5]: res = tc.gather(arr,-1,idx.reshape((1,-1))).reshape(idx.shape)

In [6]: res.shape
Out[6]: torch.Size([27, 26])

在这个示例中我们可以看到,在PyTorch里面不能支持跟输入的维度数量不同的索引张量,要进行手动的展开,最后再手动的reshape回去。

总结概要

本文通过2个实际的案例,演示了一下gather算子在MindSpore框架下PyTorch框架下的异同点。两者的输入都是tensor-axis-index,一个是输入顺序上略有区别,另一个是对于输入的张量索引维度的要求。在PyTorch中,如果我们要实现类似于MindSpore中的gather功能,需要手动对输入索引的维度操作一下。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/gather-ops.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

posted @ 2025-06-18 15:06  DECHIN  阅读(123)  评论(0)    收藏  举报