with torch.no_grad()

讲述with torch.no_grad()前,先讲述requires_grad参数。

1. requires_grad

\(torch\) 中,\(tensor\) 有一个 requires_grad参数,如果设置为 \(True\),则反向传播时,该 \(tensor\) 就会自动求导。

\(tensor\)requires_grad的属性默认为 \(False\),若一个节点(叶子变量:自己创建的 \(tensor\)requires_grad被设置为\(True\),那么所有依赖它的节点requires_grad都为 \(True\)(即使其他相依赖的 \(tensor\)requires_grad = False)。

import torch

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = False)
z = torch.randn(10, 5, requires_grad = False)
w = x + y + z
print(w.requires_grad)
True

2. with torch.no_grad()

即使一个 \(tensor\) (命名为 \(x\))的requires_grad = True,由 \(x\) 得到的新 \(tensor\)(命名为 \(w\)requires_grad也为 \(False\),且grad_fn也为 \(None\),即不会对 \(w\) 求导。

import torch

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
    w = x + y + z
    print(w.requires_grad)
    print(w.grad_fn)
print(w.requires_grad)
False
None
False


posted @ 2023-03-06 23:46  做梦当财神  阅读(318)  评论(0)    收藏  举报