PyTorch 函数

scatter

scatter 定义:

tgt.scatter(dim, idx, src)  # 将 tgt 中的部分数据点用 src 替换

dim 参数指定 idx 应用的位置:

tgt[idx[i, j], j] = src[i, j]  # dim = 0
tgt[i, idx[i, j]] = src[i, j]  # dim = 1

让我们看一些实例。首先定义数据:

src = tensor([[5,  6,  7,  8],
              [9, 10, 11, 12]])

tgt = tensor([[0,  0,  0,  0],
              [0,  0,  0,  0]])

接下来使用不同的 idx 进行 scatter 操作:

>>> idx = tensor([[0]])
>>> tgt.scatter(0, idx, src)
tensor([[5, 0, 0, 0],
        [0, 0, 0, 0]]
>>> idx = tensor([[0, 1]])
>>> tgt.scatter(0, idx, src)
tensor([[5, 0, 0, 0],
        [0, 6, 0, 0]])
>>> idx = tensor([[0, 1, 0, 0]])
>>> tgt.scatter(0, idx, src)
tensor([[5, 0, 7, 8],
        [0, 6, 0, 0]])

看出规律了吗?实际上 idx 在这里起到了第 0 维索引的功能。如果我们使用 dim=1 再次实验:

>>> idx = tensor([[0]])
>>> tgt.scatter(1, idx, src)
tensor([[5, 0, 0, 0],
        [0, 0, 0, 0]]
>>> idx = tensor([[0, 1]])
>>> tgt.scatter(1, idx, src)
tensor([[5, 6, 0, 0],
        [0, 0, 0, 0]]
>>> idx = tensor([[0, 1, 3]])
>>> tgt.scatter(1, idx, src)
tensor([[5, 6, 0, 7],
        [0, 0, 0, 0]])

可以看到这次 idx 起到了第 1 维索引的功能。

一般设置 dim=-1,这样可以实现将 tgt 中指定位置的点替换为 src 对应的点。

参考:torch.Tensor.scatter_ | PyTorch documentation

repeat

将 tensor 复制多次。

>>> data = th.tensor([4, 5])
>>> data.repeat(2, 3)
tensor([[4, 5, 4, 5, 4, 5],
        [4, 5, 4, 5, 4, 5]])

stack

将两个 tensor 摞在一起。

>>> a = th.tensor([[ 1,  2,  3],
                   [ 4,  5,  6]])  # (2, 3)

>>> b = th.tensor([[ 7,  8,  9],
                   [10, 11, 12]])  # (2, 3)

>>> th.stack([a, b], dim=0)        # (2, 2, 3)
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])

cat

将两个 tensor 拼接在一起。

>>> a = th.tensor([[ 1,  2,  3],
                   [ 4,  5,  6]])  # (2, 3)

>>> b = th.tensor([[ 7,  8,  9],
                   [10, 11, 12]])  # (2, 3)

>>> th.cat([a, b], dim=0)          # (4, 3)
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

unbind

沿指定维度将张量分解为子张量序列。

>>> x = th.tensor([[1, 2, 3],
                   [4, 5, 6]])
>>> x.unbind(0)
(tensor([1, 2, 3]), tensor([4, 5, 6]))  # 沿第一维拆解成两个 tensor

unbind 是 stack 的逆操作。

linspace

生成等差数列。

>>> x = th.linspace(0, 100, 11)

image

posted @ 2025-05-17 18:42  Undefined443  阅读(15)  评论(0)    收藏  举报