深度学习(模型保存)
这里存四种格式:
1. 只保存模型参数的pth文件。
2. 能在python环境下读取的的模型结构和参数pt文件。
3. 能在c++环境下读取的模型结构和参数pt文件。
4. 能在pytorch环境外被其他框架读取的模型结构和参数onnx文件。
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear1 = nn.Linear(10, 5) self.linear2 = nn.Linear(5, 10) def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x model = SimpleModel() model.eval() # 1. 保存模型的 state_dict(仅保存模型参数) torch.save(model.state_dict(), 'model_state_dict.pth') # 2. 保存整个模型(包括模型结构和参数),依赖于 Python 环境加载 torch.save(model, 'model.pt') # 3. 使用 TorchScript 保存模型(可以在 C++ 中加载),如libtorch,netron中能显示结构 script_model = torch.jit.trace(model, torch.randn(1, 10)) script_model.save('script_model.pt') # 4. 使用 ONNX 导出模型(可以在其他深度学习框架中使用) torch.onnx.export(model, torch.randn(1, 10), 'model.onnx', opset_version=11, input_names=['input'], output_names=['output'])