pytorch构建并保存模型(.pth) 转化为 torchscript(.pt), 导出为onnx格式

pytorch(.pth)模型转化为 torchscript(.pt), 导出为onnx格式

1 .pth模型转换为.pt模型

import torch
import torchvision
from models import fcn
 
model=torchvision.models.vgg16()
state_dict = torch.load("./checkpoint-epoch100.pth")
#print(state_dict)
model.load_state_dict(state_dict,False)
model.eval()
 
x = torch.rand(1,3,128,128)
ts = torch.jit.trace(model, x)
ts.save('fcn_vgg16.net')

注意很多人在转换的时候报错是因为:model.load_state_dict(state_dict)后面没用False参数

2. .pth模型转化为.onnx模型

如需使用opencv来加载模型,则需将.pth转化为.onnx格式的模型。
a.先安装onnx,使用命令:pip install onnx
b.使用以下命令转为.onnx模型

import io
import torch
import torch.onnx
import torchvision
from models import fcn
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
def test():
    model=torchvision.models.vgg16()
 
    pthfile = r'./checkpoint-epoch100.pth'
    loaded_model = torch.load(pthfile, map_location='cpu')
    # try:
    #     loaded_model.eval()
    # except AttributeError as error:
    #     print(error)
 
    #model.load_state_dict(loaded_model['state_dict'])
    # model = model.to(device)
 
    #data type nchw
    dummy_input1 = torch.randn(1, 3, 244, 244)
    # dummy_input2 = torch.randn(1, 3, 64, 64)
    # dummy_input3 = torch.randn(1, 3, 64, 64)
    input_names = [ "actual_input_1"]
    output_names = [ "output1" ]
    # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names)
    torch.onnx.export(model, dummy_input1, "fcn.onnx", verbose=True, input_names=input_names, output_names=output_names)
 
if __name__ == "__main__":
	test()

====================================================

import torch
import torch.nn as nn
from torch.autograd import Function
import onnx
import torch.onnx


class TinyNet(nn.Module):
    def __init__(self):
        super(TinyNet, self).__init__()
        self.abs = torch.abs
    
    def forward(self, x):
        x = self.abs(x)
        return x

model = TinyNet()
input = torch.FloatTensor([[-1, -2, 3],[-4, -5, 6]])
input_names = ["input_0"]
output_names = ["output0"]
torch.onnx.export(model, (input,), 'tinynet.onnx', opset_version=19, verbose=True, input_names=input_names, output_names=output_names)
print(onnx.load('tinynet.onnx'))

https://harmonyhu.com/2021/06/17/pytorch/

posted @ 2024-12-30 18:31  michaelchengjl  阅读(1020)  评论(0)    收藏  举报