yolov8 obb模型训练与onnx推理
1. 模型简介
YOLOv8 OBB(Oriented Bounding Box)是 YOLOv8 的旋转边界框版本,专门用于检测旋转/倾斜的物体。本文主要对数据标注,数据集制作,训练,模型导出与推理进行简要说明,关于其结构解析后面补充。
2. 数据标注
数据标注工具很多,这里列举两个能标注旋转框,适合obb标注的软件。roLabelImg(除了roLabelImg还需要装labelImg)与X-AnyLabeling。我这里使用的是X-AnyLabeling,下载地址如下:
https://github.com/CVHub520/X-AnyLabeling

然后点击导出obb格式就可以了(这样导出可以一步到位,不用再次转换等操作。导出时候会让你选择一个class Files的txt文件,自己编写一个就好了,每行为英文类别名就可以了)。

会得到一个labels文件夹,里面有每张图对应的txt文件,也就是我们后面训练用到的数据,格式如下:

分别是: 数字类别 x1 y1 x2 y2 x3 y3 x4 y4
其中x1,y1,x2,y2,x3,y3,x4,y4为旋转框的四个角点坐标(归一化后的),一般是顺时针排列(逆时针应该也可以,这里按照上图的直接导出就可以了)。
3. 制作数据集
制作数据集的脚本有很多,这里先看看数据集制作好后是什么结构(这很重要)。

展开后如下(这里方便给图,只示范了几张图):

首先通过前面的标注,我们已经有了标注数据集(包括txt与对应图片的文件夹),比如images_labels文件夹,内容如下:

接下来使用下面的python脚本,制作数据集了:
import os import random import shutil def split_dataset(image_folder_path, output_folder_path, train_ratio=0.8): """ 划分图片和标签数据集 参数: image_folder_path: 图片文件夹路径 output_folder_path: 输出文件夹路径 train_ratio: 训练集比例,默认0.8 """ # 创建输出文件夹结构 dirs = [ 'images/train', 'images/val', 'labels/train', 'labels/val' ] for dir_name in dirs: dir_path = os.path.join(output_folder_path, dir_name) os.makedirs(dir_path, exist_ok=True) # 支持的图片格式 image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff'] # 获取所有图片文件 image_files = [] for file in os.listdir(image_folder_path): file_lower = file.lower() if any(file_lower.endswith(ext) for ext in image_extensions): image_files.append(file) # 随机打乱 random.shuffle(image_files) # 计算划分点 split_point = int(len(image_files) * train_ratio) train_images = image_files[:split_point] val_images = image_files[split_point:] # 复制训练集 for img_file in train_images: # 构建文件路径 img_path = os.path.join(image_folder_path, img_file) # 构建标签文件名(相同的基本名,扩展名改为.txt) base_name = os.path.splitext(img_file)[0] label_file = base_name + '.txt' label_path = os.path.join(image_folder_path, label_file) # 复制图片 dest_img = os.path.join(output_folder_path, 'images/train', img_file) shutil.copy2(img_path, dest_img) # 复制标签(如果存在) if os.path.exists(label_path): dest_label = os.path.join(output_folder_path, 'labels/train', label_file) shutil.copy2(label_path, dest_label) else: print(f"警告: {label_file} 不存在") # 复制验证集 for img_file in val_images: # 构建文件路径 img_path = os.path.join(image_folder_path, img_file) # 构建标签文件名 base_name = os.path.splitext(img_file)[0] label_file = base_name + '.txt' label_path = os.path.join(image_folder_path, label_file) # 复制图片 dest_img = os.path.join(output_folder_path, 'images/val', img_file) shutil.copy2(img_path, dest_img) # 复制标签(如果存在) if os.path.exists(label_path): dest_label = os.path.join(output_folder_path, 'labels/val', label_file) shutil.copy2(label_path, dest_label) else: print(f"警告: {label_file} 不存在") print(f"数据集划分完成!") print(f"训练集: {len(train_images)} 张图片") print(f"验证集: {len(val_images)} 张图片") print(f"输出路径: {output_folder_path}") # 使用示例 if __name__ == "__main__": # 示例路径 image_folder = r"images_labels" #存放image与txt的文件夹 output_folder = r"train" #数据集保存的文件夹路径 # 调用函数 split_dataset(image_folder, output_folder)
运行完,就能得到前面的那个结构数据集了(训练,验证的比例可以根据自己需要修改参数train_ratio)。至此,训练集数据划分完毕了。
4. 配置文件
我们还需要yolov8-obb.yaml这个文件,如下:
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license # Ultralytics YOLOv8-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs # Model docs: https://docs.ultralytics.com/models/yolov8 # Task docs: https://docs.ultralytics.com/tasks/obb # Parameters nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' # [depth, width, max_channels] n: [0.33, 0.25, 1024] # YOLOv8n-obb summary: 144 layers, 3228867 parameters, 3228851 gradients, 9.1 GFLOPs s: [0.33, 0.50, 1024] # YOLOv8s-obb summary: 144 layers, 11452739 parameters, 11452723 gradients, 29.8 GFLOPs m: [0.67, 0.75, 768] # YOLOv8m-obb summary: 184 layers, 26463235 parameters, 26463219 gradients, 81.5 GFLOPs l: [1.00, 1.00, 512] # YOLOv8l-obb summary: 224 layers, 44540355 parameters, 44540339 gradients, 169.4 GFLOPs x: [1.00, 1.25, 512] # YOLOv8x-obb summary: 224 layers, 69555651 parameters, 69555635 gradients, 264.3 GFLOPs # YOLOv8.0n backbone backbone: # [from, repeats, module, args] - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f, [128, True]] - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f, [256, True]] - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f, [512, True]] - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f, [1024, True]] - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0n head head: - [-1, 1, nn.Upsample, [None, 2, "nearest"]] - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - [-1, 3, C2f, [512]] # 12 - [-1, 1, nn.Upsample, [None, 2, "nearest"]] - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - [-1, 3, C2f, [256]] # 15 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - [[-1, 12], 1, Concat, [1]] # cat head P4 - [-1, 3, C2f, [512]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - [[-1, 9], 1, Concat, [1]] # cat head P5 - [-1, 3, C2f, [1024]] # 21 (P5/32-large) - [[15, 18, 21], 1, OBB, [nc, 1]] # OBB(P3, P4, P5)
除了上面那个文件,我们还需要配置类别与数据路径的文件detect_obb.yaml,如下(路径与类别根据实际需要配置):
# Ultralytics YOLO 🚀, AGPL-3.0 license # DOTA 1.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University # Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ # Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.yaml # parent # ├── ultralytics # └── datasets # └── dota1 ← downloads here (2GB) # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] path: ./data/defect_obb/defect_obb # dataset root dir train: images/train # train images (relative to 'path') 1411 images val: images/val # val images (relative to 'path') 458 images #test: images/test # test images (optional) 937 images # Classes for DOTA 1.0 names: 0: defect
5. 预训练权重
去到这个地址https://docs.ultralytics.com/models/yolov8/#models去下载就好了,n到x,模型逐渐变大,根据任务需求选择。

也可以到网盘下载: 链接: https://pan.baidu.com/s/1kHgiXYCOC8w4YBcND9aTiQ 提取码: a8vc
至此,数据集与文件都准备完毕了。
6. 训练
训练之前,需要下载ultralytics 包或者下载源码也行。
pip install ultralytics
或者 https://github.com/ultralytics/ultralytics
然后使用如下代码训练:
import warnings warnings.filterwarnings('ignore') from ultralytics import YOLO if __name__ == '__main__': model = YOLO('ultralytics/cfg/models/v8/yolov8-obb.yaml') model.load('model/yolov8n-obb.pt') # loading pretrain weights model.train(data='data/defect_obb/defect_obb.yaml', cache=False, imgsz=640, epochs=50, batch=2, close_mosaic=10, workers=0, device='0', optimizer='SGD', # using SGD project='runs/train', name='exp', )
上面的参数这里就不再赘述了。
7. 导出ONNX
训练完后,我们可以导出模型.pt格式为onnx格式
from ultralytics import YOLO model = YOLO("path/to/best.pt") # load a custom-trained model # Export the model model.export(format="onnx")
8. ONNX推理
这里给一个网上的且结合自己修改后的版本。
import os import cv2 import math import random import numpy as np import onnxruntime as ort class RotatedBOX: def __init__(self, box, score, class_index): self.box = box self.score = score self.class_index = class_index class ONNXInfer: def __init__(self, onnx_model, class_names, device='auto', conf_thres=0.5, nms_thres=0.4) -> None: self.onnx_model = onnx_model self.class_names = class_names self.conf_thres = conf_thres self.nms_thres = nms_thres self.device = self._select_device(device) # logger.info(f"Loading model on {self.device}...") self.session_model = ort.InferenceSession( self.onnx_model, providers=self.device, sess_options=self._get_session_options() ) def _select_device(self, device): """ Select the appropriate device. :param device: 'auto', 'cuda', or 'cpu'. :return: List of providers. """ if device == 'cuda' or (device == 'auto' and ort.get_device() == 'GPU'): return ['CUDAExecutionProvider', 'CPUExecutionProvider'] return ['CPUExecutionProvider'] def _get_session_options(self): sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED sess_options.intra_op_num_threads = 4 return sess_options def preprocess(self, img): """ Preprocess the image for inference. :param img: Input image. :return: Preprocessed image blob, original image width, and original image height. """ # logger.info( # "Preprocessing input image to [1, channels, input_w, input_h] format") height, width = img.shape[:2] length = max(height, width) image = np.zeros((length, length, 3), np.uint8) image[0:height, 0:width] = img input_shape = self.session_model.get_inputs()[0].shape[2:] # logger.debug(f"Input shape: {input_shape}") blob = cv2.dnn.blobFromImage( image, scalefactor=1 / 255, size=tuple(input_shape), swapRB=True) # logger.info(f"Preprocessed image blob shape: {blob.shape}") return blob, image, width, height def predict(self, img): """ Perform inference on the image. :param img: Input image. :return: Inference results. """ blob, resized_image, orig_width, orig_height = self.preprocess(img) inputs = {self.session_model.get_inputs()[0].name: blob} try: outputs = self.session_model.run(None, inputs) except Exception as e: # logger.error(f"Inference failed: {e}") raise return self.postprocess(outputs, resized_image, orig_width, orig_height) def postprocess(self, outputs, resized_image, orig_width, orig_height): """ Postprocess the model output. :param outputs: Model outputs. :param resized_image: Resized image used for inference. :param orig_width: Original image width. :param orig_height: Original image height. :return: List of RotatedBOX objects. """ output_data = outputs[0] # logger.info( # f"Postprocessing output data with shape: {output_data.shape}") input_shape = self.session_model.get_inputs()[0].shape[2:] x_factor = resized_image.shape[1] / float(input_shape[1]) y_factor = resized_image.shape[0] / float(input_shape[0]) flattened_output = output_data.flatten() reshaped_output = np.reshape( flattened_output, (output_data.shape[1], output_data.shape[2])).T detected_boxes = [] confidences = [] rotated_boxes = [] num_classes = len(self.class_names) for detection in reshaped_output: class_scores = detection[4:4 + num_classes] class_id = np.argmax(class_scores) confidence_score = class_scores[class_id] if confidence_score > self.conf_thres: cx, cy, width, height = detection[:4] * \ [x_factor, y_factor, x_factor, y_factor] angle = detection[4 + num_classes] if 0.5 * math.pi <= angle <= 0.75 * math.pi: angle -= math.pi box = ((cx, cy), (width, height), angle * 180 / math.pi) rotated_box = RotatedBOX(box, confidence_score, class_id) detected_boxes.append(cv2.boundingRect(cv2.boxPoints(box))) rotated_boxes.append(rotated_box) confidences.append(confidence_score) nms_indices = cv2.dnn.NMSBoxes( detected_boxes, confidences, self.conf_thres, self.nms_thres) remain_boxes = [rotated_boxes[i] for i in nms_indices.flatten()] # logger.info(f"Detected {len(remain_boxes)} objects after NMS") return remain_boxes def generate_colors(self, num_classes): """ Generate a list of distinct colors for each class. :param num_classes: Number of classes. :return: List of RGB color tuples. """ colors = [] for _ in range(num_classes): colors.append((random.randint(0, 255), random.randint( 0, 255), random.randint(0, 255))) return colors def drawshow(self, original_image, detected_boxes, class_labels): """ Draw detected bounding boxes and labels on the image and display it. :param original_image: The input image on which to draw the boxes. :param detected_boxes: List of detected RotatedBOX objects. :param class_labels: List of class labels. """ # Generate random colors for each class num_classes = len(class_labels) colors = self.generate_colors(num_classes) for detected_box in detected_boxes: box = detected_box.box points = cv2.boxPoints(box) # Rescale the points back to the original image dimensions points[:, 0] = points[:, 0] points[:, 1] = points[:, 1] points = np.int0(points) class_id = detected_box.class_index # Draw the bounding box with the color for the class color = colors[class_id] cv2.polylines(original_image, [points], isClosed=True, color=color, thickness=2) # Put the class label text with the same color cv2.putText(original_image, class_labels[class_id], (points[0][0], points[0][1]), cv2.FONT_HERSHEY_PLAIN, 1.0, color, 1) # Display the image with drawn boxes cv2.imshow("Detected Objects", original_image) cv2.waitKey(0) cv2.destroyAllWindows() def get_image_paths(folder_path, extension=".png", is_use_extension=False): image_paths = [] # 遍历目录 for root, dirs, files in os.walk(folder_path): for file in files: # 检查文件扩展名 if file.endswith(extension) or is_use_extension == False: # 构造完整的文件路径并添加到列表 image_path = os.path.join(root, file) image_paths.append(image_path) return image_paths # def get_box_list_from_predictions(predictions,confidence_threshold=0.65): ''' predictions[0] {'box': ((395.2504436016083, 156.47752561569214), (28.97803795337677, 60.449385666847235), 89.45477582379273), 'score': 0.9875654, 'class_index': 0} :param predictions: :return: ''' obb_boxes_list = [] for prediction in predictions: prediction_dict=prediction.__dict__ prediction_box=prediction_dict['box'] conf=prediction_dict['score'] class_index=prediction_dict['class_index'] if conf>confidence_threshold: obb_boxes_list.append([prediction_box,conf,class_index]) return obb_boxes_list def draw_rectangles_on_image(image, rectangles, color=(0, 255, 0), thickness=2, show_score=True): """ 在图像上绘制矩形 """ result = image.copy() for rect in rectangles: x, y, w, h, score = rect # 绘制矩形 cv2.rectangle(result, (x, y), (x + w, y + h), color, thickness) # 显示分数 if show_score: label = f"{score:.3f}" cv2.putText(result, label, (x, y - 5),cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) return result def draw_rotated_boxes_minimal(image, image_obb_boxes_list): """ 最简版本:只绘制旋转矩形框和类别ID """ image_with_boxes = image.copy() for pred in image_obb_boxes_list: ((cx, cy), (w, h), angle), confidence, class_id = pred # 计算旋转矩形顶点 rect = ((cx, cy), (w, h), angle) box_points = cv2.boxPoints(rect) # box_points = np.int0(box_points) box_points = np.intp(box_points) # 绘制旋转矩形框(默认绿色) cv2.drawContours(image_with_boxes, [box_points], 0, (0, 255, 0), 2) # 在中心点标记类别ID cv2.putText(image_with_boxes, str(class_id),(int(cx) - 10, int(cy) + 5),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) return image_with_boxes def get_obb_onnx_model(model_path = r"model/obb_best.onnx",class_names = ["dianzu","dianrong"]): model = ONNXInfer(onnx_model=model_path, class_names=class_names) return model if __name__ == '__main__': obb_model=get_obb_onnx_model(model_path = r"model/obb_best.onnx",class_names = ["dianzu","dianrong"]) image_path= r'image.jpg' detect_image=cv2.imread(image_path,1) #待检测图 predictions = obb_model.predict(detect_image) image_obb_boxes_list = get_box_list_from_predictions(predictions,confidence_threshold=0.6) print("image_obb_boxes_list",image_obb_boxes_list) #[[((np.float64(401.2665241241455), np.float64(150.62496876716614)), (np.float64(27.15373088121414), np.float64(60.12933983802795)), np.float32(89.404434)), np.float32(0.9769711), np.int64(0)], [((np.float64(280.50368547439575), np.float64(85.80472011566162)), (np.float64(46.974865317344666), np.float64(97.59349856376647)), np.float32(89.835884)), np.float32(0.97330356), np.int64(0)], [((np.float64(313.74962425231934), np.float64(194.6945291519165)), (np.float64(21.952924823760984), np.float64(28.783482360839844)), np.float32(0.46413484)), np.float32(0.9403417), np.int64(1)], [((np.float64(431.709785079956), np.float64(261.3335613250732)), (np.float64(23.31856728196144), np.float64(30.72814937829971)), np.float32(-0.40711102)), np.float32(0.92906225), np.int64(1)], [((np.float64(164.32013969421385), np.float64(132.7377662181854)), (np.float64(72.90492920875549), np.float64(53.1005439043045)), np.float32(0.57003915)), np.float32(0.9269689), np.int64(0)], [((np.float64(254.18212633132933), np.float64(194.80640244483948)), (np.float64(23.747826969623564), np.float64(29.390544176101685)), np.float32(1.413964)), np.float32(0.91280127), np.int64(1)]] #[[(中心坐标x,中心坐标y 宽,高, 角度,置信度,数字类别,... draw_image=draw_rotated_boxes_minimal(detect_image, image_obb_boxes_list) cv2.imwrite("./draw_image.png",draw_image)
9. 测试效果
这里通过yolov8 obb实现对产品中的电阻是否倾斜进行检测(后面只需要判断角度就能知道是否倾斜了)。效果绘制图如下:

小结:本文主要对yolov8-obb进行了从数据标注到配置,再到模型导出与推理的完整流程,并对PCB板的NG产品进行了实际测试。关于结构会在后面详细补充。
若存在不足或错误之处,欢迎指出,觉得有用请点个赞再走,谢谢!

浙公网安备 33010602011771号