6.6.2 模型训练
f isinstance(net, nn.Module):
net.eval() # 设置为评估模式
if not device:
device = next(iter(net.parameters())).device
对于这段代码,作用如下
-
net.eval()设置为评估模式之后,某些层(如 Dropout 和 BatchNorm)的行为会与训练模式不同。例如:- Dropout 层在评估模式下不会随机丢弃神经元。
- BatchNorm 层在评估模式下会使用训练时计算的均值和方差,而不是当前批次的统计量。
net.train()的行为就完全与上面的相反 -
device = next(iter(net.parameters())).device是在获取当前网络所在的设备(CPU或者GPU)。从后面的代码来看,我们最开始是在GPU上面建立的网络,所以这里device会是GPU。注意,net.parameters()返回的是一个容器(不是一个列表,所以net.parameters()[0]是错误的),iter(net.parameters())是将这个容器转变为可迭代对象,next(iter(net.parameters()))是在获取第一个参数

浙公网安备 33010602011771号