6.6.2 模型训练

f isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device

对于这段代码,作用如下

  • net.eval()设置为评估模式之后,某些层(如 Dropout 和 BatchNorm)的行为会与训练模式不同。例如:

    • Dropout 层在评估模式下不会随机丢弃神经元。
    • BatchNorm 层在评估模式下会使用训练时计算的均值和方差,而不是当前批次的统计量。

    net.train()的行为就完全与上面的相反

  • device = next(iter(net.parameters())).device是在获取当前网络所在的设备(CPU或者GPU)。从后面的代码来看,我们最开始是在GPU上面建立的网络,所以这里device会是GPU。注意,net.parameters()返回的是一个容器(不是一个列表,所以net.parameters()[0]是错误的),iter(net.parameters())是将这个容器转变为可迭代对象,next(iter(net.parameters()))是在获取第一个参数

posted @ 2025-02-20 12:26  最爱丁珰  阅读(24)  评论(0)    收藏  举报