如何将MMrotate的识别结果转换为dota和fair1m格式

问题来源,在使用mmrotate的过程中,需要能够对识别的结果进行推断,结果发现缺乏相关功能:
From the demo i know show_result_pyplot can plot the inferred results, I would like to ask how to convert inferred results to DOTA format, is there a related function? Or do you need to handle the result directly?
thanks!
Hi @jsxyhelu, we did not provide the corresponding script to the user. You need to convert the results to DOTA format by yourself. Welcome to submit your script to help more people.
 
那么就自己来设计实现相关功能:
 
一、数据格式:
mmrotate的输出格式为:

分别为: x, y, w, h, theta, score.

目标格式为Dota采用 txt 文件存放,

其中一个标注框对应为:  x1、 y1、   x2、 y2、 x3、 y3、 x4、 y4、 classname、diffcult 。注意这里没有归一化处理 

二、批量处理和保存

首先将result保存下来

import os
import numpy as np
 
src_label_root = '/root/mmrotate/demo/ssdd_tiny/images/'
dst_label_root = '/root/mmrotate/demo/ssdd_tiny/dst/'
!mkdir '/root/mmrotate/demo/ssdd_tiny/dst/'
 
 
model.cfg = cfg
for i, src_label_name in enumerate(os.listdir(src_label_root)):
    src_label_path = os.path.join(src_label_root,src_label_name) #输入地址
    dst_label_path = os.path.join(dst_label_root,os.path.splitext(src_label_name)[0]+".txt")
    img = mmcv.imread(src_label_path)
    result = inference_detector(model, img)
    np.savetxt(dst_label_path, result[0], delimiter=',')
    print(dst_label_path)

而后进行格式转换,对于单通道图片来说为:

def rota( x, y, w, h, a):  # 旋转中心点,旋转中心点,框的w,h,旋转角
    center_x1 = x
    center_y1 = y
    x1, y1 = x - w / 2, y - h / 2  # 旋转前左上
    x2, y2 = x + w / 2, y - h / 2  # 旋转前右上
    x3, y3 = x + w / 2, y + h / 2  # 旋转前右下
    x4, y4 = x - w / 2, y + h / 2  # 旋转前左下
    px1 = (x1 - center_x1) * math.cos(a) - (y1 - center_y1) * math.sin(a) + center_x1  # 旋转后左上
    py1 = (x1 - center_x1) * math.sin(a) + (y1 - center_y1) * math.cos(a) + center_y1
    px2 = (x2 - center_x1) * math.cos(a) - (y2 - center_y1) * math.sin(a) + center_x1  # 旋转后右上
    py2 = (x2 - center_x1) * math.sin(a) + (y2 - center_y1) * math.cos(a) + center_y1
    px3 = (x3 - center_x1) * math.cos(a) - (y3 - center_y1) * math.sin(a) + center_x1  # 旋转后右下
    py3 = (x3 - center_x1) * math.sin(a) + (y3 - center_y1) * math.cos(a) + center_y1
    px4 = (x4 - center_x1) * math.cos(a) - (y4 - center_y1) * math.sin(a) + center_x1  # 旋转后左下
    py4 = (x4 - center_x1) * math.sin(a) + (y4 - center_y1) * math.cos(a) + center_y1
 
    return px1, py1, px2, py2, px3, py3, px4, py4  # 旋转后的四个点,左上,右上,右下,左下

def mmrotate2dota(src_img_root, src_label_root, dst_label_root,class_map,score_thr=0.3):
    not_have_img = []
    if not os.path.exists(dst_label_root):
        os.makedirs(dst_label_root)
    # 遍历所有txt文件
    for i, src_label_name in enumerate(os.listdir(src_label_root)):
        src_label_path = os.path.join(src_label_root,src_label_name) #输入地址
        dst_label_path = os.path.join(dst_label_root,src_label_name) #输出地址
        dst_label_list = []          ## 空列表
        with open(src_label_path, 'r') as fr:
            txtlines = fr.readlines()   #原始数据
        for line in txtlines:
            oneline = line.strip().split(",")    
            x = float(oneline[0])
            y = float(oneline[1])
            w = float(oneline[2])
            h = float(oneline[3])
            a = float(oneline[4])
            score = float(oneline[5])
            px1, py1, px2, py2, px3, py3, px4, py4 = rota(x,y,w,h,a)
            #目标格式为  x1、y1、x2、y2、x3、y3、x4、y4、 classname、diffcult
            dstline = str(px1)+" "+ str(py1)+" "+ str(px2)+" "+ str(py2)+" "+ str(px3)+" "+ str(py3)+" "+ str(px4)+" "+ str(py4)+" "+ str(class_map['0'])+ "1"
            if(score >= score_thr):
                dst_label_list.append(dstline)
        with open(dst_label_path,'w') as fw:
            fw.writelines([line+'\n' for line in dst_label_list]) #添加换行
        print(dst_label_path)
    print('convert done')

得到初步的对比结果,目视是正确的

 

使用Dota自己的工具进行标绘(Dota_devKit)

 

具体

查看  https://www.kaggle.com/code/jsxyhelu2019/ddd-mmrotate-result2dota

三、获得批量处理结果

当前的结果处理的只是一种类型的,在处理批量数据的时候是有不同的。

而且转换的过程中存在错误,需要进行修正。

通过模仿现有的例子,能够获得读取现有pt,执行推断的结果。

它的内容是这样来组织的:

 

一共37个array,每一个都是推测出来的位置。

这样的话在写下来的过程中,就需要编码了。

而且在推断的过程中,就是需要使用

from mmrotate.apis import inference_detector_by_patches
img = 'demo/dota_demo.jpg'
result = inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
 
def inference_detector_by_patches(model,
                                  img,
                                  sizes,
                                  steps,
                                  ratios,
                                  merge_iou_thr,
                                  bs=1):
    """inference patches with the detector.
    Split huge image(s) into patches and inference them with the detector.
    Finally, merge patch results on one huge image by nms.
    Args:
        model (nn.Module): The loaded detector.
        img (str | ndarray or): Either an image file or loaded image.
        sizes (list): The sizes of patches.
        steps (list): The steps between two patches.
        ratios (list): Image resizing ratios for multi-scale detecting.
        merge_iou_thr (float): IoU threshold for merging results.
        bs (int): Batch size, must greater than or equal to 1.
    Returns:
        list[np.ndarray]: Detection results.
    "
""

所以最后,单个写:

# Use the detector to do inference
dst = []
from mmrotate.apis import inference_detector_by_patches
img = '/home/helu/workstation/Fair1m/fair1M_jpg_train_split_1280_200/images/1__1__0___0.jpg'
result =  inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
for index,typeresult in enumerate(result):
    if(typeresult.size!=0):
        for lineresult in typeresult:
            lineresult = np.append(lineresult,  np.float32(index))
            dst.append(lineresult)
            #print(index)
print(dst)
#show_result_pyplot(model, img, result, score_thr=0.3)

批量处理,获得Dota的结果

 

 

test_image_root = '/home/helu/workstation/Fair1m/fair1M_jpg_test_tiny/images/'
test_result_root = '/home/helu/workstation/Fair1m/fair1M_jpg_test_tiny/labelTxt/'
dst = []
from mmrotate.apis import inference_detector_by_patches
for i, test_image_name in enumerate(os.listdir(test_image_root)):
    dst = []
    test_image_path = os.path.join(test_image_root,test_image_name) #输入地址
    dst_label_path = os.path.join(test_result_root,os.path.splitext(test_image_name)[0]+".txt")
    img = mmcv.imread(test_image_path)
    result =  inference_detector_by_patches(model, img, [1024], [824], [1.0], 0.1)
    for index,typeresult in enumerate(result):
        if(typeresult.size!=0):
            for lineresult in typeresult:
                lineresult = np.append(lineresult,  np.float32(index))
                dst.append(lineresult)
        np.savetxt(dst_label_path, dst, delimiter=',')

 

 

全部代码为 https://files.cnblogs.com/files/blogs/758212/MMRotat_infer.rar?t=1682895717&download=true

 

 

 

需要进行进一步的修改,或者数据转换也可以。

 

转换为Fair1m数据格式并上分,30 epoch获得这个值

 

posted on 2023-05-01 07:04  jsxyhelu  阅读(405)  评论(0编辑  收藏  举报

导航