深度学习(模型保存)

这里存四种格式:

1. 只保存模型参数的pth文件。

2. 能在python环境下读取的的模型结构和参数pt文件。

3. 能在c++环境下读取的模型结构和参数pt文件。

4. 能在pytorch环境外被其他框架读取的模型结构和参数onnx文件。

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear1 = nn.Linear(10, 5)
        self.linear2 = nn.Linear(5, 10)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

model = SimpleModel()
model.eval()

# 1. 保存模型的 state_dict(仅保存模型参数)
torch.save(model.state_dict(), 'model_state_dict.pth')

# 2. 保存整个模型(包括模型结构和参数),依赖于 Python 环境加载
torch.save(model, 'model.pt')

# 3. 使用 TorchScript 保存模型(可以在 C++ 中加载),如libtorch,netron中能显示结构
script_model = torch.jit.trace(model, torch.randn(1, 10))
script_model.save('script_model.pt')

# 4. 使用 ONNX 导出模型(可以在其他深度学习框架中使用)
torch.onnx.export(model, torch.randn(1, 10), 'model.onnx', opset_version=11, input_names=['input'], output_names=['output'])
posted @ 2025-05-02 20:02  Dsp Tian  阅读(77)  评论(0)    收藏  举报