在 PyTorch 中,model.eval() 和 model.train() 是用于切换模型训练和评估模式的两个重要方法,它们主要影响某些特定层(如 Dropout、BatchNorm 等)的行为。以下是它们的具体使用场景和区别:
- 作用:将模型设置为训练模式。在训练模式下,Dropout、BatchNorm 等层会正常工作(例如随机丢弃神经元、计算批次统计量)。
- 使用时机:在训练阶段调用,通常在每个训练 epoch 开始前使用。
- 示例:
- 作用:将模型设置为评估模式。在评估模式下,Dropout 会停止丢弃神经元,BatchNorm 会使用训练阶段统计的全局均值和方差,确保推理结果稳定。
- 使用时机:在验证、测试或模型推理阶段调用,通常搭配
torch.no_grad() 使用以节省计算资源。
- 示例:
| 模式 | Dropout/BatchNorm 行为 | 梯度计算 |
model.train() |
启用(如随机丢弃神经元、更新批次统计量) |
启用(默认) |
model.eval() |
禁用(保持所有神经元,使用全局统计量) |
通常禁用(配合torch.no_grad()) |
- 仅调用
model.eval() 不会自动关闭梯度计算:需要搭配 with torch.no_grad(): 来禁用梯度计算。
- 仅对需要训练 / 评估的模型调用:如果你的代码中有多个模型,确保只对当前使用的模型切换模式。
- RNN 类模型可能不受影响:Dropout 和 BatchNorm 在 RNN 中应用较少,因此这些模型可能对模式切换不敏感。
- 训练时:用
model.train() 开启训练模式。
- 验证 / 测试 / 推理时:用
model.eval() 和 torch.no_grad() 组合开启评估模式。
正确切换模式是保证模型性能和结果一致性的关键!