java集成yolo11 onnx模型推理
先使用python脚本转换 yolo11 训练的推理文件,注意opencv版本>=4.8.0
# 导出为不同格式,优化部署 from ultralytics import YOLO model = YOLO("runs/detect/train/weights/best.pt") # 1. 导出为ONNX(适合跨平台部署) model.export( format="onnx", simplify=True, # 简化模型 opset=12, # 较新的opset有更多优化 dynamic=False, # 固定输入形状(利于优化) half=True, # 使用FP16减少大小 batch=1, # 固定批大小 imgsz=640 # 固定输入尺寸 ) # 2. 导出为TensorRT(GPU加速) #model.export(format="engine", half=True, workspace=4) # 半精度,4GB显存 # 3. 导出为OpenVINO(Intel CPU优化) #model.export(format="openvino") # 4. 导出为CoreML(iOS/Mac) #model.export(format="coreml")
JAVA中编写工具类,调用 onnx 模型进行推理
import ai.onnxruntime.*; import org.opencv.core.*; import org.opencv.imgproc.Imgproc; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import java.nio.FloatBuffer; import java.util.*; /** * YOLO ONNX 目标检测器 - 最终版本,不使用cvtColor */ @Component public class YoloOnnxDetector { private static final Logger logger = LoggerFactory.getLogger(YoloOnnxDetector.class); // 模型配置 private static final int MODEL_WIDTH = 640; private static final int MODEL_HEIGHT = 640; @Value("${yolo.confidence_level}") private float CONFIDENCE_THRESHOLD; //置信度 @Value("${yolo.iou_threshold}") //IOU阈值 private float IOU_THRESHOLD; // ONNX Runtime 对象 private OrtEnvironment environment; private OrtSession session; @Value("${yolo.bestOnnxDir}") private String modelPath; @Value("${yolo.medicineClassId:0}") private int medicineClassId; private boolean isLoaded = false; @PostConstruct public void init() { try { loadModel(); logger.info("YOLO ONNX模型初始化成功"); } catch (Exception e) { logger.error("YOLO ONNX模型初始化失败: {}", e.getMessage()); } } @PreDestroy public void destroy() { release(); } private void loadModel() { try { logger.info("加载ONNX模型: {}", modelPath); environment = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions options = new OrtSession.SessionOptions(); options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); session = environment.createSession(modelPath, options); isLoaded = true; printModelInfo(); logger.info("模型加载成功"); } catch (Exception e) { logger.error("模型加载失败", e); isLoaded = false; } } private void printModelInfo() { try { logger.info("=== 模型信息 ==="); Map<String, NodeInfo> inputInfo = session.getInputInfo(); for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) { String name = entry.getKey(); NodeInfo info = entry.getValue(); TensorInfo tensorInfo = (TensorInfo) info.getInfo(); logger.info("输入节点: {}, 形状: {}", name, Arrays.toString(tensorInfo.getShape())); } Map<String, NodeInfo> outputInfo = session.getOutputInfo(); for (Map.Entry<String, NodeInfo> entry : outputInfo.entrySet()) { String name = entry.getKey(); NodeInfo info = entry.getValue(); TensorInfo tensorInfo = (TensorInfo) info.getInfo(); logger.info("输出节点: {}, 形状: {}", name, Arrays.toString(tensorInfo.getShape())); } } catch (Exception e) { logger.error("获取模型信息失败", e); } } public Mat detectAndCrop(Mat inputImage) { if (!isLoaded) { logger.error("模型未加载"); return null; } if (inputImage == null || inputImage.empty()) { logger.error("输入图像为空"); return null; } try { logger.info("=== 开始检测 ==="); logger.info("输入图像尺寸: {}x{}, 通道数: {}", inputImage.cols(), inputImage.rows(), inputImage.channels()); // 1. 预处理图像 - 不使用cvtColor logger.info("开始预处理..."); float[] preprocessedData = manualPreprocess(inputImage); if (preprocessedData == null) { logger.error("预处理失败"); return null; } // 2. 执行推理 logger.info("开始推理..."); OrtSession.Result results = inference(preprocessedData); // 3. 解析检测结果 List<Detection> allDetections = parseDetections(results); logger.info("解析完成,原始检测框总数: {}", allDetections.size()); // 打印所有检测框信息 for (int i = 0; i < Math.min(allDetections.size(), 10); i++) { Detection d = allDetections.get(i); } // 4. 筛选medicine类别的检测框 List<Detection> medicineDetections = filterMedicineDetections(allDetections); logger.info("Medicine类别检测框数量: {}", medicineDetections.size()); // 5. 如果没有检测到medicine,返回null if (medicineDetections.isEmpty()) { logger.info("未检测到medicine目标"); return null; } // 6. 非极大值抑制(NMS) logger.info("开始NMS处理..."); List<Detection> nmsDetections = nonMaxSuppression(medicineDetections); logger.info("NMS后检测框数量: {}", nmsDetections.size()); // 7. 获取置信度最高的检测框 if (nmsDetections.isEmpty()) { logger.info("NMS后无有效检测框"); return null; } Detection bestDetection = nmsDetections.get(0); logger.info("最佳检测框: classId={}, confidence={}, bbox=[{},{},{},{}]", bestDetection.classId, bestDetection.confidence, bestDetection.bbox.x, bestDetection.bbox.y, bestDetection.bbox.width, bestDetection.bbox.height); // 8. 转换检测框坐标到原始图像尺寸 Rect originalBBox = scaleBoundingBox(bestDetection.bbox, inputImage.cols(), inputImage.rows()); logger.info("原始图像尺寸转换完成: [{},{},{},{}]", originalBBox.x, originalBBox.y, originalBBox.width, originalBBox.height); // 9. 在原始图像上绘制绿色边框(新增代码) // 由于inputImage是灰度图,我们绘制白色边框(灰度值255) // 如果要绘制彩色边框,需要先将灰度图转换为彩色图 logger.info("在原始图像上绘制检测边框"); // 检查图像通道数,如果是单通道灰度图,使用白色(255) // 如果是彩色图,可以使用绿色(0, 255, 0) if (inputImage.channels() == 1) { // 灰度图 - 使用白色边框 Imgproc.rectangle(inputImage, new Point(originalBBox.x, originalBBox.y), new Point(originalBBox.x + originalBBox.width, originalBBox.y + originalBBox.height), new Scalar(255), // 白色 3); // 边框厚度 logger.info("在灰度图上绘制白色边框"); } else if (inputImage.channels() == 3) { // 彩色图 - 使用绿色边框 Imgproc.rectangle(inputImage, new Point(originalBBox.x, originalBBox.y), new Point(originalBBox.x + originalBBox.width, originalBBox.y + originalBBox.height), new Scalar(0, 255, 0), // BGR格式:绿色 3); // 边框厚度 logger.info("在彩色图上绘制绿色边框"); } // 10. 裁剪图像 Mat result = cropImage(inputImage, originalBBox); if (result != null) { logger.info("裁剪成功,结果图像尺寸: {}x{}", result.cols(), result.rows()); } else { logger.info("裁剪失败"); } logger.info("=== 检测完成 ==="); return result; } catch (Exception e) { logger.error("检测过程中出现异常", e); return null; } } /** * 手动预处理方法 - 完全不使用cvtColor */ private float[] manualPreprocess(Mat image) { logger.debug("开始手动预处理"); // 1. 调整尺寸到模型输入尺寸 Mat resized = new Mat(); Imgproc.resize(image, resized, new Size(MODEL_WIDTH, MODEL_HEIGHT)); logger.debug("调整尺寸完成: {}x{} -> {}x{}", image.cols(), image.rows(), MODEL_WIDTH, MODEL_HEIGHT); int channels = resized.channels(); logger.debug("调整后图像通道数: {}", channels); // 2. 准备输出数组 int totalPixels = MODEL_WIDTH * MODEL_HEIGHT; float[] floatValues = new float[3 * totalPixels]; if (channels == 1) { // 单通道图像 - 手动转换为三通道 processSingleChannel(resized, floatValues); } else if (channels == 3) { // 三通道图像 - 手动处理 processThreeChannel(resized, floatValues); } else { logger.error("不支持的图像通道数: {}", channels); resized.release(); return null; } // 3. 打印前几个像素值用于调试 logger.debug("预处理后前10个像素值 (CHW格式):"); for (int i = 0; i < 10 && i < floatValues.length; i++) { logger.debug(" floatValues[{}] = {}", i, floatValues[i]); } // 4. 释放临时Mat resized.release(); return floatValues; } /** * 处理单通道图像 */ private void processSingleChannel(Mat mat, float[] output) { int totalPixels = MODEL_WIDTH * MODEL_HEIGHT; // 使用get方法获取像素值 for (int h = 0; h < MODEL_HEIGHT; h++) { for (int w = 0; w < MODEL_WIDTH; w++) { double[] pixel = mat.get(h, w); float value = 0.0f; if (pixel != null && pixel.length > 0) { value = (float) pixel[0] / 255.0f; } // CHW格式: 先所有B通道,再所有G通道,再所有R通道 // 对于灰度图,三个通道值相同 int index = h * MODEL_WIDTH + w; output[index] = value; // B通道 output[index + totalPixels] = value; // G通道 output[index + 2 * totalPixels] = value; // R通道 } } } /** * 处理三通道图像 */ private void processThreeChannel(Mat mat, float[] output) { int totalPixels = MODEL_WIDTH * MODEL_HEIGHT; // 假设输入是BGR格式(OpenCV默认) for (int h = 0; h < MODEL_HEIGHT; h++) { for (int w = 0; w < MODEL_WIDTH; w++) { double[] pixel = mat.get(h, w); float b = 0.0f, g = 0.0f, r = 0.0f; if (pixel != null && pixel.length >= 3) { // OpenCV默认是BGR格式 b = (float) pixel[0] / 255.0f; // B通道 g = (float) pixel[1] / 255.0f; // G通道 r = (float) pixel[2] / 255.0f; // R通道 } // CHW格式 int index = h * MODEL_WIDTH + w; output[index] = b; // B通道 output[index + totalPixels] = g; // G通道 output[index + 2 * totalPixels] = r; // R通道 } } } private OrtSession.Result inference(float[] imageData) throws OrtException { logger.debug("开始推理,输入数据长度: {}", imageData.length); // 创建输入张量 [1, 3, 640, 640] FloatBuffer floatBuffer = FloatBuffer.wrap(imageData); long[] shape = {1, 3, MODEL_HEIGHT, MODEL_WIDTH}; // 获取输入节点名称 String inputName = session.getInputInfo().keySet().iterator().next(); logger.debug("输入节点名称: {}", inputName); OnnxTensor inputTensor = OnnxTensor.createTensor(environment, floatBuffer, shape); // 执行推理 Map<String, OnnxTensor> inputs = Collections.singletonMap(inputName, inputTensor); OrtSession.Result results = session.run(inputs); // 关闭输入张量 inputTensor.close(); return results; } private List<Detection> parseDetections(OrtSession.Result results) throws OrtException { List<Detection> detections = new ArrayList<>(); try { // 获取输出张量 OnnxTensor outputTensor = (OnnxTensor) results.get(0); Object outputValue = outputTensor.getValue(); logger.info("输出数据类型: {}", outputValue.getClass().getName()); if (outputValue instanceof float[][][]) { float[][][] outputData = (float[][][]) outputValue; detections = parse3DOutput(outputData); } else { logger.error("未知的输出类型: {}", outputValue.getClass().getName()); } } catch (Exception e) { logger.error("解析检测结果失败", e); } return detections; } /** * 解析3D输出 */ private List<Detection> parse3DOutput(float[][][] outputData) { List<Detection> detections = new ArrayList<>(); // 输出形状: [1, 5, 8400] int batchSize = outputData.length; // 1 int numFeatures = outputData[0].length; // 5 int numPredictions = outputData[0][0].length; // 8400 logger.info("3D输出形状: [{}, {}, {}]", batchSize, numFeatures, numPredictions); // 统计置信度分布 int[] confRanges = new int[11]; // 0-0.1, 0.1-0.2, ..., 1.0 int highConfCount = 0; // 先收集所有置信度,找出最大值 float maxConf = 0; for (int i = 0; i < numPredictions; i++) { float conf = outputData[0][4][i]; if (conf > maxConf) maxConf = conf; } logger.info("最大置信度: {}", maxConf); for (int i = 0; i < numPredictions; i++) { float conf = outputData[0][4][i]; int rangeIdx = (int)(conf * 10); if (rangeIdx >= 0 && rangeIdx <= 10) { confRanges[rangeIdx]++; } if (conf > 0.5) { highConfCount++; } } logger.info("置信度分布 (0-1.0, 步长0.1): {}", Arrays.toString(confRanges)); logger.info("置信度>0.5的预测数量: {}", highConfCount); // 打印置信度最高的10个预测 List<Prediction> topPredictions = new ArrayList<>(); for (int i = 0; i < numPredictions; i++) { float conf = outputData[0][4][i]; if (conf > maxConf * 0.1) { // 只记录置信度大于最大值的10%的 topPredictions.add(new Prediction(i, conf, outputData[0][0][i], outputData[0][1][i], outputData[0][2][i], outputData[0][3][i])); } } // 按置信度排序 topPredictions.sort((a, b) -> Float.compare(b.confidence, a.confidence)); logger.info("高置信度预测数量 (大于最大值的10%): {}", topPredictions.size()); for (int i = 0; i < Math.min(topPredictions.size(), 10); i++) { Prediction p = topPredictions.get(i); logger.info("高置信度预测 {}: idx={}, conf={}, cx={}, cy={}, w={}, h={}", i, p.index, p.confidence, p.cx, p.cy, p.w, p.h); } // 解析所有预测 int validCount = 0; for (int b = 0; b < batchSize; b++) { for (int i = 0; i < numPredictions; i++) { // 获取边界框坐标和置信度 float cx = outputData[b][0][i]; float cy = outputData[b][1][i]; float w = outputData[b][2][i]; float h = outputData[b][3][i]; float confidence = outputData[b][4][i]; // 根据之前的输出,坐标值看起来很大,可能是像素坐标而不是归一化坐标 // 尝试将坐标除以640进行归一化 if (cx > 1 || cy > 1 || w > 1 || h > 1) { cx = cx / MODEL_WIDTH; cy = cy / MODEL_HEIGHT; w = w / MODEL_WIDTH; h = h / MODEL_HEIGHT; } // 过滤低置信度的检测框 if (confidence > CONFIDENCE_THRESHOLD) { validCount++; // 转换为像素坐标 int x1 = Math.round((cx - w / 2) * MODEL_WIDTH); int y1 = Math.round((cy - h / 2) * MODEL_HEIGHT); int x2 = Math.round((cx + w / 2) * MODEL_WIDTH); int y2 = Math.round((cy + h / 2) * MODEL_HEIGHT); // 确保边界有效 x1 = Math.max(0, x1); y1 = Math.max(0, y1); x2 = Math.min(MODEL_WIDTH - 1, x2); y2 = Math.min(MODEL_HEIGHT - 1, y2); if (x2 > x1 && y2 > y1) { int classId = 0; // 单类别模型 detections.add(new Detection(classId, confidence, new Rect(x1, y1, x2 - x1, y2 - y1))); } } } } logger.info("置信度>{}的预测数量: {}", CONFIDENCE_THRESHOLD, validCount); logger.info("有效检测框总数: {} (坐标有效)", detections.size()); return detections; } private List<Detection> filterMedicineDetections(List<Detection> allDetections) { List<Detection> medicineDetections = new ArrayList<>(); logger.info("筛选Medicine类别 (classId={})", medicineClassId); for (Detection detection : allDetections) { if (detection.classId == medicineClassId) { medicineDetections.add(detection); } } return medicineDetections; } private List<Detection> nonMaxSuppression(List<Detection> detections) { if (detections.size() <= 1) { return detections; } logger.debug("开始NMS,输入检测框数量: {}", detections.size()); // 按置信度降序排序 detections.sort((a, b) -> Float.compare(b.confidence, a.confidence)); List<Detection> result = new ArrayList<>(); boolean[] suppressed = new boolean[detections.size()]; for (int i = 0; i < detections.size(); i++) { if (suppressed[i]) continue; Detection current = detections.get(i); result.add(current); for (int j = i + 1; j < detections.size(); j++) { if (suppressed[j]) continue; Detection other = detections.get(j); float iou = calculateIoU(current.bbox, other.bbox); if (iou > IOU_THRESHOLD) { suppressed[j] = true; } } } logger.debug("NMS完成,输出检测框数量: {}", result.size()); return result; } private float calculateIoU(Rect rect1, Rect rect2) { int x1 = Math.max(rect1.x, rect2.x); int y1 = Math.max(rect1.y, rect2.y); int x2 = Math.min(rect1.x + rect1.width, rect2.x + rect2.width); int y2 = Math.min(rect1.y + rect1.height, rect2.y + rect2.height); if (x2 <= x1 || y2 <= y1) { return 0.0f; } int interArea = (x2 - x1) * (y2 - y1); int unionArea = rect1.width * rect1.height + rect2.width * rect2.height - interArea; return unionArea > 0 ? (float) interArea / unionArea : 0.0f; } private Rect scaleBoundingBox(Rect modelBBox, int originalWidth, int originalHeight) { float scaleX = (float) originalWidth / MODEL_WIDTH; float scaleY = (float) originalHeight / MODEL_HEIGHT; int x = Math.round(modelBBox.x * scaleX); int y = Math.round(modelBBox.y * scaleY); int width = Math.round(modelBBox.width * scaleX); int height = Math.round(modelBBox.height * scaleY); // 边界检查 x = Math.max(0, Math.min(x, originalWidth - 1)); y = Math.max(0, Math.min(y, originalHeight - 1)); width = Math.max(1, Math.min(width, originalWidth - x)); height = Math.max(1, Math.min(height, originalHeight - y)); return new Rect(x, y, width, height); } private Mat cropImage(Mat image, Rect bbox) { try { // 添加边界扩展 int padding = 5; int x = Math.max(0, bbox.x - padding); int y = Math.max(0, bbox.y - padding); int width = Math.min(bbox.width + 2 * padding, image.cols() - x); int height = Math.min(bbox.height + 2 * padding, image.rows() - y); if (width > 0 && height > 0) { return new Mat(image, new Rect(x, y, width, height)); } } catch (Exception e) { logger.error("裁剪失败", e); } return null; } public boolean isLoaded() { return isLoaded; } public void release() { try { if (session != null) { session.close(); session = null; } if (environment != null) { environment.close(); environment = null; } isLoaded = false; logger.info("YOLO模型资源已释放"); } catch (Exception e) { logger.error("释放资源失败", e); } } private static class Detection { int classId; float confidence; Rect bbox; Detection(int classId, float confidence, Rect bbox) { this.classId = classId; this.confidence = confidence; this.bbox = bbox; } } private static class Prediction { int index; float confidence; float cx, cy, w, h; Prediction(int index, float confidence, float cx, float cy, float w, float h) { this.index = index; this.confidence = confidence; this.cx = cx; this.cy = cy; this.w = w; this.h = h; } } }
导包
<!-- opencv mvn install:install-file -Dfile=C:/opencv/opencv/build/java/opencv-4110.jar -DgroupId=com.acts -DartifactId=opencv -Dversion=4.11 -Dpackaging=jar--> <dependency> <groupId>com.acts</groupId> <artifactId>opencv</artifactId> <version>4.11</version> </dependency> <!-- ONNX Runtime --> <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.16.0</version> </dependency>

浙公网安备 33010602011771号