with torch.no_grad()
讲述with torch.no_grad()
前,先讲述requires_grad
参数。
1. requires_grad
在 \(torch\) 中,\(tensor\) 有一个 requires_grad
参数,如果设置为 \(True\),则反向传播时,该 \(tensor\) 就会自动求导。
\(tensor\) 的requires_grad
的属性默认为 \(False\),若一个节点(叶子变量:自己创建的 \(tensor\))requires_grad
被设置为\(True\),那么所有依赖它的节点requires_grad
都为 \(True\)(即使其他相依赖的 \(tensor\) 的 requires_grad = False
)。
import torch
x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = False)
z = torch.randn(10, 5, requires_grad = False)
w = x + y + z
print(w.requires_grad)
True
2. with torch.no_grad()
即使一个 \(tensor\) (命名为 \(x\))的requires_grad = True
,由 \(x\) 得到的新 \(tensor\)(命名为 \(w\))requires_grad
也为 \(False\),且grad_fn
也为 \(None\),即不会对 \(w\) 求导。
import torch
x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
w = x + y + z
print(w.requires_grad)
print(w.grad_fn)
print(w.requires_grad)
False
None
False