PyTorch `torch.no_grad` vs `torch.inference_mode`

在Pytorch 1.9版本中,更新了torch.inference_mode()方法(官方文档),该方法与torch.no_grad()类似,都不会记录梯度。

然而不同的是torch.inference_mode()相比torch.no_grad()有更好的性能,并且会强制关闭梯度记录。并且不能在中途设置梯度。

下例来自于官方论坛的提问

import torch

with torch.no_grad():
    x = torch.randn(1)
    y = x + 1

y.requires_grad = True
z = y + 1
print(z.grad_fn)

# >>> <AddBackward0 object at 0x10ca1e110>

with torch.inference_mode():
    x = torch.randn(1)
    y = x + 1

y.requires_grad = True

# >>> RuntimeError: Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.
posted @ 2024-04-17 17:06  絵守辛玥  阅读(2269)  评论(0)    收藏  举报