PyTorch 模型训练的流程
PyTorch 模型训练的流程
1. 模型构建
模型构建主要包括以下部分:
-
创建结构化数据
-
定义网络模型结构:主要通过继承
torch.nn.Module()类实现 -
定义损失函数
-
定义优化器:
torch.optim模块 -
训练过程
2. 创建结构化数据
主要通过 torch.utils.data.TensorDataset() 和 torch.utils.data.Dataset() 类 和 torch.utils.data.DataLoader() 类实现
参见博客:PyTorch torch.utils.data 模块 结构化数据
3. 定义模型 torch.nn.Module() 类
通过继承 torch.nn.Module() 类实现自定义模型,也可以自定义模块、层、激活函数、损失函数等
参见博客:PyTorch torch.nn.Module 类 构建模型
4. 定义损失函数
参见博客:PyTorch 损失函数
5. 优化器 torch.optim 模块
一般通过 torch.optim 模块下的优化器来实现学习参数的训练
参见博客:Pytorch torch.optim 模块 优化器
-
设置学习参数调整策略:
torch.optim.lr_scheduler模块 -
只对模型的部分参数进行训练,实现冻结层和微调(Fine-Tuning)的效果
6. 训练过程
-
模型评估时取消梯度
torch.no_grad() -
训练模式
model.train()和评估模式model.eval()-
model.train()方法启动训练模型(training),Dropout 和 Batch Normalization 层生效 -
model.eval()方法启动评估模型(evaluation),Dropout 和 Batch Normalization 层不会生效
-

浙公网安备 33010602011771号