java集成yolo11 onnx模型推理

先使用python脚本转换 yolo11 训练的推理文件,注意opencv版本>=4.8.0

# 导出为不同格式,优化部署
from ultralytics import YOLO

model = YOLO("runs/detect/train/weights/best.pt")

# 1. 导出为ONNX(适合跨平台部署)
model.export(
    format="onnx",
    simplify=True,       # 简化模型
    opset=12,           # 较新的opset有更多优化
    dynamic=False,       # 固定输入形状(利于优化)
    half=True,          # 使用FP16减少大小
    batch=1,            # 固定批大小
    imgsz=640           # 固定输入尺寸
)

# 2. 导出为TensorRT(GPU加速)
#model.export(format="engine", half=True, workspace=4)  # 半精度,4GB显存

# 3. 导出为OpenVINO(Intel CPU优化)
#model.export(format="openvino")

# 4. 导出为CoreML(iOS/Mac)
#model.export(format="coreml")

 

JAVA中编写工具类,调用 onnx 模型进行推理

import ai.onnxruntime.*;
import org.opencv.core.*;
import org.opencv.imgproc.Imgproc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.nio.FloatBuffer;
import java.util.*;

/**
 * YOLO ONNX 目标检测器 - 最终版本,不使用cvtColor
 */
@Component
public class YoloOnnxDetector {
    private static final Logger logger = LoggerFactory.getLogger(YoloOnnxDetector.class);

    // 模型配置
    private static final int MODEL_WIDTH = 640;
    private static final int MODEL_HEIGHT = 640;
    @Value("${yolo.confidence_level}")
    private float CONFIDENCE_THRESHOLD; //置信度

    @Value("${yolo.iou_threshold}") //IOU阈值
    private float IOU_THRESHOLD;

    // ONNX Runtime 对象
    private OrtEnvironment environment;
    private OrtSession session;

    @Value("${yolo.bestOnnxDir}")
    private String modelPath;

    @Value("${yolo.medicineClassId:0}")
    private int medicineClassId;

    private boolean isLoaded = false;

    @PostConstruct
    public void init() {
        try {
            loadModel();
            logger.info("YOLO ONNX模型初始化成功");
        } catch (Exception e) {
            logger.error("YOLO ONNX模型初始化失败: {}", e.getMessage());
        }
    }

    @PreDestroy
    public void destroy() {
        release();
    }

    private void loadModel() {
        try {
            logger.info("加载ONNX模型: {}", modelPath);
            environment = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
            options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
            session = environment.createSession(modelPath, options);
            isLoaded = true;

            printModelInfo();
            logger.info("模型加载成功");
        } catch (Exception e) {
            logger.error("模型加载失败", e);
            isLoaded = false;
        }
    }

    private void printModelInfo() {
        try {
            logger.info("=== 模型信息 ===");

            Map<String, NodeInfo> inputInfo = session.getInputInfo();
            for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
                String name = entry.getKey();
                NodeInfo info = entry.getValue();
                TensorInfo tensorInfo = (TensorInfo) info.getInfo();
                logger.info("输入节点: {}, 形状: {}", name, Arrays.toString(tensorInfo.getShape()));
            }

            Map<String, NodeInfo> outputInfo = session.getOutputInfo();
            for (Map.Entry<String, NodeInfo> entry : outputInfo.entrySet()) {
                String name = entry.getKey();
                NodeInfo info = entry.getValue();
                TensorInfo tensorInfo = (TensorInfo) info.getInfo();
                logger.info("输出节点: {}, 形状: {}", name, Arrays.toString(tensorInfo.getShape()));
            }
        } catch (Exception e) {
            logger.error("获取模型信息失败", e);
        }
    }

    public Mat detectAndCrop(Mat inputImage) {
        if (!isLoaded) {
            logger.error("模型未加载");
            return null;
        }

        if (inputImage == null || inputImage.empty()) {
            logger.error("输入图像为空");
            return null;
        }

        try {
            logger.info("=== 开始检测 ===");
            logger.info("输入图像尺寸: {}x{}, 通道数: {}",
                    inputImage.cols(), inputImage.rows(), inputImage.channels());

            // 1. 预处理图像 - 不使用cvtColor
            logger.info("开始预处理...");
            float[] preprocessedData = manualPreprocess(inputImage);
            if (preprocessedData == null) {
                logger.error("预处理失败");
                return null;
            }

            // 2. 执行推理
            logger.info("开始推理...");
            OrtSession.Result results = inference(preprocessedData);

            // 3. 解析检测结果
            List<Detection> allDetections = parseDetections(results);
            logger.info("解析完成,原始检测框总数: {}", allDetections.size());

            // 打印所有检测框信息
            for (int i = 0; i < Math.min(allDetections.size(), 10); i++) {
                Detection d = allDetections.get(i);
            }

            // 4. 筛选medicine类别的检测框
            List<Detection> medicineDetections = filterMedicineDetections(allDetections);
            logger.info("Medicine类别检测框数量: {}", medicineDetections.size());

            // 5. 如果没有检测到medicine,返回null
            if (medicineDetections.isEmpty()) {
                logger.info("未检测到medicine目标");
                return null;
            }

            // 6. 非极大值抑制(NMS)
            logger.info("开始NMS处理...");
            List<Detection> nmsDetections = nonMaxSuppression(medicineDetections);
            logger.info("NMS后检测框数量: {}", nmsDetections.size());

            // 7. 获取置信度最高的检测框
            if (nmsDetections.isEmpty()) {
                logger.info("NMS后无有效检测框");
                return null;
            }

            Detection bestDetection = nmsDetections.get(0);
            logger.info("最佳检测框: classId={}, confidence={}, bbox=[{},{},{},{}]",
                    bestDetection.classId, bestDetection.confidence,
                    bestDetection.bbox.x, bestDetection.bbox.y,
                    bestDetection.bbox.width, bestDetection.bbox.height);

            // 8. 转换检测框坐标到原始图像尺寸
            Rect originalBBox = scaleBoundingBox(bestDetection.bbox,
                    inputImage.cols(), inputImage.rows());
            logger.info("原始图像尺寸转换完成: [{},{},{},{}]",
                    originalBBox.x, originalBBox.y, originalBBox.width, originalBBox.height);

            // 9. 在原始图像上绘制绿色边框(新增代码)
            // 由于inputImage是灰度图,我们绘制白色边框(灰度值255)
            // 如果要绘制彩色边框,需要先将灰度图转换为彩色图
            logger.info("在原始图像上绘制检测边框");
            // 检查图像通道数,如果是单通道灰度图,使用白色(255)
            // 如果是彩色图,可以使用绿色(0, 255, 0)
            if (inputImage.channels() == 1) {
                // 灰度图 - 使用白色边框
                Imgproc.rectangle(inputImage,
                        new Point(originalBBox.x, originalBBox.y),
                        new Point(originalBBox.x + originalBBox.width, originalBBox.y + originalBBox.height),
                        new Scalar(255),  // 白色
                        3);  // 边框厚度
                logger.info("在灰度图上绘制白色边框");
            } else if (inputImage.channels() == 3) {
                // 彩色图 - 使用绿色边框
                Imgproc.rectangle(inputImage,
                        new Point(originalBBox.x, originalBBox.y),
                        new Point(originalBBox.x + originalBBox.width, originalBBox.y + originalBBox.height),
                        new Scalar(0, 255, 0),  // BGR格式:绿色
                        3);  // 边框厚度
                logger.info("在彩色图上绘制绿色边框");
            }

            // 10. 裁剪图像
            Mat result = cropImage(inputImage, originalBBox);


            if (result != null) {
                logger.info("裁剪成功,结果图像尺寸: {}x{}", result.cols(), result.rows());
            } else {
                logger.info("裁剪失败");
            }

            logger.info("=== 检测完成 ===");
            return result;

        } catch (Exception e) {
            logger.error("检测过程中出现异常", e);
            return null;
        }
    }

    /**
     * 手动预处理方法 - 完全不使用cvtColor
     */
    private float[] manualPreprocess(Mat image) {
        logger.debug("开始手动预处理");

        // 1. 调整尺寸到模型输入尺寸
        Mat resized = new Mat();
        Imgproc.resize(image, resized, new Size(MODEL_WIDTH, MODEL_HEIGHT));
        logger.debug("调整尺寸完成: {}x{} -> {}x{}",
                image.cols(), image.rows(), MODEL_WIDTH, MODEL_HEIGHT);

        int channels = resized.channels();
        logger.debug("调整后图像通道数: {}", channels);

        // 2. 准备输出数组
        int totalPixels = MODEL_WIDTH * MODEL_HEIGHT;
        float[] floatValues = new float[3 * totalPixels];

        if (channels == 1) {
            // 单通道图像 - 手动转换为三通道
            processSingleChannel(resized, floatValues);
        } else if (channels == 3) {
            // 三通道图像 - 手动处理
            processThreeChannel(resized, floatValues);
        } else {
            logger.error("不支持的图像通道数: {}", channels);
            resized.release();
            return null;
        }

        // 3. 打印前几个像素值用于调试
        logger.debug("预处理后前10个像素值 (CHW格式):");
        for (int i = 0; i < 10 && i < floatValues.length; i++) {
            logger.debug("  floatValues[{}] = {}", i, floatValues[i]);
        }

        // 4. 释放临时Mat
        resized.release();

        return floatValues;
    }

    /**
     * 处理单通道图像
     */
    private void processSingleChannel(Mat mat, float[] output) {
        int totalPixels = MODEL_WIDTH * MODEL_HEIGHT;

        // 使用get方法获取像素值
        for (int h = 0; h < MODEL_HEIGHT; h++) {
            for (int w = 0; w < MODEL_WIDTH; w++) {
                double[] pixel = mat.get(h, w);
                float value = 0.0f;

                if (pixel != null && pixel.length > 0) {
                    value = (float) pixel[0] / 255.0f;
                }

                // CHW格式: 先所有B通道,再所有G通道,再所有R通道
                // 对于灰度图,三个通道值相同
                int index = h * MODEL_WIDTH + w;
                output[index] = value;                    // B通道
                output[index + totalPixels] = value;      // G通道
                output[index + 2 * totalPixels] = value;  // R通道
            }
        }
    }

    /**
     * 处理三通道图像
     */
    private void processThreeChannel(Mat mat, float[] output) {
        int totalPixels = MODEL_WIDTH * MODEL_HEIGHT;

        // 假设输入是BGR格式(OpenCV默认)
        for (int h = 0; h < MODEL_HEIGHT; h++) {
            for (int w = 0; w < MODEL_WIDTH; w++) {
                double[] pixel = mat.get(h, w);

                float b = 0.0f, g = 0.0f, r = 0.0f;
                if (pixel != null && pixel.length >= 3) {
                    // OpenCV默认是BGR格式
                    b = (float) pixel[0] / 255.0f;  // B通道
                    g = (float) pixel[1] / 255.0f;  // G通道
                    r = (float) pixel[2] / 255.0f;  // R通道
                }

                // CHW格式
                int index = h * MODEL_WIDTH + w;
                output[index] = b;                    // B通道
                output[index + totalPixels] = g;      // G通道
                output[index + 2 * totalPixels] = r;  // R通道
            }
        }
    }

    private OrtSession.Result inference(float[] imageData) throws OrtException {
        logger.debug("开始推理,输入数据长度: {}", imageData.length);

        // 创建输入张量 [1, 3, 640, 640]
        FloatBuffer floatBuffer = FloatBuffer.wrap(imageData);
        long[] shape = {1, 3, MODEL_HEIGHT, MODEL_WIDTH};

        // 获取输入节点名称
        String inputName = session.getInputInfo().keySet().iterator().next();
        logger.debug("输入节点名称: {}", inputName);

        OnnxTensor inputTensor = OnnxTensor.createTensor(environment, floatBuffer, shape);

        // 执行推理
        Map<String, OnnxTensor> inputs = Collections.singletonMap(inputName, inputTensor);
        OrtSession.Result results = session.run(inputs);

        // 关闭输入张量
        inputTensor.close();

        return results;
    }

    private List<Detection> parseDetections(OrtSession.Result results) throws OrtException {
        List<Detection> detections = new ArrayList<>();

        try {
            // 获取输出张量
            OnnxTensor outputTensor = (OnnxTensor) results.get(0);
            Object outputValue = outputTensor.getValue();

            logger.info("输出数据类型: {}", outputValue.getClass().getName());

            if (outputValue instanceof float[][][]) {
                float[][][] outputData = (float[][][]) outputValue;
                detections = parse3DOutput(outputData);
            } else {
                logger.error("未知的输出类型: {}", outputValue.getClass().getName());
            }

        } catch (Exception e) {
            logger.error("解析检测结果失败", e);
        }

        return detections;
    }

    /**
     * 解析3D输出
     */
    private List<Detection> parse3DOutput(float[][][] outputData) {
        List<Detection> detections = new ArrayList<>();

        // 输出形状: [1, 5, 8400]
        int batchSize = outputData.length;           // 1
        int numFeatures = outputData[0].length;      // 5
        int numPredictions = outputData[0][0].length; // 8400

        logger.info("3D输出形状: [{}, {}, {}]", batchSize, numFeatures, numPredictions);

        // 统计置信度分布
        int[] confRanges = new int[11]; // 0-0.1, 0.1-0.2, ..., 1.0
        int highConfCount = 0;

        // 先收集所有置信度,找出最大值
        float maxConf = 0;
        for (int i = 0; i < numPredictions; i++) {
            float conf = outputData[0][4][i];
            if (conf > maxConf) maxConf = conf;
        }
        logger.info("最大置信度: {}", maxConf);

        for (int i = 0; i < numPredictions; i++) {
            float conf = outputData[0][4][i];
            int rangeIdx = (int)(conf * 10);
            if (rangeIdx >= 0 && rangeIdx <= 10) {
                confRanges[rangeIdx]++;
            }
            if (conf > 0.5) {
                highConfCount++;
            }
        }

        logger.info("置信度分布 (0-1.0, 步长0.1): {}", Arrays.toString(confRanges));
        logger.info("置信度>0.5的预测数量: {}", highConfCount);

        // 打印置信度最高的10个预测
        List<Prediction> topPredictions = new ArrayList<>();
        for (int i = 0; i < numPredictions; i++) {
            float conf = outputData[0][4][i];
            if (conf > maxConf * 0.1) { // 只记录置信度大于最大值的10%的
                topPredictions.add(new Prediction(i, conf,
                        outputData[0][0][i], outputData[0][1][i],
                        outputData[0][2][i], outputData[0][3][i]));
            }
        }

        // 按置信度排序
        topPredictions.sort((a, b) -> Float.compare(b.confidence, a.confidence));

        logger.info("高置信度预测数量 (大于最大值的10%): {}", topPredictions.size());
        for (int i = 0; i < Math.min(topPredictions.size(), 10); i++) {
            Prediction p = topPredictions.get(i);
            logger.info("高置信度预测 {}: idx={}, conf={}, cx={}, cy={}, w={}, h={}",
                    i, p.index, p.confidence, p.cx, p.cy, p.w, p.h);
        }

        // 解析所有预测
        int validCount = 0;
        for (int b = 0; b < batchSize; b++) {
            for (int i = 0; i < numPredictions; i++) {
                // 获取边界框坐标和置信度
                float cx = outputData[b][0][i];
                float cy = outputData[b][1][i];
                float w = outputData[b][2][i];
                float h = outputData[b][3][i];
                float confidence = outputData[b][4][i];

                // 根据之前的输出,坐标值看起来很大,可能是像素坐标而不是归一化坐标
                // 尝试将坐标除以640进行归一化
                if (cx > 1 || cy > 1 || w > 1 || h > 1) {
                    cx = cx / MODEL_WIDTH;
                    cy = cy / MODEL_HEIGHT;
                    w = w / MODEL_WIDTH;
                    h = h / MODEL_HEIGHT;
                }

                // 过滤低置信度的检测框
                if (confidence > CONFIDENCE_THRESHOLD) {
                    validCount++;

                    // 转换为像素坐标
                    int x1 = Math.round((cx - w / 2) * MODEL_WIDTH);
                    int y1 = Math.round((cy - h / 2) * MODEL_HEIGHT);
                    int x2 = Math.round((cx + w / 2) * MODEL_WIDTH);
                    int y2 = Math.round((cy + h / 2) * MODEL_HEIGHT);

                    // 确保边界有效
                    x1 = Math.max(0, x1);
                    y1 = Math.max(0, y1);
                    x2 = Math.min(MODEL_WIDTH - 1, x2);
                    y2 = Math.min(MODEL_HEIGHT - 1, y2);

                    if (x2 > x1 && y2 > y1) {
                        int classId = 0;  // 单类别模型
                        detections.add(new Detection(classId, confidence,
                                new Rect(x1, y1, x2 - x1, y2 - y1)));
                    }
                }
            }
        }

        logger.info("置信度>{}的预测数量: {}", CONFIDENCE_THRESHOLD, validCount);
        logger.info("有效检测框总数: {} (坐标有效)", detections.size());

        return detections;
    }

    private List<Detection> filterMedicineDetections(List<Detection> allDetections) {
        List<Detection> medicineDetections = new ArrayList<>();

        logger.info("筛选Medicine类别 (classId={})", medicineClassId);

        for (Detection detection : allDetections) {
            if (detection.classId == medicineClassId) {
                medicineDetections.add(detection);
            }
        }

        return medicineDetections;
    }

    private List<Detection> nonMaxSuppression(List<Detection> detections) {
        if (detections.size() <= 1) {
            return detections;
        }

        logger.debug("开始NMS,输入检测框数量: {}", detections.size());

        // 按置信度降序排序
        detections.sort((a, b) -> Float.compare(b.confidence, a.confidence));

        List<Detection> result = new ArrayList<>();
        boolean[] suppressed = new boolean[detections.size()];

        for (int i = 0; i < detections.size(); i++) {
            if (suppressed[i]) continue;

            Detection current = detections.get(i);
            result.add(current);

            for (int j = i + 1; j < detections.size(); j++) {
                if (suppressed[j]) continue;

                Detection other = detections.get(j);
                float iou = calculateIoU(current.bbox, other.bbox);

                if (iou > IOU_THRESHOLD) {
                    suppressed[j] = true;
                }
            }
        }

        logger.debug("NMS完成,输出检测框数量: {}", result.size());
        return result;
    }

    private float calculateIoU(Rect rect1, Rect rect2) {
        int x1 = Math.max(rect1.x, rect2.x);
        int y1 = Math.max(rect1.y, rect2.y);
        int x2 = Math.min(rect1.x + rect1.width, rect2.x + rect2.width);
        int y2 = Math.min(rect1.y + rect1.height, rect2.y + rect2.height);

        if (x2 <= x1 || y2 <= y1) {
            return 0.0f;
        }

        int interArea = (x2 - x1) * (y2 - y1);
        int unionArea = rect1.width * rect1.height + rect2.width * rect2.height - interArea;

        return unionArea > 0 ? (float) interArea / unionArea : 0.0f;
    }

    private Rect scaleBoundingBox(Rect modelBBox, int originalWidth, int originalHeight) {
        float scaleX = (float) originalWidth / MODEL_WIDTH;
        float scaleY = (float) originalHeight / MODEL_HEIGHT;

        int x = Math.round(modelBBox.x * scaleX);
        int y = Math.round(modelBBox.y * scaleY);
        int width = Math.round(modelBBox.width * scaleX);
        int height = Math.round(modelBBox.height * scaleY);

        // 边界检查
        x = Math.max(0, Math.min(x, originalWidth - 1));
        y = Math.max(0, Math.min(y, originalHeight - 1));
        width = Math.max(1, Math.min(width, originalWidth - x));
        height = Math.max(1, Math.min(height, originalHeight - y));

        return new Rect(x, y, width, height);
    }

    private Mat cropImage(Mat image, Rect bbox) {
        try {
            // 添加边界扩展
            int padding = 5;
            int x = Math.max(0, bbox.x - padding);
            int y = Math.max(0, bbox.y - padding);
            int width = Math.min(bbox.width + 2 * padding, image.cols() - x);
            int height = Math.min(bbox.height + 2 * padding, image.rows() - y);

            if (width > 0 && height > 0) {
                return new Mat(image, new Rect(x, y, width, height));
            }
        } catch (Exception e) {
            logger.error("裁剪失败", e);
        }
        return null;
    }

    public boolean isLoaded() {
        return isLoaded;
    }

    public void release() {
        try {
            if (session != null) {
                session.close();
                session = null;
            }
            if (environment != null) {
                environment.close();
                environment = null;
            }
            isLoaded = false;
            logger.info("YOLO模型资源已释放");
        } catch (Exception e) {
            logger.error("释放资源失败", e);
        }
    }

    private static class Detection {
        int classId;
        float confidence;
        Rect bbox;

        Detection(int classId, float confidence, Rect bbox) {
            this.classId = classId;
            this.confidence = confidence;
            this.bbox = bbox;
        }
    }

    private static class Prediction {
        int index;
        float confidence;
        float cx, cy, w, h;

        Prediction(int index, float confidence, float cx, float cy, float w, float h) {
            this.index = index;
            this.confidence = confidence;
            this.cx = cx;
            this.cy = cy;
            this.w = w;
            this.h = h;
        }
    }
}

 

导包

        <!-- opencv mvn install:install-file -Dfile=C:/opencv/opencv/build/java/opencv-4110.jar -DgroupId=com.acts -DartifactId=opencv -Dversion=4.11 -Dpackaging=jar-->
        <dependency>
            <groupId>com.acts</groupId>
            <artifactId>opencv</artifactId>
            <version>4.11</version>
        </dependency>

        <!-- ONNX Runtime -->
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>1.16.0</version>
        </dependency>

 

posted @ 2025-12-08 17:54  别动我的猫  阅读(7)  评论(0)    收藏  举报