C++测试Pytorch训练的模型

背景

大多数的网络模型都是基于Python设计的,因为Python有集成好的框架和包随意使用。但是由于软件的底层设计大多数都是基于C++的,所以这里记录一下将一个简单的基于Python的分割模型转化C++程序里。

1. 先将.pth模型转化成.onnx模型

因为C++里面没办法直接读取.pth模型,所以要先转化成.onnx模型。

废话不多说,直接上代码。

import torch
from Unet import Unet


def pth2onnx(input, pth_path, onnx_path):
    model = Unet()  # 导入自己的网络模型
    model.load_state_dict(torch.load(pth_path))  # 初始化权重
    model.eval()

    torch.onnx.export(model, input, onnx_path, verbose=True)


if __name__ == '__main__':
    pth_path = r'./best_model.pth'  # 训练的pth路径
    onnx_path = r'./best_model.onnx'  # 保存onnx的路径
    model_input = torch.randn(1, 1, 512, 512)  # 模型输入[B,C,H,W]
    pth2onnx(input=model_input, pth_path=pth_path, onnx_path=onnx_path)

(可选)2. 测试.onnx模型转换是否正确

如果第3步模型测试不正确,可以通过这一步检查模型转换是否正确。

import cv2
import onnxruntime
import numpy as np

onnx_path = './best_model.onnx'  # 上一步生成的onnx模型
image_path = './data/test/1.bmp'  # 测试图像

image = cv2.imread(image_path)  # 读取图像
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)  # resize成相应尺寸
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # 灰度处理,使用的是单通道图像预测

# 处理成模型需要的格式,[B,C,H,W]
input = image.reshape(1, 1, image.shape[0], image.shape[1]).astype(np.float32)

session = onnxruntime.InferenceSession(onnx_path)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input})
print(outputs[0].shape)  # (1, 1, 512, 512)

pred = np.array(outputs[0])[0][0]
pred[pred > 0] = 255
pred[pred <= 0] = 0
cv2.imwrite("pred.bmp", pred)

3. C++测试

这里使用的是Qt测试的。

#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;
using namespace std;

int main()
{
    int h = 512; int w = 512;
    String modelFile = "F:/project/Unet_model/best_model.onnx";
    String imageFile = "F:/project/Unet_model/data/test/1.bmp";

    Mat img = imread(imageFile); // 读取测试图片
    cvtColor(img, img, cv::COLOR_BGR2GRAY);  // 灰度化
    resize(img, img, Size(h, w));

    Mat inputBolb = blobFromImage(img);  // 转换输入图像的格式[B,C,H,W]

    dnn::Net net = cv::dnn::readNetFromONNX(modelFile); //读取网络和参数
    net.setInput(inputBolb);
    Mat output = net.forward();  // 输出4D mat

    int B = inputBolb.size[0];
    int C = inputBolb.size[1];
    int H = inputBolb.size[2];
    int W = inputBolb.size[3];

    Mat predMat = Mat::zeros(h, w, CV_32F);

    for(int i = 0; i < B; i++){
        for(int j = 0; j < C; j++){
            for(int m = 0; m < H; m++){
                for(int n =0; n < W; n++){

                    float pred = output.ptr<float>(i,j,m)[n];

                    if(pred > 0){
                        predMat.at<float>(m,n) = 255;
                    }
                    else{
                        predMat.at<float>(m,n) = 0;
                    }
                }
            }
        }
    }
    cv::imwrite("F:/QtProject/pred.bmp", predMat);
}
posted @ 2022-03-23 12:00  小吕同学吖  阅读(400)  评论(0)    收藏  举报