yolov8 obb模型训练与onnx推理

1. 模型简介

   YOLOv8 OBB(Oriented Bounding Box)是 YOLOv8 的旋转边界框版本,专门用于检测旋转/倾斜的物体。本文主要对数据标注,数据集制作,训练,模型导出与推理进行简要说明,关于其结构解析后面补充。

2. 数据标注

  数据标注工具很多,这里列举两个能标注旋转框,适合obb标注的软件。roLabelImg(除了roLabelImg还需要装labelImg)与X-AnyLabeling。我这里使用的是X-AnyLabeling,下载地址如下:

  https://github.com/CVHub520/X-AnyLabeling

image

  然后点击导出obb格式就可以了(这样导出可以一步到位,不用再次转换等操作。导出时候会让你选择一个class Files的txt文件,自己编写一个就好了,每行为英文类别名就可以了)。

image

  会得到一个labels文件夹,里面有每张图对应的txt文件,也就是我们后面训练用到的数据,格式如下:

image

 分别是: 数字类别  x1  y1  x2  y2  x3  y3  x4  y4

  其中x1,y1,x2,y2,x3,y3,x4,y4为旋转框的四个角点坐标(归一化后的),一般是顺时针排列(逆时针应该也可以,这里按照上图的直接导出就可以了)。

3. 制作数据集

  制作数据集的脚本有很多,这里先看看数据集制作好后是什么结构(这很重要)。

image

   展开后如下(这里方便给图,只示范了几张图):

image

 

  首先通过前面的标注,我们已经有了标注数据集(包括txt与对应图片的文件夹),比如images_labels文件夹,内容如下:

企业微信截图_17689998051227

  接下来使用下面的python脚本,制作数据集了:

import os
import random
import shutil

def split_dataset(image_folder_path, output_folder_path, train_ratio=0.8):
    """
    划分图片和标签数据集

    参数:
    image_folder_path: 图片文件夹路径
    output_folder_path: 输出文件夹路径
    train_ratio: 训练集比例,默认0.8
    """

    # 创建输出文件夹结构
    dirs = [
        'images/train',
        'images/val',
        'labels/train',
        'labels/val'
    ]

    for dir_name in dirs:
        dir_path = os.path.join(output_folder_path, dir_name)
        os.makedirs(dir_path, exist_ok=True)

    # 支持的图片格式
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']

    # 获取所有图片文件
    image_files = []
    for file in os.listdir(image_folder_path):
        file_lower = file.lower()
        if any(file_lower.endswith(ext) for ext in image_extensions):
            image_files.append(file)

    # 随机打乱
    random.shuffle(image_files)

    # 计算划分点
    split_point = int(len(image_files) * train_ratio)
    train_images = image_files[:split_point]
    val_images = image_files[split_point:]

    # 复制训练集
    for img_file in train_images:
        # 构建文件路径
        img_path = os.path.join(image_folder_path, img_file)

        # 构建标签文件名(相同的基本名,扩展名改为.txt)
        base_name = os.path.splitext(img_file)[0]
        label_file = base_name + '.txt'
        label_path = os.path.join(image_folder_path, label_file)

        # 复制图片
        dest_img = os.path.join(output_folder_path, 'images/train', img_file)
        shutil.copy2(img_path, dest_img)

        # 复制标签(如果存在)
        if os.path.exists(label_path):
            dest_label = os.path.join(output_folder_path, 'labels/train', label_file)
            shutil.copy2(label_path, dest_label)
        else:
            print(f"警告: {label_file} 不存在")

    # 复制验证集
    for img_file in val_images:
        # 构建文件路径
        img_path = os.path.join(image_folder_path, img_file)

        # 构建标签文件名
        base_name = os.path.splitext(img_file)[0]
        label_file = base_name + '.txt'
        label_path = os.path.join(image_folder_path, label_file)

        # 复制图片
        dest_img = os.path.join(output_folder_path, 'images/val', img_file)
        shutil.copy2(img_path, dest_img)

        # 复制标签(如果存在)
        if os.path.exists(label_path):
            dest_label = os.path.join(output_folder_path, 'labels/val', label_file)
            shutil.copy2(label_path, dest_label)
        else:
            print(f"警告: {label_file} 不存在")

    print(f"数据集划分完成!")
    print(f"训练集: {len(train_images)} 张图片")
    print(f"验证集: {len(val_images)} 张图片")
    print(f"输出路径: {output_folder_path}")

# 使用示例
if __name__ == "__main__":
    # 示例路径
    image_folder = r"images_labels"  #存放image与txt的文件夹
    output_folder = r"train" #数据集保存的文件夹路径

    # 调用函数
    split_dataset(image_folder, output_folder)

  运行完,就能得到前面的那个结构数据集了(训练,验证的比例可以根据自己需要修改参数train_ratio)。至此,训练集数据划分完毕了。

4. 配置文件

  我们还需要yolov8-obb.yaml这个文件,如下:

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

# Ultralytics YOLOv8-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolov8
# Task docs: https://docs.ultralytics.com/tasks/obb

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024] # YOLOv8n-obb summary: 144 layers, 3228867 parameters, 3228851 gradients, 9.1 GFLOPs
  s: [0.33, 0.50, 1024] # YOLOv8s-obb summary: 144 layers, 11452739 parameters, 11452723 gradients, 29.8 GFLOPs
  m: [0.67, 0.75, 768] # YOLOv8m-obb summary: 184 layers, 26463235 parameters, 26463219 gradients, 81.5 GFLOPs
  l: [1.00, 1.00, 512] # YOLOv8l-obb summary: 224 layers, 44540355 parameters, 44540339 gradients, 169.4 GFLOPs
  x: [1.00, 1.25, 512] # YOLOv8x-obb summary: 224 layers, 69555651 parameters, 69555635 gradients, 264.3 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)

  - [[15, 18, 21], 1, OBB, [nc, 1]] # OBB(P3, P4, P5)
除了上面那个文件,我们还需要配置类别与数据路径的文件detect_obb.yaml,如下(路径与类别根据实际需要配置):
# Ultralytics YOLO 🚀, AGPL-3.0 license
# DOTA 1.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University
# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/
# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── dota1  ← downloads here (2GB)

# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ./data/defect_obb/defect_obb  # dataset root dir
train: images/train  # train images (relative to 'path') 1411 images
val: images/val  # val images (relative to 'path') 458 images
#test: images/test  # test images (optional) 937 images

# Classes for DOTA 1.0
names:
  0: defect

5. 预训练权重

  去到这个地址https://docs.ultralytics.com/models/yolov8/#models去下载就好了,n到x,模型逐渐变大,根据任务需求选择。

image

   也可以到网盘下载:  链接: https://pan.baidu.com/s/1kHgiXYCOC8w4YBcND9aTiQ 提取码: a8vc

  至此,数据集与文件都准备完毕了。

 6. 训练

  训练之前,需要下载ultralytics 包或者下载源码也行。

pip install  ultralytics

  或者 https://github.com/ultralytics/ultralytics

  然后使用如下代码训练:

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('ultralytics/cfg/models/v8/yolov8-obb.yaml')
    model.load('model/yolov8n-obb.pt') # loading pretrain weights
    model.train(data='data/defect_obb/defect_obb.yaml',
                cache=False,
                imgsz=640,
                epochs=50,
                batch=2,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                project='runs/train',
                name='exp',
                )

  上面的参数这里就不再赘述了。

7. 导出ONNX

  训练完后,我们可以导出模型.pt格式为onnx格式

from ultralytics import YOLO

model = YOLO("path/to/best.pt")  # load a custom-trained model
# Export the model
model.export(format="onnx")

8. ONNX推理

  这里给一个网上的且结合自己修改后的版本。

import os

import cv2
import math
import random
import numpy as np
import onnxruntime as ort


class RotatedBOX:
    def __init__(self, box, score, class_index):
        self.box = box
        self.score = score
        self.class_index = class_index

class ONNXInfer:
    def __init__(self, onnx_model, class_names, device='auto', conf_thres=0.5, nms_thres=0.4) -> None:
        self.onnx_model = onnx_model
        self.class_names = class_names
        self.conf_thres = conf_thres
        self.nms_thres = nms_thres
        self.device = self._select_device(device)

        # logger.info(f"Loading model on {self.device}...")
        self.session_model = ort.InferenceSession(
            self.onnx_model,
            providers=self.device,
            sess_options=self._get_session_options()
        )

    def _select_device(self, device):
        """
        Select the appropriate device.
        :param device: 'auto', 'cuda', or 'cpu'.
        :return: List of providers.
        """
        if device == 'cuda' or (device == 'auto' and ort.get_device() == 'GPU'):
            return ['CUDAExecutionProvider', 'CPUExecutionProvider']
        return ['CPUExecutionProvider']

    def _get_session_options(self):
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
        sess_options.intra_op_num_threads = 4
        return sess_options

    def preprocess(self, img):
        """
        Preprocess the image for inference.
        :param img: Input image.
        :return: Preprocessed image blob, original image width, and original image height.
        """
        # logger.info(
        #     "Preprocessing input image to [1, channels, input_w, input_h] format")
        height, width = img.shape[:2]
        length = max(height, width)
        image = np.zeros((length, length, 3), np.uint8)
        image[0:height, 0:width] = img

        input_shape = self.session_model.get_inputs()[0].shape[2:]
        # logger.debug(f"Input shape: {input_shape}")

        blob = cv2.dnn.blobFromImage(
            image, scalefactor=1 / 255, size=tuple(input_shape), swapRB=True)
        # logger.info(f"Preprocessed image blob shape: {blob.shape}")

        return blob, image, width, height

    def predict(self, img):
        """
        Perform inference on the image.
        :param img: Input image.
        :return: Inference results.
        """
        blob, resized_image, orig_width, orig_height = self.preprocess(img)
        inputs = {self.session_model.get_inputs()[0].name: blob}
        try:
            outputs = self.session_model.run(None, inputs)
        except Exception as e:
            # logger.error(f"Inference failed: {e}")
            raise
        return self.postprocess(outputs, resized_image, orig_width, orig_height)

    def postprocess(self, outputs, resized_image, orig_width, orig_height):
        """
        Postprocess the model output.
        :param outputs: Model outputs.
        :param resized_image: Resized image used for inference.
        :param orig_width: Original image width.
        :param orig_height: Original image height.
        :return: List of RotatedBOX objects.
        """
        output_data = outputs[0]
        # logger.info(
        #     f"Postprocessing output data with shape: {output_data.shape}")

        input_shape = self.session_model.get_inputs()[0].shape[2:]
        x_factor = resized_image.shape[1] / float(input_shape[1])
        y_factor = resized_image.shape[0] / float(input_shape[0])

        flattened_output = output_data.flatten()
        reshaped_output = np.reshape(
            flattened_output, (output_data.shape[1], output_data.shape[2])).T

        detected_boxes = []
        confidences = []
        rotated_boxes = []

        num_classes = len(self.class_names)

        for detection in reshaped_output:
            class_scores = detection[4:4 + num_classes]
            class_id = np.argmax(class_scores)
            confidence_score = class_scores[class_id]

            if confidence_score > self.conf_thres:
                cx, cy, width, height = detection[:4] * \
                    [x_factor, y_factor, x_factor, y_factor]
                angle = detection[4 + num_classes]

                if 0.5 * math.pi <= angle <= 0.75 * math.pi:
                    angle -= math.pi

                box = ((cx, cy), (width, height), angle * 180 / math.pi)
                rotated_box = RotatedBOX(box, confidence_score, class_id)

                detected_boxes.append(cv2.boundingRect(cv2.boxPoints(box)))
                rotated_boxes.append(rotated_box)
                confidences.append(confidence_score)

        nms_indices = cv2.dnn.NMSBoxes(
            detected_boxes, confidences, self.conf_thres, self.nms_thres)
        remain_boxes = [rotated_boxes[i] for i in nms_indices.flatten()]

        # logger.info(f"Detected {len(remain_boxes)} objects after NMS")
        return remain_boxes

    def generate_colors(self, num_classes):
        """
        Generate a list of distinct colors for each class.

        :param num_classes: Number of classes.
        :return: List of RGB color tuples.
        """
        colors = []
        for _ in range(num_classes):
            colors.append((random.randint(0, 255), random.randint(
                0, 255), random.randint(0, 255)))
        return colors

    def drawshow(self, original_image, detected_boxes, class_labels):
        """
        Draw detected bounding boxes and labels on the image and display it.

        :param original_image: The input image on which to draw the boxes.
        :param detected_boxes: List of detected RotatedBOX objects.
        :param class_labels: List of class labels.
        """
        # Generate random colors for each class
        num_classes = len(class_labels)
        colors = self.generate_colors(num_classes)

        for detected_box in detected_boxes:
            box = detected_box.box
            points = cv2.boxPoints(box)

            # Rescale the points back to the original image dimensions
            points[:, 0] = points[:, 0]
            points[:, 1] = points[:, 1]
            points = np.int0(points)

            class_id = detected_box.class_index

            # Draw the bounding box with the color for the class
            color = colors[class_id]
            cv2.polylines(original_image, [points],
                          isClosed=True, color=color, thickness=2)
            # Put the class label text with the same color
            cv2.putText(original_image, class_labels[class_id], (points[0][0], points[0][1]),
                        cv2.FONT_HERSHEY_PLAIN, 1.0, color, 1)

        # Display the image with drawn boxes
        cv2.imshow("Detected Objects", original_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()


def get_image_paths(folder_path, extension=".png", is_use_extension=False):
    image_paths = []

    # 遍历目录
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            # 检查文件扩展名
            if file.endswith(extension) or is_use_extension == False:
                # 构造完整的文件路径并添加到列表
                image_path = os.path.join(root, file)
                image_paths.append(image_path)

    return image_paths

#
def get_box_list_from_predictions(predictions,confidence_threshold=0.65):
    '''
    predictions[0] {'box': ((395.2504436016083, 156.47752561569214), (28.97803795337677, 60.449385666847235), 89.45477582379273), 'score': 0.9875654, 'class_index': 0}
    :param predictions:
    :return:
    '''
    obb_boxes_list = []
    for prediction in predictions:
        prediction_dict=prediction.__dict__
        prediction_box=prediction_dict['box']
        conf=prediction_dict['score']
        class_index=prediction_dict['class_index']
        if conf>confidence_threshold:
            obb_boxes_list.append([prediction_box,conf,class_index])

    return obb_boxes_list



def draw_rectangles_on_image(image, rectangles, color=(0, 255, 0), thickness=2, show_score=True):
    """
    在图像上绘制矩形
    """
    result = image.copy()

    for rect in rectangles:
        x, y, w, h, score = rect

        # 绘制矩形
        cv2.rectangle(result, (x, y), (x + w, y + h), color, thickness)

        # 显示分数
        if show_score:
            label = f"{score:.3f}"
            cv2.putText(result, label, (x, y - 5),cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

    return result




def draw_rotated_boxes_minimal(image, image_obb_boxes_list):
    """
    最简版本:只绘制旋转矩形框和类别ID
    """
    image_with_boxes = image.copy()

    for pred in image_obb_boxes_list:
        ((cx, cy), (w, h), angle), confidence, class_id = pred

        # 计算旋转矩形顶点
        rect = ((cx, cy), (w, h), angle)
        box_points = cv2.boxPoints(rect)
        # box_points = np.int0(box_points)
        box_points = np.intp(box_points)

        # 绘制旋转矩形框(默认绿色)
        cv2.drawContours(image_with_boxes, [box_points], 0, (0, 255, 0), 2)

        # 在中心点标记类别ID
        cv2.putText(image_with_boxes, str(class_id),(int(cx) - 10, int(cy) + 5),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

    return image_with_boxes


def get_obb_onnx_model(model_path = r"model/obb_best.onnx",class_names = ["dianzu","dianrong"]):
    model = ONNXInfer(onnx_model=model_path, class_names=class_names)
    return model



if __name__ == '__main__':
    obb_model=get_obb_onnx_model(model_path = r"model/obb_best.onnx",class_names = ["dianzu","dianrong"])

    image_path= r'image.jpg'
    detect_image=cv2.imread(image_path,1) #待检测图
    predictions = obb_model.predict(detect_image)
    image_obb_boxes_list = get_box_list_from_predictions(predictions,confidence_threshold=0.6)
    print("image_obb_boxes_list",image_obb_boxes_list)
    #[[((np.float64(401.2665241241455), np.float64(150.62496876716614)), (np.float64(27.15373088121414), np.float64(60.12933983802795)), np.float32(89.404434)), np.float32(0.9769711), np.int64(0)], [((np.float64(280.50368547439575), np.float64(85.80472011566162)), (np.float64(46.974865317344666), np.float64(97.59349856376647)), np.float32(89.835884)), np.float32(0.97330356), np.int64(0)], [((np.float64(313.74962425231934), np.float64(194.6945291519165)), (np.float64(21.952924823760984), np.float64(28.783482360839844)), np.float32(0.46413484)), np.float32(0.9403417), np.int64(1)], [((np.float64(431.709785079956), np.float64(261.3335613250732)), (np.float64(23.31856728196144), np.float64(30.72814937829971)), np.float32(-0.40711102)), np.float32(0.92906225), np.int64(1)], [((np.float64(164.32013969421385), np.float64(132.7377662181854)), (np.float64(72.90492920875549), np.float64(53.1005439043045)), np.float32(0.57003915)), np.float32(0.9269689), np.int64(0)], [((np.float64(254.18212633132933), np.float64(194.80640244483948)), (np.float64(23.747826969623564), np.float64(29.390544176101685)), np.float32(1.413964)), np.float32(0.91280127), np.int64(1)]]
    #[[(中心坐标x,中心坐标y 宽,高, 角度,置信度,数字类别,...
    draw_image=draw_rotated_boxes_minimal(detect_image, image_obb_boxes_list)
    cv2.imwrite("./draw_image.png",draw_image)

9. 测试效果

  这里通过yolov8 obb实现对产品中的电阻是否倾斜进行检测(后面只需要判断角度就能知道是否倾斜了)。效果绘制图如下:

29

 

  小结:本文主要对yolov8-obb进行了从数据标注到配置,再到模型导出与推理的完整流程,并对PCB板的NG产品进行了实际测试。关于结构会在后面详细补充。

 

 

  若存在不足或错误之处,欢迎指出,觉得有用请点个赞再走,谢谢!

  

posted @ 2026-01-21 21:44  wancy  阅读(216)  评论(0)    收藏  举报