解决mmdeploy导出mmdet模型的时候报RuntimeError: Only tuples, lists and Variables are supported...Here, received an input of unsupported type: DetDataSample

1. 环境

  • 系统:

  mmdeploy-mmdet:ubuntu20.04-cuda11.8-mmdeploy1.3.1 docker镜像。

  • 相关软件:

  mmdetection 3.3.0

  mmdeploy 1.3.1

  • 转换模型:

  Co-DETR,使用co_dino_5scale_r50_8xb2_1x_coco.py模型配置。

2. 问题描述

使用官方提供的python接口进行模型转换:

from mmdeploy.apis import torch2onnx
from mmdeploy.backend.sdk.export_info import export2SDK

img = 'demo/xxxx.jpg'
work_dir = 'work_dir/xxxx/onnx'
save_file = 'best_coco_bbox_mAP_epoch_12.onnx'
deploy_cfg = '../mmdeploy/configs/mmdet/detection/detection_onnxruntime_dynamic.py'
model_cfg = 'work_dir/xxxx/co_dino_5scale_r50_8xb2_1x_coco.py'
model_checkpoint = 'work_dir/xxxx/best_coco_bbox_mAP_epoch_12.pth'
device = 'cpu'

# 1. convert model to onnx
print("开始转换onnx模型...")
torch2onnx(img, work_dir, save_file, deploy_cfg, model_cfg,
           model_checkpoint, device)
print("完成转换onnx模型.")
# 2. extract pipeline info for inference by MMDeploy SDK
print("开始转换MMDeploy SDK...")
export2SDK(deploy_cfg, model_cfg, work_dir, pth=model_checkpoint,
           device=device)
print("完成转换MMDeploy SDK...")

报错:

....

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: DetDataSample

3. 解决方法

出现这个问题的原因是,在使用torch.onnx.export导出onnx的时候,torch.onnx.export 通过执行一次前向计算,追踪 PyTorch 的算子调用, 将 PyTorch 计算图映射为 ONNX 算子图,并序列化保存为静态 ONNX 模型。

查看Co-DETR源文件projects/CO-DETR/codetr/codetr.py,可以看到其前向推理过程为:

    def predict(self,
                batch_inputs: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = True) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
            batch_data_samples (List[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
            rescale (bool): Whether to rescale the results.
                Defaults to True.

        Returns:
            list[:obj:`DetDataSample`]: Detection results of the input images.
            Each DetDataSample usually contain 'pred_instances'. And the
            `pred_instances` usually contains following keys.

            - scores (Tensor): Classification scores, has a shape
              (num_instance, )
            - labels (Tensor): Labels of bboxes, has a shape
              (num_instances, ).
            - bboxes (Tensor): Has a shape (num_instances, 4),
              the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        assert self.eval_module in ['detr', 'one-stage', 'two-stage']

        if self.use_lsj:
            for data_samples in batch_data_samples:
                img_metas = data_samples.metainfo
                input_img_h, input_img_w = img_metas['batch_input_shape']
                img_metas['img_shape'] = [input_img_h, input_img_w]

        img_feats = self.extract_feat(batch_inputs)
        if self.with_bbox and self.eval_module == 'one-stage':
            results_list = self.predict_bbox_head(
                img_feats, batch_data_samples, rescale=rescale)
        elif self.with_roi_head and self.eval_module == 'two-stage':
            results_list = self.predict_roi_head(
                img_feats, batch_data_samples, rescale=rescale)
        else:
            results_list = self.predict_query_head(
                img_feats, batch_data_samples, rescale=rescale)
        batch_data_samples = self.add_pred_to_datasample(
            batch_data_samples, results_list)
        return batch_data_samples

可以看到返回的类型为list[:obj:`DetDataSample`],并不是支持的类型。在"batch_data_samples = self.add_pred_to_datasample...."语句前增加处理:

    if torch.onnx.is_in_onnx_export():
        det_bboxes = []
        det_labels = []
        for result in results_list:
            det_bboxes.append(torch.cat([result['bboxes'], result['scores'].view(-1, 1)], dim=1))
            det_labels.append(result['labels'])
        return det_bboxes, det_labels

因为我们在配置文件中的模型输出配置是:

onnx_config = dict(output_names=['dets', 'labels'], input_shape=None)

其中dets输出中包括坐标和置信度,所以需要将results_list中的bboxes和scores拼接在一起。

(完)

posted @ 2025-12-16 11:15  大师兄啊哈  阅读(10)  评论(0)    收藏  举报