torch.where(condition, x, y)
torch.where 是一个非常有用的函数,它用于根据给定的条件对输入的两个张量进行选择操作。以下是 torch.where 的基本语法:
torch.where(condition, x, y)
condition:布尔条件张量,通常是一个和 x、y 相同形状的张量。
x:当 condition 中的元素为 True 时,从 x 中选择相应的值。
y:当 condition 中的元素为 False 时,从 y 中选择相应的值。
- 基本用法
假设我们有两个张量 a 和 b,我们想根据 a 中的值是否大于 0 来选择来自 a 或 b 的元素。
import torch
a = torch.tensor([1, -2, 3, -4])
b = torch.tensor([10, 20, 30, 40])
result = torch.where(a > 0, a, b)
print(result)
输出:
tensor([ 1, 20, 3, 40])
解释:
对于 a > 0 的位置(即 1 和 3),选择 a 中对应的值。
对于 a <= 0 的位置(即 -2 和 -4),选择 b 中对应的值。
- 用于张量替换
假设我们有一个张量 data,其中的 0 值代表缺失数据,我们希望用 -1 来替换所有 0 值。
data = torch.tensor([0, 1, 0, 3, 4])
result = torch.where(data == 0, torch.tensor(-1), data)
print(result)
输出:
tensor([-1, 1, -1, 3, 4])
解释:
对于 data == 0 的位置,替换为 -1。
对于其他位置,保留 data 中的原值。
好记性不如烂键盘---点滴、积累、进步!

浙公网安备 33010602011771号