torch.where(condition, x, y)

torch.where 是一个非常有用的函数,它用于根据给定的条件对输入的两个张量进行选择操作。以下是 torch.where 的基本语法:

torch.where(condition, x, y)

condition:布尔条件张量,通常是一个和 x、y 相同形状的张量。

x:当 condition 中的元素为 True 时,从 x 中选择相应的值。

y:当 condition 中的元素为 False 时,从 y 中选择相应的值。

  1. 基本用法

假设我们有两个张量 a 和 b,我们想根据 a 中的值是否大于 0 来选择来自 a 或 b 的元素。

import torch

a = torch.tensor([1, -2, 3, -4])
b = torch.tensor([10, 20, 30, 40])

result = torch.where(a > 0, a, b)
print(result)

输出:

tensor([ 1, 20, 3, 40])

解释:

对于 a > 0 的位置(即 1 和 3),选择 a 中对应的值。

对于 a <= 0 的位置(即 -2 和 -4),选择 b 中对应的值。
  1. 用于张量替换

假设我们有一个张量 data,其中的 0 值代表缺失数据,我们希望用 -1 来替换所有 0 值。

data = torch.tensor([0, 1, 0, 3, 4])

result = torch.where(data == 0, torch.tensor(-1), data)
print(result)

输出:

tensor([-1, 1, -1, 3, 4])

解释:

对于 data == 0 的位置,替换为 -1。

对于其他位置,保留 data 中的原值。
posted @ 2025-04-16 20:12  无左无右  阅读(209)  评论(0)    收藏  举报