model
-
所有的网络都要继承nn.Module,这个类中封装好了很多函数,例如train(),to, cuda, state_dict, load_state_dict等;
-
train
model.train()
并不是真的执行训练, 而是PyTorch 中 nn.Module 的一个方法,用于将模型设置为训练模式。调用这个方法后,模型中的所有模块(如 Dropout 和 BatchNorm)都会被设置为训练模式。这对于某些特定层是必要的,因为它们在训练和评估时的行为不同
所有的网络都要继承nn.Module,这个类中封装好了很多函数,例如train(),to, cuda, state_dict, load_state_dict等;
train
model.train()
并不是真的执行训练, 而是PyTorch 中 nn.Module 的一个方法,用于将模型设置为训练模式。调用这个方法后,模型中的所有模块(如 Dropout 和 BatchNorm)都会被设置为训练模式。这对于某些特定层是必要的,因为它们在训练和评估时的行为不同