博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

model.eval()和with torch.no_grad()的区别

Posted on 2021-03-03 20:56  秘密小鱼  阅读(1015)  评论(0)    收藏  举报

这二者的功能是不同的:

  • model.eval():
    告诉网络的所有层,你在eval模式,也就是说,像batchNorm和dropout这样的层会工作在eval模式而非training模式(如dropout层在eval模式会被关闭)。

  • with torch.no_grad()
    当我们计算梯度时,我们需要缓存input values,中间features(可以理解为中间神经元的输出)的值,因为他们可能会在后面计算梯度时需要。
    譬如:b = w1 * a关于变量w1和a的梯度分别是a和w1。因此在前馈的过程中我们应当缓存这些值以在之后的backward自动求梯度过程中使用。
    然而,当我们在inference(或测试)时,我们无需计算梯度,那么前馈的时候我们不必保存这些值。事实上,在inference时,我们根本无需构造计算图,否则会导致不必要的存储开销。
    Pytorch提供了一个上下文管理器,也就是torch.no_grad来满足这种目的,被with torch.no_gard()管理的环境中进行的计算,不会生成计算图,不会存储为计算梯度而缓存的数值。

总结: 如果你的网络中包含batchNorm或者dropout这样在training,eval时表现不同的层,应当使用model.eval()。在inference时用with torch.no_grad()会节省存储空间。

另外需要注意的是,即便不使用with torch.no_grad(),在测试只要你不调用loss.backward()就不会计算梯度,with torch.no_grad()的作用只是节省存储空间。

明白的铁铁把明白打在公屏上

参考文章:https://blog.paperspace.com/pytorch-101-understanding-graphs-and-automatic-differentiation/