PyTorch 函数
scatter
scatter 定义:
tgt.scatter(dim, idx, src) # 将 tgt 中的部分数据点用 src 替换
dim 参数指定 idx 应用的位置:
tgt[idx[i, j], j] = src[i, j] # dim = 0
tgt[i, idx[i, j]] = src[i, j] # dim = 1
让我们看一些实例。首先定义数据:
src = tensor([[5, 6, 7, 8],
[9, 10, 11, 12]])
tgt = tensor([[0, 0, 0, 0],
[0, 0, 0, 0]])
接下来使用不同的 idx 进行 scatter 操作:
>>> idx = tensor([[0]])
>>> tgt.scatter(0, idx, src)
tensor([[5, 0, 0, 0],
[0, 0, 0, 0]]
>>> idx = tensor([[0, 1]])
>>> tgt.scatter(0, idx, src)
tensor([[5, 0, 0, 0],
[0, 6, 0, 0]])
>>> idx = tensor([[0, 1, 0, 0]])
>>> tgt.scatter(0, idx, src)
tensor([[5, 0, 7, 8],
[0, 6, 0, 0]])
看出规律了吗?实际上 idx 在这里起到了第 0 维索引的功能。如果我们使用 dim=1 再次实验:
>>> idx = tensor([[0]])
>>> tgt.scatter(1, idx, src)
tensor([[5, 0, 0, 0],
[0, 0, 0, 0]]
>>> idx = tensor([[0, 1]])
>>> tgt.scatter(1, idx, src)
tensor([[5, 6, 0, 0],
[0, 0, 0, 0]]
>>> idx = tensor([[0, 1, 3]])
>>> tgt.scatter(1, idx, src)
tensor([[5, 6, 0, 7],
[0, 0, 0, 0]])
可以看到这次 idx 起到了第 1 维索引的功能。
一般设置 dim=-1,这样可以实现将 tgt 中指定位置的点替换为 src 对应的点。
参考:torch.Tensor.scatter_ | PyTorch documentation
repeat
将 tensor 复制多次。
>>> data = th.tensor([4, 5])
>>> data.repeat(2, 3)
tensor([[4, 5, 4, 5, 4, 5],
[4, 5, 4, 5, 4, 5]])
stack
将两个 tensor 摞在一起。
>>> a = th.tensor([[ 1, 2, 3],
[ 4, 5, 6]]) # (2, 3)
>>> b = th.tensor([[ 7, 8, 9],
[10, 11, 12]]) # (2, 3)
>>> th.stack([a, b], dim=0) # (2, 2, 3)
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
cat
将两个 tensor 拼接在一起。
>>> a = th.tensor([[ 1, 2, 3],
[ 4, 5, 6]]) # (2, 3)
>>> b = th.tensor([[ 7, 8, 9],
[10, 11, 12]]) # (2, 3)
>>> th.cat([a, b], dim=0) # (4, 3)
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
unbind
沿指定维度将张量分解为子张量序列。
>>> x = th.tensor([[1, 2, 3],
[4, 5, 6]])
>>> x.unbind(0)
(tensor([1, 2, 3]), tensor([4, 5, 6])) # 沿第一维拆解成两个 tensor
unbind 是 stack 的逆操作。
linspace
生成等差数列。
>>> x = th.linspace(0, 100, 11)

浙公网安备 33010602011771号