PyTorch `torch.no_grad` vs `torch.inference_mode`
在Pytorch 1.9版本中,更新了torch.inference_mode()方法(官方文档),该方法与torch.no_grad()类似,都不会记录梯度。
然而不同的是torch.inference_mode()相比torch.no_grad()有更好的性能,并且会强制关闭梯度记录。并且不能在中途设置梯度。
下例来自于官方论坛的提问
import torch
with torch.no_grad():
x = torch.randn(1)
y = x + 1
y.requires_grad = True
z = y + 1
print(z.grad_fn)
# >>> <AddBackward0 object at 0x10ca1e110>
with torch.inference_mode():
x = torch.randn(1)
y = x + 1
y.requires_grad = True
# >>> RuntimeError: Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.

浙公网安备 33010602011771号