TensorRT推理Onnx_c++

TensorRT推理Onnx

#include "FormTensorRT.h"
#include <QToolButton>
#include <QMessageBox>
#include <QDebug>

#include <iostream>
#include <fstream>
#include <vector>
#include <string>

#pragma execution_character_set("utf-8")

// TensorRT 命名空间
using namespace nvinfer1;
using namespace nvonnxparser;

FormTensorRT::FormTensorRT(QWidget* parent)
    : QWidget(parent)
{
    ui.setupUi(this);
    connect(ui.toolButton_peridict, &QToolButton::clicked, this, &FormTensorRT::Test6);
}

// 保存引擎
bool FormTensorRT::saveEngine(const char* filename, nvinfer1::IHostMemory* engineData)
{
    std::ofstream file(filename, std::ios::binary);
    if (!file) return false;
    file.write((char*)engineData->data(), engineData->size());
    return true;
}

// 加载引擎
bool FormTensorRT::loadEngine(const char* filename, std::vector<char>& data)
{
    std::ifstream file(filename, std::ios::binary | std::ios::ate);
    if (!file) return false;

    size_t size = file.tellg();
    file.seekg(0);
    data.resize(size);
    file.read(data.data(), size);
    return true;
}




FormTensorRT::~FormTensorRT()
{}

void FormTensorRT::Test6()
{
    //ShowMsg("开始推理(首次会编译引擎,稍慢)");

    // 1. 路径配置
    QString onnx_path2 = QApplication::applicationDirPath() + "/config/best_lrc.onnx";
    //const char* ONNX_PATH = "best_lrc.onnx";
    std::string str_temp = onnx_path2.toStdString();
    const char* ONNX_PATH = str_temp.c_str();
    const char* ENGINE_PATH = "model.trt";

    // ===================== 1. 创建运行时组件 =====================
    IRuntime* runtime = createInferRuntime(mLogger);
    if (!runtime) {
        ShowMsg("创建 runtime 失败");
        return;
    }

    ICudaEngine* engine = nullptr;
    IExecutionContext* context = nullptr;

    // ===================== 2. 加载/构建引擎 =====================
    std::vector<char> engineData;
    if (loadEngine(ENGINE_PATH, engineData)) {
        qDebug() << "加载本地TRT引擎...";
        engine = runtime->deserializeCudaEngine(engineData.data(), engineData.size());
    }
    else {
        qDebug() << "从ONNX构建TRT引擎...";

        IBuilder* builder = createInferBuilder(mLogger);
        if (!builder) { ShowMsg("builder 创建失败"); return; }

        uint32_t flags = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
        INetworkDefinition* network = builder->createNetworkV2(flags);
        if (!network) { ShowMsg("network 创建失败"); return; }

        IParser* parser = createParser(*network, mLogger);
        if (!parser) { ShowMsg("parser 创建失败"); return; }

        // 解析ONNX
        if (!parser->parseFromFile(ONNX_PATH, (int)ILogger::Severity::kINFO)) {
            ShowMsg("ONNX 解析失败");
            return;
        }

        IBuilderConfig* config = builder->createBuilderConfig();
        if (!config) { ShowMsg("config 创建失败"); return; }

        config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, 1U << 30);

        IHostMemory* serializedEngine = builder->buildSerializedNetwork(*network, *config);
        if (!serializedEngine) { ShowMsg("引擎构建失败"); return; }

        saveEngine(ENGINE_PATH, serializedEngine);
        engine = runtime->deserializeCudaEngine(serializedEngine->data(), serializedEngine->size());

        // ===================== TRT9 全部用 delete 释放 =====================
        delete serializedEngine;
        delete config;
        delete parser;      // 这里修复!
        delete network;
        delete builder;
    }

    if (!engine) { ShowMsg("引擎创建失败"); return; }

    context = engine->createExecutionContext();
    if (!context) { ShowMsg("上下文创建失败"); return; }

    // ===================== 3. 获取模型信息 ( TRT9 新版API ) =====================
    const char* inputName = engine->getIOTensorName(0);
    const char* outputName = engine->getIOTensorName(1);
    Dims inputDims = engine->getTensorShape(inputName);
    Dims outputDims = engine->getTensorShape(outputName);

    qDebug() << "\n===== 模型信息 =====";
    qDebug() << "输入名:" << inputName;
    qDebug() << "输出名:" << outputName;

    // 计算元素数量
    auto getSize = [](Dims dims) {
        int size = 1;
        for (int i = 0; i < dims.nbDims; i++) size *= dims.d[i];
        return size;
    };

    int inputSize = getSize(inputDims);
    int outputSize = getSize(outputDims);

    // ===================== 4. 分配内存 =====================
    float* hostInput = new float[inputSize]();
    float* hostOutput = new float[outputSize]();
    void* deviceInput = nullptr;
    void* deviceOutput = nullptr;

    cudaMalloc(&deviceInput, inputSize * sizeof(float));
    cudaMalloc(&deviceOutput, outputSize * sizeof(float));

    // 构造输入数据
    for (int i = 0; i < inputSize; i++)
        hostInput[i] = 1.0f;

    // ===================== 5. 拷贝数据到GPU =====================
    cudaMemcpy(deviceInput, hostInput, inputSize * sizeof(float), cudaMemcpyHostToDevice);

    // ===================== 6. 执行推理 ( TRT9 新版API ) =====================
    context->setTensorAddress(inputName, deviceInput);
    context->setTensorAddress(outputName, deviceOutput);

    bool success = context->executeV2(nullptr);
    if (!success) {
        ShowMsg("推理执行失败");
    }

    // ===================== 7. 结果拷回CPU =====================
    cudaMemcpy(hostOutput, deviceOutput, outputSize * sizeof(float), cudaMemcpyDeviceToHost);

    // ===================== 8. 输出结果 =====================
    QString res = "推理结果前10个:\n";
    for (int i = 0; i < std::min(10, outputSize); i++) {
        res += QString::number(hostOutput[i]) + " ";
    }
    ShowMsg(res);

    // ===================== 9. 释放资源 =====================
    cudaFree(deviceInput);
    cudaFree(deviceOutput);
    delete[] hostInput;
    delete[] hostOutput;

    // TRT9 用 delete
    delete context;
    delete engine;
    delete runtime;

    ShowMsg("推理完成!");
}

void FormTensorRT::ShowMsg(QString msg)
{
    QMessageBox::information(this, "提示", msg);
}

 v2:

#include "FormTensorRT.h"
#include <QToolButton>
#include <QMessageBox>
#include <QDebug>
#include <QImage>
#include <QPixmap>
#include <QPainter>
#include <QPen>
#include <QFile>

#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <algorithm>
#include <opencv2/opencv.hpp>

#pragma execution_character_set("utf-8")

// TensorRT 命名空间
using namespace nvinfer1;
using namespace nvonnxparser;
using namespace cv;

// YOLO 后处理参数(根据你的模型修改)
const float CONF_THRESHOLD = 0.25f;    // 置信度阈值
const float NMS_THRESHOLD = 0.45f;      // 非极大值抑制阈值
const int INPUT_W = 640;                // 模型输入宽
const int INPUT_H = 640;                // 模型输入高

struct DetResult {
    float x1, y1, x2, y2;
    float score;
    int class_id;
};

FormTensorRT::FormTensorRT(QWidget* parent)
    : QWidget(parent)
{
    ui.setupUi(this);
    connect(ui.toolButton_peridict, &QToolButton::clicked, this, &FormTensorRT::Test6);
}

// 保存引擎
bool FormTensorRT::saveEngine(const char* filename, nvinfer1::IHostMemory* engineData)
{
    std::ofstream file(filename, std::ios::binary);
    if (!file) return false;
    file.write((char*)engineData->data(), engineData->size());
    return true;
}

// 加载引擎
bool FormTensorRT::loadEngine(const char* filename, std::vector<char>& data)
{
    std::ifstream file(filename, std::ios::binary | std::ios::ate);
    if (!file) return false;

    size_t size = file.tellg();
    file.seekg(0);
    data.resize(size);
    file.read(data.data(), size);
    return true;
}

FormTensorRT::~FormTensorRT()
{}

// 图像预处理:letterbox + 归一化 + 格式转换
void preprocess(Mat& img, float* input_data, int input_w, int input_h) {
    Mat rgb_img, resized_img;
    cvtColor(img, rgb_img, COLOR_BGR2RGB);
    resize(rgb_img, resized_img, Size(input_w, input_h));
    resized_img.convertTo(resized_img, CV_32F, 1.0 / 255.0);

    int idx = 0;
    for (int c = 0; c < 3; c++) {
        for (int h = 0; h < input_h; h++) {
            for (int w = 0; w < input_w; w++) {
                input_data[idx++] = resized_img.at<Vec3f>(h, w)[c];
            }
        }
    }
}

// YOLO 输出后处理:解码 + NMS
std::vector<DetResult> postprocess(float* output, int output_size, int img_w, int img_h) {
    std::vector<DetResult> results;
    std::vector<cv::Rect> bboxes;
    std::vector<float> scores;

    int num_boxes = output_size / 6;
    float scale_w = (float)img_w / INPUT_W;
    float scale_h = (float)img_h / INPUT_H;

    for (int i = 0; i < num_boxes; i++) {
        float* ptr = output + i * 6;
        float x = ptr[0];
        float y = ptr[1];
        float w = ptr[2];
        float h = ptr[3];
        float score = ptr[4];
        int cls = (int)ptr[5];

        if (score < CONF_THRESHOLD) continue;

        float x1 = (x - w / 2) * scale_w;
        float y1 = (y - h / 2) * scale_h;
        float x2 = (x + w / 2) * scale_w;
        float y2 = (y + h / 2) * scale_h;

        results.push_back({ x1, y1, x2, y2, score, cls });
        bboxes.emplace_back(x1, y1, x2 - x1, y2 - y1);
        scores.push_back(score);
    }

    std::vector<int> indices;
    cv::dnn::NMSBoxes(bboxes, scores, CONF_THRESHOLD, NMS_THRESHOLD, indices);

    std::vector<DetResult> final_res;
    for (int i : indices) {
        final_res.push_back(results[i]);
    }
    return final_res;
}

// 绘制检测框和分数
void drawResults(Mat& img, std::vector<DetResult>& results) {
    for (auto& res : results) {
        rectangle(img, Point(res.x1, res.y1), Point(res.x2, res.y2), Scalar(0, 255, 0), 2);
        QString text = QString("score:%1").arg(res.score, 0, 'f', 2);
        putText(img, text.toStdString(), Point(res.x1, res.y1 - 5),
            FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0), 2);
    }
}

void FormTensorRT::Test6()
{
    // 1. 路径配置
    QString img_path = QApplication::applicationDirPath() + "/config/lrc_img.jpg";
    QString onnx_path2 = QApplication::applicationDirPath() + "/config/best_lrc.onnx";
    std::string str_temp = onnx_path2.toStdString();
    const char* ONNX_PATH = str_temp.c_str();
    const char* ENGINE_PATH = "model.trt";

    // 读取图像
    Mat src_img = imread(img_path.toStdString());
    if (src_img.empty()) {
        ShowMsg("读取 aa.jpg 失败!");
        return;
    }
    int img_w = src_img.cols;
    int img_h = src_img.rows;

    // ===================== 1. 创建运行时组件 =====================
    IRuntime* runtime = createInferRuntime(mLogger);
    if (!runtime) {
        ShowMsg("创建 runtime 失败");
        return;
    }

    ICudaEngine* engine = nullptr;
    IExecutionContext* context = nullptr;

    // ===================== 2. 加载/构建引擎 =====================
    std::vector<char> engineData;
    if (loadEngine(ENGINE_PATH, engineData)) {
        qDebug() << "加载本地TRT引擎...";
        engine = runtime->deserializeCudaEngine(engineData.data(), engineData.size());
    }
    else {
        qDebug() << "从ONNX构建TRT引擎...";

        IBuilder* builder = createInferBuilder(mLogger);
        if (!builder) { ShowMsg("builder 创建失败"); return; }

        uint32_t flags = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
        INetworkDefinition* network = builder->createNetworkV2(flags);
        if (!network) { ShowMsg("network 创建失败"); return; }

        IParser* parser = createParser(*network, mLogger);
        if (!parser) { ShowMsg("parser 创建失败"); return; }

        // 解析ONNX
        if (!parser->parseFromFile(ONNX_PATH, (int)ILogger::Severity::kINFO)) {
            ShowMsg("ONNX 解析失败");
            return;
        }

        IBuilderConfig* config = builder->createBuilderConfig();
        if (!config) { ShowMsg("config 创建失败"); return; }

        config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, 1U << 30);

        IHostMemory* serializedEngine = builder->buildSerializedNetwork(*network, *config);
        if (!serializedEngine) { ShowMsg("引擎构建失败"); return; }

        saveEngine(ENGINE_PATH, serializedEngine);
        engine = runtime->deserializeCudaEngine(serializedEngine->data(), serializedEngine->size());

        delete serializedEngine;
        delete config;
        delete parser;
        delete network;
        delete builder;
    }

    if (!engine) { ShowMsg("引擎创建失败"); return; }

    context = engine->createExecutionContext();
    if (!context) { ShowMsg("上下文创建失败"); return; }

    // ===================== 3. 获取模型信息 =====================
    const char* inputName = engine->getIOTensorName(0);
    const char* outputName = engine->getIOTensorName(1);
    Dims inputDims = engine->getTensorShape(inputName);
    Dims outputDims = engine->getTensorShape(outputName);

    qDebug() << "\n===== 模型信息 =====";
    qDebug() << "输入名:" << inputName;
    qDebug() << "输出名:" << outputName;

    auto getSize = [](Dims dims) {
        int size = 1;
        for (int i = 0; i < dims.nbDims; i++) size *= dims.d[i];
        return size;
    };

    int inputSize = getSize(inputDims);
    int outputSize = getSize(outputDims);

    // ===================== 4. 分配内存 =====================
    float* hostInput = new float[inputSize]();
    float* hostOutput = new float[outputSize]();
    void* deviceInput = nullptr;
    void* deviceOutput = nullptr;

    cudaMalloc(&deviceInput, inputSize * sizeof(float));
    cudaMalloc(&deviceOutput, outputSize * sizeof(float));

    // 图像预处理(替换原来的全1数据)
    preprocess(src_img, hostInput, INPUT_W, INPUT_H);

    // ===================== 5. 拷贝数据到GPU =====================
    cudaMemcpy(deviceInput, hostInput, inputSize * sizeof(float), cudaMemcpyHostToDevice);

    // ===================== 6. 执行推理 =====================
    context->setTensorAddress(inputName, deviceInput);
    context->setTensorAddress(outputName, deviceOutput);

    bool success = context->executeV2(nullptr);
    if (!success) {
        ShowMsg("推理执行失败");
    }

    // ===================== 7. 结果拷回CPU =====================
    cudaMemcpy(hostOutput, deviceOutput, outputSize * sizeof(float), cudaMemcpyDeviceToHost);

    // ===================== 8. 后处理 + 绘制 =====================
    auto det_results = postprocess(hostOutput, outputSize, img_w, img_h);
    drawResults(src_img, det_results);

    // 保存标注后的图像
    imwrite("result.jpg", src_img);

    // 弹窗显示结果
    QString res = QString("检测到 %1 个目标\n结果已保存为 result.jpg\n分数列表:\n").arg(det_results.size());
    for (int i = 0; i < det_results.size(); i++) {
        res += QString("目标%1:%2\n").arg(i + 1).arg(det_results[i].score, 0, 'f', 2);
    }
    ShowMsg(res);

    // ===================== 9. 释放资源 =====================
    cudaFree(deviceInput);
    cudaFree(deviceOutput);
    delete[] hostInput;
    delete[] hostOutput;

    delete context;
    delete engine;
    delete runtime;
}

void FormTensorRT::ShowMsg(QString msg)
{
    QMessageBox::information(this, "提示", msg);
}

 

posted @ 2026-04-28 16:25  txwtech  阅读(0)  评论(0)    收藏  举报