折腾笔记[57]-批量合成字符级分割标注数据

摘要

使用latex(精确排版), opencv(连通域检测)和pillow(数据扰动)批量合成xanylabeling格式的字符级分割标注数据(带字符级边界框标注)的合成数据用于训练字符级分割模型.

声明

本文人类为第一作者, 龙虾为通讯作者.本文有AI生成内容.

前言

在 OCR(光学字符识别)任务中,字符级分割是关键预处理步骤。传统基于连通域的方法在处理粘连字符和断裂笔画时存在固有缺陷。ViT 模型通过全局自注意力机制,能够利用上下文语义信息判断像素归属,在理论上可以克服传统方法的局限。然而,ViT 模型的训练需要大量带有精确像素级标注的数据。手动标注成本高昂,因此合成数据生成成为重要的数据补充手段。

简介

tectonic简介

[https://erasin.wang/latex-quick/]
[https://tectonic-typesetting.github.io/]
这是使用 rust 编写的更加现代化 tex 引擎,内部使用的 xetex 和 texlive 驱动。 它会自动下载支持的文件,也就是不用自动管理资源文件了,编译文件十分的简单。

字符连通域分割算法

  1. 灰度化 + 二值化THRESH_BINARY_INV 提取黑色文字
  2. 形态学操作:开运算去噪,闭运算填孔(3×3 矩形核)
  3. 连通域分析connectedComponentsWithStats(8连通)
  4. 合并连通域:处理多笔画字符(如"轴"、"距")
  5. 坐标转换:局部 bbox + x_offset → 全局 bbox

工程

概述

  • LaTeX 精确排版:解决 PIL 基线对齐问题,全角/半角字符完美混排
  • 田字格/日字格排版:汉字占全格,标点占指定象限
  • OpenCV 连通域精确 bbox:参考 drawMissingRoiOpenCvBlob.cs 算法,精确包围字符黑色像素
  • x-anylabeling 格式imageData 自包含,无需 imagePath
  • rotation 边界框:支持旋转框标注

环境要求

  • Python 3.9+
  • Tectonic (LaTeX 编译器)
  • OpenCV (cv2)
  • Pillow
# 安装 Python 依赖
pip install opencv-python pillow

# 安装 Tectonic (macOS)
brew install tectonic

生成数据

python3 generate_synthetic_data.py

输出目录 output/

output/
├── 轴距_u0028_mm_u0029_.png       # 合成字符串图片
├── 轴距_u0028_mm_u0029_.json      # x-anylabeling 标注
├── 额定载质量_u0028_kg_u0029_.png
├── 额定载质量_u0028_kg_u0029_.json
├── ABS型号_u002f_生产企业.png
└── ABS型号_u002f_生产企业.json

x-anylabeling JSON标注格式

{
  "version": "5.4.1",
  "flags": {},
  "shapes": [
    {
      "label": "轴",
      "points": [[3, 0], [70, 0], [70, 70], [3, 70]],
      "group_id": 0,
      "description": "char_0",
      "shape_type": "rotation",
      "flags": {},
      "attributes": {
        "char_index": 0,
        "char_type": "tianzi"
      }
    }
  ],
  "imagePath": null,
  "imageData": "/9j/4AAQ...",
  "imageHeight": 70,
  "imageWidth": 316
}
  • imageData: Base64 编码的 PNG 图片
  • shape_type: rotation(旋转框)
  • group_id: 字符序号
  • attributes.char_type: tianzi(田字格)或 rizi(日字格)

文件

tree.log

.
├── generate_synthetic_data.py    # 主生成脚本(LaTeX + OpenCV)
├── generate_latex.py             # LaTeX 排版脚本(备用)
├── drawMissingRoiOpenCvBlob.cs   # 参考算法(C# OpenCV 连通域)
├── fonts/
│   └── simsun.ttf                # 宋体字体
├── strings.md                    # 输入字符串数据源
├── output/                       # 生成结果
├── README.md                     # 本文件
├── CHANGELOG.md                  # 更新日志
└── TODO.md                       # 任务清单

strings.md

轴距(mm)
额定载质量(kg)
ABS型号/生产企业

generate_latex.py

#!/usr/bin/env python3
"""
使用 LaTeX 进行精确排版生成 ViT 合成数据
- 全角字符:田字格排版(80x80)
- 半角字符:日字格排版(40x80)
- 输出 x-anylabeling 格式 JSON(imageData 存储)
"""

import json
import base64
import io
import os
import subprocess
import tempfile
import unicodedata
from PIL import Image

# ==================== 配置 ====================
FONT_DIR = "./fonts/"
OUTPUT_DIR = "output_latex"
TIKZ_BASELINE_RAISE = 3  # 半角字符基线上移量(pt)

# LaTeX 模板
TEX_TEMPLATE = r"""\documentclass[border=0pt]{standalone}
\usepackage{xeCJK}
\usepackage{fontspec}
\usepackage{tikz}
\usepackage{xcolor}
\usepackage{calc}

% 使用项目提供的宋体
\setCJKmainfont{simsun.ttf}[Path=./fonts/]
\setmainfont{simsun.ttf}[Path=./fonts/]

\begin{document}
\begin{tikzpicture}[x=1pt, y=1pt]

%(content)s

\end{tikzpicture}
\end{document}
"""


# ==================== 字符类型判断 ====================
def get_char_type(c: str) -> str:
    """返回字符类型"""
    cat = unicodedata.category(c)
    if "\u4e00" <= c <= "\u9fff":
        return "chinese"
    if "\uff01" <= c <= "\uff5e" or "\u3000" <= c <= "\u303f":
        return "fullwidth"
    if cat.startswith("P") or cat in ("Sm", "Sc", "Sk", "So"):
        return "halfwidth_punct"
    if cat.startswith("L"):
        return "halfwidth_letter"
    if cat.startswith("N"):
        return "halfwidth_digit"
    return "other"


def is_fullwidth_char(c: str) -> bool:
    """判断是否为全角字符(使用田字格)"""
    char_type = get_char_type(c)
    return char_type in ("chinese", "fullwidth")


# ==================== LaTeX 代码生成 ====================
def make_tianzi_cell(char: str, x_offset: int) -> str:
    """生成田字格(全角字符)的 LaTeX 代码
    全角字符占 80x80,居中于 (40,40)
    """
    return (
        f"\\begin{{scope}}[shift={{({x_offset},0)}}]\n"
        f"  \\node[anchor=center, inner sep=0pt, outer sep=0pt] at (40,40) "
        f"{{\\fontsize{{76}}{{76}}\\selectfont {char}}};\n"
        f"\\end{{scope}}"
    )


def make_rizi_cell(char: str, x_offset: int, raise_pt: int = TIKZ_BASELINE_RAISE) -> str:
    """生成日字格(半角字符)的 LaTeX 代码
    半角字符占 40x80,居中于 (20,40),基线上移
    """
    return (
        f"\\begin{{scope}}[shift={{({x_offset},0)}}]\n"
        f"  \\node[anchor=center, inner sep=0pt, outer sep=0pt] at (20,40) "
        f"{{\\fontsize{{76}}{{76}}\\selectfont\\raisebox{{{raise_pt}pt}}{{{char}}}}};\n"
        f"\\end{{scope}}"
    )


def generate_latex(text: str) -> str:
    """根据字符串生成完整 LaTeX 代码"""
    cells = []
    x_offset = 0
    char_positions = []  # 记录每个字符的位置信息

    for i, c in enumerate(text):
        if is_fullwidth_char(c):
            cells.append(make_tianzi_cell(c, x_offset))
            char_positions.append({
                "char": c,
                "index": i,
                "type": "tianzi",
                "x_offset": x_offset,
                "width": 80,
                "center_x": x_offset + 40,
                "center_y": 40,
            })
            x_offset += 80
        else:
            cells.append(make_rizi_cell(c, x_offset))
            char_positions.append({
                "char": c,
                "index": i,
                "type": "rizi",
                "x_offset": x_offset,
                "width": 40,
                "center_x": x_offset + 20,
                "center_y": 40,
            })
            x_offset += 40

    content = "\n".join(cells)
    tex_code = TEX_TEMPLATE.replace("%(content)s", content)
    return tex_code, char_positions, x_offset


# ==================== 编译与转换 ====================
def compile_latex(tex_code: str, output_prefix: str) -> str:
    """编译 LaTeX 并转为 PNG,返回 PNG 路径"""
    tex_path = f"{output_prefix}.tex"
    pdf_path = f"{output_prefix}.pdf"
    png_path = f"{output_prefix}.png"

    # 写入 .tex 文件
    with open(tex_path, "w", encoding="utf-8") as f:
        f.write(tex_code)

    # 编译 LaTeX
    result = subprocess.run(
        ["tectonic", tex_path],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        raise RuntimeError(f"LaTeX 编译失败: {result.stderr}")

    # PDF 转 PNG(使用 sips)
    result = subprocess.run(
        ["sips", "-s", "format", "png", pdf_path, "--out", png_path],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        raise RuntimeError(f"PDF 转 PNG 失败: {result.stderr}")

    return png_path


# ==================== 图片处理 ====================
def load_png(png_path: str) -> Image.Image:
    """加载 PNG 图片"""
    return Image.open(png_path)


def image_to_base64(img: Image.Image) -> str:
    """将 PIL Image 转为 Base64 字符串(PNG 格式)"""
    buffer = io.BytesIO()
    img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


# ==================== 边界框计算 ====================
def estimate_char_bbox(char_pos: dict, img_height: int = 80) -> tuple:
    """估算字符的边界框 (x, y, w, h)
    基于 LaTeX 排版规则估算
    """
    c = char_pos["char"]
    cx = char_pos["center_x"]
    cy = char_pos["center_y"]

    if char_pos["type"] == "tianzi":
        # 全角字符:约 76x76,居中
        w, h = 76, 76
        x = cx - w // 2
        y = cy - h // 2
    else:
        # 半角字符:根据字符类型估算
        char_type = get_char_type(c)
        if char_type in ("halfwidth_letter", "halfwidth_digit"):
            w, h = 38, 36
        elif c in "([{<":
            w, h = 38, 68
        elif c in ")]>}>":
            w, h = 38, 68
        elif c == "/":
            w, h = 38, 65
        elif c == "-":
            w, h = 38, 20
        else:
            w, h = 38, 50
        x = cx - w // 2
        y = cy - h // 2 + 3  # 基线偏移补偿

    return (x, y, w, h)


def bbox_to_rotation_points(x: int, y: int, w: int, h: int):
    """将 (x, y, w, h) 转为 rotation 格式的 4 个点"""
    return [
        [x, y],
        [x + w, y],
        [x + w, y + h],
        [x, y + h],
    ]


# ==================== JSON 生成 ====================
def generate_xanylabeling_json(img: Image.Image, text: str, char_positions: list) -> dict:
    """生成 x-anylabeling 格式的 JSON"""
    shapes = []

    for pos in char_positions:
        bx, by, bw, bh = estimate_char_bbox(pos, img.height)
        points = bbox_to_rotation_points(bx, by, bw, bh)

        shapes.append({
            "label": pos["char"],
            "points": points,
            "group_id": pos["index"],
            "description": f"char_{pos['index']}",
            "shape_type": "rotation",
            "flags": {},
            "attributes": {
                "char_index": pos["index"],
                "char_type": pos["type"],
            },
        })

    return {
        "version": "5.4.1",
        "flags": {},
        "shapes": shapes,
        "imagePath": None,
        "imageData": image_to_base64(img),
        "imageHeight": img.height,
        "imageWidth": img.width,
    }


# ==================== 主函数 ====================
def process_string(text: str, output_dir: str):
    """处理单个字符串,生成图片和标注"""
    print(f"\n处理: {text}")

    # 生成 LaTeX 代码
    tex_code, char_positions, total_width = generate_latex(text)
    print(f"  画布宽度: {total_width}pt, 字符数: {len(char_positions)}")

    # 安全文件名
    safe_name = "".join(
        c if c.isalnum() else f"_u{ord(c):04x}_" for c in text
    )

    # 编译 LaTeX
    prefix = os.path.join(output_dir, safe_name)
    png_path = compile_latex(tex_code, prefix)
    print(f"  已编译: {png_path}")

    # 加载图片
    img = load_png(png_path)
    print(f"  图片尺寸: {img.size}")

    # 生成 JSON 标注
    json_data = generate_xanylabeling_json(img, text, char_positions)

    # 保存 JSON
    json_path = f"{prefix}.json"
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(json_data, f, ensure_ascii=False, indent=2)
    print(f"  已保存标注: {json_path}")

    # 清理临时文件
    for ext in [".tex", ".pdf", ".aux", ".log"]:
        tmp = f"{prefix}{ext}"
        if os.path.exists(tmp):
            os.remove(tmp)

    return img, json_data


def main():
    # 读取字符串示例
    with open("字符串示例.md", "r", encoding="utf-8") as f:
        texts = [line.strip() for line in f if line.strip()]

    # 创建输出目录
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print(f"共 {len(texts)} 个字符串待生成")
    print(f"使用 LaTeX + Tectonic 进行精确排版")
    print(f"半角字符基线上移: {TIKZ_BASELINE_RAISE}pt")

    for text in texts:
        try:
            process_string(text, OUTPUT_DIR)
        except Exception as e:
            print(f"  错误: {e}")

    print(f"\n生成完成!输出目录: {OUTPUT_DIR}/")


if __name__ == "__main__":
    main()

generate_synthetic_data.py

#!/usr/bin/env python3
"""
ViT 字符级分割合成数据生成器
- LaTeX 精确排版(解决PIL基线问题)
- OpenCV 连通域检测精确边界框
- 全角字符:田字格排版(80x80)
- 半角字符:日字格排版(40x80)
- 输出 x-anylabeling 格式 JSON(imageData 存储)
"""

import json
import base64
import io
import os
import subprocess
import unicodedata

import cv2
import numpy as np
from PIL import Image, ImageDraw


# ==================== 配置 ====================
FONT_DIR = os.path.abspath("fonts")
OUTPUT_DIR = "output"


def get_char_type(c):
    cat = unicodedata.category(c)
    if "\u4e00" <= c <= "\u9fff":
        return "chinese"
    if "\uff01" <= c <= "\uff5e" or "\u3000" <= c <= "\u303f":
        return "fullwidth"
    if cat.startswith("P") or cat in ("Sm", "Sc", "Sk", "So"):
        return "halfwidth_punct"
    if cat.startswith("L"):
        return "halfwidth_letter"
    if cat.startswith("N"):
        return "halfwidth_digit"
    return "other"


def is_fullwidth_char(c):
    return get_char_type(c) in ("chinese", "fullwidth")


def generate_latex_tex(text):
    """生成LaTeX代码,每个字符独立scope"""
    cells = []
    x_offset = 0
    char_info = []
    
    for i, c in enumerate(text):
        if is_fullwidth_char(c):
            cell = (
                r"\begin{scope}[shift={(" + str(x_offset) + r",0)}]" + "\n"
                r"  \node[anchor=center, inner sep=0pt, outer sep=0pt] at (40,40) "
                r"{\fontsize{76}{76}\selectfont " + c + r"};" + "\n"
                r"\end{scope}"
            )
            cells.append(cell)
            char_info.append({
                "char": c, "index": i, "type": "tianzi",
                "x_offset": x_offset, "width": 80,
            })
            x_offset += 80
        else:
            cell = (
                r"\begin{scope}[shift={(" + str(x_offset) + r",0)}]" + "\n"
                r"  \node[anchor=center, inner sep=0pt, outer sep=0pt] at (20,40) "
                r"{\fontsize{76}{76}\selectfont\raisebox{3pt}{" + c + r"}};" + "\n"
                r"\end{scope}"
            )
            cells.append(cell)
            char_info.append({
                "char": c, "index": i, "type": "rizi",
                "x_offset": x_offset, "width": 40,
            })
            x_offset += 40
    
    tex = (
        r"\documentclass[border=0pt]{standalone}" + "\n"
        r"\usepackage{xeCJK}" + "\n"
        r"\usepackage{fontspec}" + "\n"
        r"\usepackage{tikz}" + "\n"
        r"\setCJKmainfont{simsun.ttf}[Path=" + FONT_DIR + r"/]" + "\n"
        r"\setmainfont{simsun.ttf}[Path=" + FONT_DIR + r"/]" + "\n"
        r"\begin{document}" + "\n"
        r"\begin{tikzpicture}[x=1pt, y=1pt]" + "\n"
        + "\n".join(cells) + "\n"
        r"\end{tikzpicture}" + "\n"
        r"\end{document}" + "\n"
    )
    return tex, char_info, x_offset


def compile_latex(tex_code, prefix):
    """编译LaTeX并转为PNG"""
    tex_path = prefix + ".tex"
    pdf_path = prefix + ".pdf"
    png_path = prefix + ".png"
    
    with open(tex_path, "w", encoding="utf-8") as f:
        f.write(tex_code)
    
    result = subprocess.run(["tectonic", tex_path], capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError("LaTeX编译失败: " + result.stderr)
    
    result = subprocess.run(
        ["sips", "-s", "format", "png", pdf_path, "--out", png_path],
        capture_output=True, text=True,
    )
    if result.returncode != 0:
        raise RuntimeError("PDF转PNG失败: " + result.stderr)
    
    return png_path


def detect_char_bboxes(img_path, char_info):
    """
    使用OpenCV连通域检测每个字符的精确bbox
    参考 drawMissingRoiOpenCvBlob.cs 算法
    """
    img = Image.open(img_path)
    
    # 处理透明背景
    if img.mode == "RGBA":
        background = Image.new("RGB", img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])
        img_rgb = background
    else:
        img_rgb = img.convert("RGB")
    
    # OpenCV处理
    img_cv = cv2.cvtColor(np.array(img_rgb), cv2.COLOR_RGB2BGR)
    gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
    
    # 二值化:THRESH_BINARY_INV 提取黑色文字
    _, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY_INV)
    
    # 形态学操作(3x3矩形核)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    
    # 对每个字符区域单独进行连通域分析
    results = []
    
    for info in char_info:
        x_start = info["x_offset"]
        x_end = x_start + info["width"]
        
        # 提取字符区域
        char_binary = binary[:, x_start:x_end]
        
        # 连通域分析(8连通)
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
            char_binary, connectivity=8
        )
        
        if num_labels <= 1:
            results.append(dict(info, bbox=(0, 0, 0, 0)))
            continue
        
        # 合并所有非背景连通域(处理多笔画字符)
        all_x, all_y, all_x_end, all_y_end = [], [], [], []
        
        for i in range(1, num_labels):
            x, y, w, h, area = stats[i]
            if area < 5:
                continue
            all_x.append(x)
            all_y.append(y)
            all_x_end.append(x + w)
            all_y_end.append(y + h)
        
        if not all_x:
            results.append(dict(info, bbox=(0, 0, 0, 0)))
            continue
        
        # 合并bbox(参考MergeOverlappingRegionsWithGap)
        merged_x = min(all_x)
        merged_y = min(all_y)
        merged_w = max(all_x_end) - merged_x
        merged_h = max(all_y_end) - merged_y
        
        # 转换为全局坐标
        global_bbox = (
            int(x_start + merged_x),
            int(merged_y),
            int(merged_w),
            int(merged_h),
        )
        
        results.append(dict(info, bbox=global_bbox))
    
    return img_rgb, results


def image_to_base64(img):
    buffer = io.BytesIO()
    img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def bbox_to_rotation_points(x, y, w, h):
    return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]


def generate_json(img, text, results):
    shapes = []
    for r in results:
        bx, by, bw, bh = r["bbox"]
        points = bbox_to_rotation_points(bx, by, bw, bh)
        shapes.append({
            "label": r["char"],
            "points": points,
            "group_id": r["index"],
            "description": "char_" + str(r["index"]),
            "shape_type": "rotation",
            "flags": {},
            "attributes": {
                "char_index": r["index"],
                "char_type": r["type"],
            },
        })
    
    return {
        "version": "5.4.1",
        "flags": {},
        "shapes": shapes,
        "imagePath": None,
        "imageData": image_to_base64(img),
        "imageHeight": img.height,
        "imageWidth": img.width,
    }


def process_string(text, output_dir):
    print("\n处理: " + text)
    
    # 生成LaTeX代码
    tex, char_info, total_width = generate_latex_tex(text)
    print("  画布宽度: " + str(total_width) + "pt")
    
    # 安全文件名
    safe_name = "".join(
        c if c.isalnum() else "_u" + format(ord(c), "04x") + "_" for c in text
    )
    prefix = os.path.join(output_dir, safe_name)
    
    # 编译LaTeX
    png_path = compile_latex(tex, prefix)
    
    # OpenCV连通域检测精确bbox
    img, results = detect_char_bboxes(png_path, char_info)
    print("  图片尺寸: " + str(img.size))
    
    for r in results:
        print("  \"" + r["char"] + "\": " + str(r["bbox"]))
    
    # 保存图片
    img.save(prefix + ".png")
    
    # 生成并保存JSON
    json_data = generate_json(img, text, results)
    with open(prefix + ".json", "w", encoding="utf-8") as f:
        json.dump(json_data, f, ensure_ascii=False, indent=2)
    
    print("  已保存: " + prefix + ".png, " + prefix + ".json")
    
    # 清理临时文件
    for ext in [".tex", ".pdf", ".aux", ".log"]:
        tmp = prefix + ext
        if os.path.exists(tmp):
            os.remove(tmp)
    
    return img, results


def main():
    with open("字符串示例.md", "r", encoding="utf-8") as f:
        texts = [line.strip() for line in f if line.strip()]
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    print("共 " + str(len(texts)) + " 个字符串")
    print("方案: LaTeX精确排版 + OpenCV连通域精确bbox")
    
    for text in texts:
        try:
            process_string(text, OUTPUT_DIR)
        except Exception as e:
            print("  错误: " + str(e))
            import traceback
            traceback.print_exc()
    
    print("\n完成! 输出: " + OUTPUT_DIR + "/")


if __name__ == "__main__":
    main()

postprocess_clean.py

#!/usr/bin/env python3
"""
后处理脚本:根据 rotation 标注框去除标注框外的残留黑色/灰色像素

功能:
- 读取 x-anylabeling 格式的 JSON 文件(包含 imageData)
- 解析 rotation 类型的标注框(四边形,但当前实现为轴对齐矩形)
- 将所有不在任何标注框内的非白色像素设为白色(255, 255, 255)
- 更新 imageData 并保存为新的 JSON 文件

用法:
    python postprocess_clean.py <input_json> [output_json]
    python postprocess_clean.py output/轴距_u0028_mm_u0029_.json
    python postprocess_clean.py output/  # 处理整个目录
"""

import argparse
import base64
import io
import json
import os
import sys

import numpy as np
from PIL import Image


def decode_image(image_data_b64: str) -> Image.Image:
    """从 Base64 字符串解码为 PIL Image"""
    img_bytes = base64.b64decode(image_data_b64)
    return Image.open(io.BytesIO(img_bytes))


def encode_image(img: Image.Image) -> str:
    """将 PIL Image 编码为 Base64 PNG 字符串"""
    buffer = io.BytesIO()
    img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def get_bbox_from_rotation_points(points):
    """
    从 rotation 类型的四边形点获取轴对齐 bbox。
    points: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
    返回: (x_min, y_min, x_max, y_max)
    """
    xs = [p[0] for p in points]
    ys = [p[1] for p in points]
    return int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))


def clean_stray_pixels(img: Image.Image, shapes: list, gray_threshold: int = 250) -> Image.Image:
    """
    去除所有标注框外的非白色像素。

    参数:
        img: PIL Image (RGB)
        shapes: x-anylabeling shapes 列表
        gray_threshold: 灰度阈值,低于此值的像素被视为"非白色"需要处理
    返回:
        清理后的 PIL Image
    """
    img_np = np.array(img)
    h, w = img_np.shape[:2]

    # 创建所有标注框的并集掩码
    bbox_mask = np.zeros((h, w), dtype=bool)

    for shape in shapes:
        if shape.get("shape_type") != "rotation":
            continue
        points = shape["points"]
        if len(points) != 4:
            continue
        x1, y1, x2, y2 = get_bbox_from_rotation_points(points)
        # 裁剪到图像边界
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(w, x2)
        y2 = min(h, y2)
        if x1 < x2 and y1 < y2:
            bbox_mask[y1:y2, x1:x2] = True

    # 计算灰度图
    if len(img_np.shape) == 3:
        gray = np.mean(img_np, axis=2)
    else:
        gray = img_np.astype(float)

    # 非白色像素掩码(低于阈值)
    non_white = gray < gray_threshold

    # 在标注框外且非白色的像素
    stray_mask = non_white & ~bbox_mask

    stray_count = np.count_nonzero(stray_mask)
    if stray_count == 0:
        print(f"  无需清理,没有残留像素")
        return img

    print(f"  发现 {stray_count} 个残留像素,正在清理...")

    # 将残留像素设为白色
    img_clean = img_np.copy()
    img_clean[stray_mask] = [255, 255, 255]

    return Image.fromarray(img_clean)


def process_json(input_path: str, output_path: str = None, gray_threshold: int = 250):
    """处理单个 JSON 文件"""
    print(f"处理: {input_path}")

    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # 解码图片
    img = decode_image(data["imageData"])
    print(f"  图片尺寸: {img.size}, 模式: {img.mode}")

    # 清理残留像素
    img_clean = clean_stray_pixels(img, data.get("shapes", []), gray_threshold)

    # 更新 imageData
    data["imageData"] = encode_image(img_clean)

    # 确定输出路径
    if output_path is None:
        base, ext = os.path.splitext(input_path)
        output_path = base + "_cleaned" + ext

    # 保存
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"  已保存: {output_path}")
    return output_path


def process_directory(input_dir: str, output_dir: str = None, gray_threshold: int = 250):
    """处理目录中的所有 JSON 文件"""
    json_files = [
        os.path.join(input_dir, f)
        for f in os.listdir(input_dir)
        if f.endswith(".json")
    ]

    if not json_files:
        print(f"目录中没有 JSON 文件: {input_dir}")
        return

    if output_dir is None:
        output_dir = input_dir
    os.makedirs(output_dir, exist_ok=True)

    print(f"找到 {len(json_files)} 个 JSON 文件")
    for json_path in sorted(json_files):
        basename = os.path.basename(json_path)
        name, ext = os.path.splitext(basename)
        out_path = os.path.join(output_dir, name + "_cleaned" + ext)
        try:
            process_json(json_path, out_path, gray_threshold)
        except Exception as e:
            print(f"  错误: {e}")
            import traceback
            traceback.print_exc()


def verify_cleaned(json_path: str, gray_threshold: int = 250):
    """验证清理后的 JSON 是否还有残留像素"""
    print(f"\n验证: {json_path}")

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    img = decode_image(data["imageData"])
    img_np = np.array(img)
    h, w = img_np.shape[:2]

    # 创建 bbox 掩码
    bbox_mask = np.zeros((h, w), dtype=bool)
    for shape in data.get("shapes", []):
        if shape.get("shape_type") != "rotation":
            continue
        points = shape["points"]
        if len(points) != 4:
            continue
        x1, y1, x2, y2 = get_bbox_from_rotation_points(points)
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(w, x2)
        y2 = min(h, y2)
        if x1 < x2 and y1 < y2:
            bbox_mask[y1:y2, x1:x2] = True

    gray = np.mean(img_np, axis=2)
    non_white = gray < gray_threshold
    stray_mask = non_white & ~bbox_mask
    stray_count = np.count_nonzero(stray_mask)

    if stray_count == 0:
        print(f"  ✓ 验证通过,无残留像素")
        return True
    else:
        print(f"  ✗ 仍有 {stray_count} 个残留像素!")
        ys, xs = np.where(stray_mask)
        for i in range(min(10, len(xs))):
            print(f"    ({xs[i]}, {ys[i]}): gray={gray[ys[i], xs[i]]:.0f}")
        return False


def main():
    parser = argparse.ArgumentParser(
        description="根据 rotation 标注框去除标注框外的残留黑色/灰色像素"
    )
    parser.add_argument("input", help="输入 JSON 文件或目录")
    parser.add_argument("-o", "--output", help="输出路径(文件或目录)")
    parser.add_argument(
        "-t", "--threshold", type=int, default=250,
        help="灰度阈值,低于此值的像素视为非白色(默认: 250)"
    )
    parser.add_argument(
        "-v", "--verify", action="store_true",
        help="处理后验证是否还有残留像素"
    )
    args = parser.parse_args()

    if os.path.isdir(args.input):
        process_directory(args.input, args.output, args.threshold)
        if args.verify:
            output_dir = args.output or args.input
            for f in sorted(os.listdir(output_dir)):
                if f.endswith("_cleaned.json"):
                    verify_cleaned(os.path.join(output_dir, f), args.threshold)
    else:
        output_path = process_json(args.input, args.output, args.threshold)
        if args.verify:
            verify_cleaned(output_path, args.threshold)


if __name__ == "__main__":
    main()

postprocess_augment.py

#!/usr/bin/env python3
"""
后处理扰动脚本:对 cleaned.json 进行亮度扰动和高斯模糊扰动

功能:
- 读取 x-anylabeling 格式的 cleaned JSON 文件(包含 imageData)
- 应用亮度扰动(Brightness):随机或指定增强系数
- 应用高斯模糊扰动(GaussianBlur):随机或指定模糊半径
- 保持标注框不变,仅修改 imageData
- 分别保存为新的 JSON 文件

用法:
    # 单文件处理(随机扰动参数)
    python postprocess_augment.py output/轴距_u0028_mm_u0029__cleaned.json

    # 批量目录处理
    python postprocess_augment.py output/ --suffix _cleaned

    # 指定亮度范围和高斯模糊半径
    python postprocess_augment.py output/ --suffix _cleaned \
        --brightness 0.8 1.2 --blur 0.5 1.5

    # 固定参数(非随机)
    python postprocess_augment.py output/ --suffix _cleaned \
        --brightness 1.1 1.1 --blur 1.0 1.0
"""

import argparse
import base64
import io
import json
import os
import random
import sys

import numpy as np
from PIL import Image, ImageEnhance, ImageFilter


def decode_image(image_data_b64: str) -> Image.Image:
    """从 Base64 字符串解码为 PIL Image"""
    img_bytes = base64.b64decode(image_data_b64)
    return Image.open(io.BytesIO(img_bytes))


def encode_image(img: Image.Image) -> str:
    """将 PIL Image 编码为 Base64 PNG 字符串"""
    buffer = io.BytesIO()
    img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def apply_brightness(img: Image.Image, factor: float) -> Image.Image:
    """
    应用亮度扰动

    参数:
        img: PIL Image
        factor: 亮度系数,>1 变亮,<1 变暗
    返回:
        扰动后的 PIL Image
    """
    enhancer = ImageEnhance.Brightness(img)
    return enhancer.enhance(factor)


def apply_gaussian_blur(img: Image.Image, radius: float) -> Image.Image:
    """
    应用高斯模糊扰动

    参数:
        img: PIL Image
        radius: 高斯模糊半径(像素)
    返回:
        扰动后的 PIL Image
    """
    return img.filter(ImageFilter.GaussianBlur(radius=radius))


def process_single_augmentation(
    data: dict,
    aug_type: str,
    param: float,
) -> dict:
    """
    对单个 JSON 数据应用一种扰动

    参数:
        data: 原始 JSON 数据
        aug_type: 扰动类型,"brightness" 或 "blur"
        param: 扰动参数值
    返回:
        新的 JSON 数据(深拷贝,仅修改 imageData)
    """
    # 深拷贝,避免修改原始数据
    new_data = json.loads(json.dumps(data))

    # 解码图片
    img = decode_image(data["imageData"])

    # 应用扰动
    if aug_type == "brightness":
        img_aug = apply_brightness(img, param)
        aug_label = f"brightness_{param:.2f}"
    elif aug_type == "blur":
        img_aug = apply_gaussian_blur(img, param)
        aug_label = f"blur_{param:.2f}"
    else:
        raise ValueError(f"未知的扰动类型: {aug_type}")

    # 更新 imageData
    new_data["imageData"] = encode_image(img_aug)

    # 在 flags 中记录扰动信息
    if "flags" not in new_data:
        new_data["flags"] = {}
    new_data["flags"]["augmentation"] = aug_label
    new_data["flags"]["augmentation_type"] = aug_type
    new_data["flags"]["augmentation_param"] = param

    return new_data


def generate_augmented_json(
    input_path: str,
    output_dir: str = None,
    brightness_range: tuple = (0.85, 1.15),
    blur_radius_range: tuple = (0.5, 1.5),
    seed: int = None,
) -> list:
    """
    对单个 JSON 文件生成两种扰动版本

    参数:
        input_path: 输入 JSON 文件路径
        output_dir: 输出目录(默认与输入同目录)
        brightness_range: 亮度系数范围 (min, max)
        blur_radius_range: 高斯模糊半径范围 (min, max)
        seed: 随机种子(可选)
    返回:
        生成的输出文件路径列表
    """
    if seed is not None:
        random.seed(seed)

    # 读取原始数据
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    basename = os.path.basename(input_path)
    name, ext = os.path.splitext(basename)
    # 去掉 _cleaned 后缀(如果存在)
    if name.endswith("_cleaned"):
        base_name = name[:-8]
    else:
        base_name = name

    if output_dir is None:
        output_dir = os.path.dirname(input_path) or "."
    os.makedirs(output_dir, exist_ok=True)

    output_paths = []

    # 1. 亮度扰动
    brightness_factor = random.uniform(*brightness_range)
    bright_data = process_single_augmentation(data, "brightness", brightness_factor)
    bright_path = os.path.join(output_dir, f"{base_name}_brightness{ext}")
    with open(bright_path, "w", encoding="utf-8") as f:
        json.dump(bright_data, f, ensure_ascii=False, indent=2)
    output_paths.append(bright_path)
    print(f"  亮度扰动 (factor={brightness_factor:.3f}) -> {bright_path}")

    # 2. 高斯模糊扰动
    blur_radius = random.uniform(*blur_radius_range)
    blur_data = process_single_augmentation(data, "blur", blur_radius)
    blur_path = os.path.join(output_dir, f"{base_name}_blur{ext}")
    with open(blur_path, "w", encoding="utf-8") as f:
        json.dump(blur_data, f, ensure_ascii=False, indent=2)
    output_paths.append(blur_path)
    print(f"  高斯模糊 (radius={blur_radius:.3f}) -> {blur_path}")

    return output_paths


def process_directory(
    input_dir: str,
    suffix: str = "_cleaned",
    output_dir: str = None,
    brightness_range: tuple = (0.85, 1.15),
    blur_radius_range: tuple = (0.5, 1.5),
    seed: int = None,
):
    """
    批量处理目录中的 JSON 文件

    参数:
        input_dir: 输入目录
        suffix: 文件名后缀过滤(如 _cleaned)
        output_dir: 输出目录(默认与输入同目录)
        brightness_range: 亮度系数范围
        blur_radius_range: 高斯模糊半径范围
        seed: 随机种子
    """
    json_files = [
        os.path.join(input_dir, f)
        for f in os.listdir(input_dir)
        if f.endswith(".json") and (not suffix or f.endswith(f"{suffix}.json"))
    ]

    if not json_files:
        print(f"目录中没有匹配的 JSON 文件: {input_dir} (后缀: {suffix})")
        return

    print(f"找到 {len(json_files)} 个匹配文件")

    for json_path in sorted(json_files):
        print(f"\n处理: {json_path}")
        try:
            generate_augmented_json(
                json_path,
                output_dir=output_dir,
                brightness_range=brightness_range,
                blur_radius_range=blur_radius_range,
                seed=seed,
            )
        except Exception as e:
            print(f"  错误: {e}")
            import traceback
            traceback.print_exc()


def verify_augmentation(json_path: str, original_path: str = None):
    """
    验证扰动后的 JSON 文件

    参数:
        json_path: 扰动后的 JSON 文件路径
        original_path: 原始 JSON 文件路径(可选,用于对比)
    """
    print(f"\n验证: {json_path}")

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # 检查 flags
    flags = data.get("flags", {})
    aug_type = flags.get("augmentation_type", "unknown")
    aug_param = flags.get("augmentation_param", "unknown")
    print(f"  扰动类型: {aug_type}, 参数: {aug_param}")

    # 解码图片
    img = decode_image(data["imageData"])
    print(f"  图片尺寸: {img.size}, 模式: {img.mode}")

    # 如果有原始图片,计算差异
    if original_path and os.path.exists(original_path):
        with open(original_path, "r", encoding="utf-8") as f:
            orig_data = json.load(f)
        orig_img = decode_image(orig_data["imageData"])

        orig_np = np.array(orig_img).astype(float)
        aug_np = np.array(img).astype(float)

        diff = np.abs(orig_np - aug_np)
        max_diff = diff.max()
        mean_diff = diff.mean()

        print(f"  与原始图片对比:")
        print(f"    最大像素差异: {max_diff:.2f}")
        print(f"    平均像素差异: {mean_diff:.2f}")

        if max_diff < 1:
            print(f"  ⚠ 警告: 像素差异很小,扰动可能未生效")
        else:
            print(f"  ✓ 扰动已生效")
    else:
        print(f"  ✓ 文件格式正确")


def main():
    parser = argparse.ArgumentParser(
        description="对 cleaned.json 进行亮度扰动和高斯模糊扰动"
    )
    parser.add_argument("input", help="输入 JSON 文件或目录")
    parser.add_argument(
        "-o", "--output", help="输出目录(默认与输入同目录)"
    )
    parser.add_argument(
        "--suffix", default="_cleaned",
        help="目录模式下过滤文件名后缀(默认: _cleaned)"
    )
    parser.add_argument(
        "--brightness", nargs=2, type=float, default=[0.85, 1.15],
        metavar=("MIN", "MAX"),
        help="亮度系数范围(默认: 0.85 1.15)"
    )
    parser.add_argument(
        "--blur", nargs=2, type=float, default=[0.5, 1.5],
        metavar=("MIN", "MAX"),
        help="高斯模糊半径范围(默认: 0.5 1.5)"
    )
    parser.add_argument(
        "--seed", type=int, default=None,
        help="随机种子(用于可复现)"
    )
    parser.add_argument(
        "-v", "--verify", action="store_true",
        help="处理后验证"
    )
    args = parser.parse_args()

    brightness_range = tuple(args.brightness)
    blur_radius_range = tuple(args.blur)

    if os.path.isdir(args.input):
        process_directory(
            args.input,
            suffix=args.suffix,
            output_dir=args.output,
            brightness_range=brightness_range,
            blur_radius_range=blur_radius_range,
            seed=args.seed,
        )

        if args.verify:
            output_dir = args.output or args.input
            for f in sorted(os.listdir(output_dir)):
                if "_brightness" in f or "_blur" in f:
                    aug_path = os.path.join(output_dir, f)
                    # 推断原始文件路径
                    base = f.replace("_brightness", "").replace("_blur", "")
                    orig_path = os.path.join(output_dir, base)
                    verify_augmentation(aug_path, orig_path)
    else:
        output_paths = generate_augmented_json(
            args.input,
            output_dir=args.output,
            brightness_range=brightness_range,
            blur_radius_range=blur_radius_range,
            seed=args.seed,
        )

        if args.verify:
            for p in output_paths:
                verify_augmentation(p, args.input)


if __name__ == "__main__":
    main()

效果图

合成图片 合成标注框
ABS型号_u002f_生产企业 截屏2026-05-04 21.38
posted @ 2026-05-04 22:02  qsBye  阅读(10)  评论(0)    收藏  举报