AI学习 - 诊断结论信息抽取 - 数据格式转换BERT训练格式
Label Studio导出的数据需要转换成BERT训练格式。
代码
# CMD 运行时,先执行 $env:HF_ENDPOINT = "https://hf-mirror.com"
import json
from transformers import AutoTokenizer
from collections import Counter
def convert_label_studio_to_bert(label_studio_json, tokenizer, label2id, max_length=256):
"""
将Label Studio格式转为BERT训练格式(针对中文医疗文本)
Args:
label_studio_json: Label Studio导出的JSON数据
tokenizer: BERT tokenizer实例(推荐使用中文BERT)
label2id: 标签到ID的映射字典
max_length: 最大序列长度
Returns:
List[dict]: BERT训练格式的数据列表
"""
bert_samples = []
stats = Counter() # 统计标签使用情况
for item_idx, item in enumerate(label_studio_json):
text = item["data"]["text"]
# 过滤出实体标注,排除关系标注
entity_annotations = []
for ann in item["annotations"][0]["result"]:
if "value" in ann and "labels" in ann["value"]:
entity_annotations.append(ann)
print(f"\n处理第 {item_idx + 1} 个样本,文本长度: {len(text)}")
print(f"实体标注数量: {len(entity_annotations)}")
# 1. 进行tokenization
encoding = tokenizer(
text,
return_offsets_mapping=True, # 获取字符到token的映射
truncation=True,
max_length=max_length,
padding="max_length"
)
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
offset_mapping = encoding["offset_mapping"]
# 2. 初始化所有标签为"O"
labels = ["O"] * len(tokens)
# 3. 按照起始位置排序标注,确保按顺序处理
sorted_annotations = sorted(entity_annotations, key=lambda x: x["value"]["start"])
# 4. 将字符级标注转为token级标注
for ann_idx, ann in enumerate(sorted_annotations):
start_char = ann["value"]["start"]
end_char = ann["value"]["end"]
label_type = ann["value"]["labels"][0]
ann_text = ann["value"]["text"]
# 找出这个实体覆盖的所有token索引
entity_token_indices = []
for token_idx, (token_start, token_end) in enumerate(offset_mapping):
# 跳过特殊token(CLS, SEP, PAD)
if token_start == token_end:
continue
# 判断token是否在实体范围内(完全包含或部分重叠)
token_in_entity = (token_start >= start_char and token_end <= end_char)
token_overlaps = (token_start < end_char and token_end > start_char)
if token_in_entity or token_overlaps:
entity_token_indices.append(token_idx)
if not entity_token_indices:
print(f" 警告: 标注 '{ann_text}' (位置 {start_char}-{end_char}) 没有匹配到任何token")
continue
# 对实体内的token进行排序
entity_token_indices.sort()
# 分配BIO标签
for idx, token_idx in enumerate(entity_token_indices):
# 如果这个token已经被其他实体标注了,跳过(避免重叠)
if labels[token_idx] != "O":
continue
if idx == 0: # 实体第一个token使用B-前缀
bio_label = f"B-{label_type}"
else: # 实体内部token使用I-前缀
bio_label = f"I-{label_type}"
labels[token_idx] = bio_label
stats[bio_label] += 1
# 5. 处理特殊token(CLS, SEP, PAD)的标签
for i in range(len(tokens)):
if tokens[i] in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
labels[i] = "O"
# 6. 将标签转换为ID
label_ids = []
unknown_labels = []
for l in labels:
if l in label2id:
label_ids.append(label2id[l])
else:
# 未知标签处理
if l != "O": # 只打印非O的未知标签
unknown_labels.append(l)
label_ids.append(label2id["O"]) # 默认使用O标签
if unknown_labels:
print(f" 警告: 发现 {len(set(unknown_labels))} 个未知标签: {set(unknown_labels)}")
# 7. 构建样本
bert_samples.append({
"input_ids": encoding["input_ids"],
"attention_mask": encoding["attention_mask"],
"labels": label_ids,
"tokens": tokens, # 用于调试
"text": text[:100] + "..." if len(text) > 100 else text # 截断长文本用于显示
})
# 打印统计信息
print(f"\n=== 转换统计 ===")
print(f"总样本数: {len(bert_samples)}")
print(f"标签使用统计:")
for label, count in sorted(stats.items(), key=lambda x: x[1], reverse=True):
print(f" {label}: {count}")
return bert_samples, stats
def extract_entity_types(label_studio_json):
"""
从Label Studio数据中提取所有实体类型,用于自动生成label2id
"""
all_labels = set()
for item in label_studio_json:
for ann in item["annotations"][0]["result"]:
if "value" in ann and "labels" in ann["value"]:
label = ann["value"]["labels"][0]
all_labels.add(label)
print("发现的所有实体类型:")
for label in sorted(all_labels):
print(f" - {label}")
return sorted(all_labels)
def create_label2id(entity_types):
"""
根据实体类型创建BIO标签映射
"""
label2id = {"O": 0}
# 为每个实体类型创建B-和I-标签
for i, entity_type in enumerate(entity_types, 1):
"""
# 指标名称相关
"B-指标名称": 1,
"I-指标名称": 2,
# 数值相关
"B-数值": 3,
"I-数值": 4,
"""
label2id[f"B-{entity_type}"] = i * 2 - 1
label2id[f"I-{entity_type}"] = i * 2
return label2id
def validate_alignment(sample, tokenizer, id2label):
"""
验证token-label对齐是否正确
"""
tokens = sample["tokens"]
labels = [id2label[l] for l in sample["labels"]]
print("\n=== 对齐验证 ===")
print("显示前50个token的标注:")
# 按行显示,每行显示5个token
for i in range(0, min(50, len(tokens)), 5):
line_tokens = tokens[i:i + 5]
line_labels = labels[i:i + 5]
# 显示token
token_line = " | ".join([f"{t:^10}" for t in line_tokens])
print(f"Token: {token_line}")
# 显示label
label_line = " | ".join([f"{l:^10}" for l in line_labels])
print(f"Label: {label_line}")
print("-" * 60)
def save_data_without_text(bert_data, output_path):
"""
保存转换后的数据(不包含tokens和text以减小文件大小)
"""
data_to_save = []
for sample in bert_data:
data_to_save.append({
"input_ids": sample["input_ids"],
"attention_mask": sample["attention_mask"],
"labels": sample["labels"]
})
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data_to_save, f, indent=2, ensure_ascii=False)
def main():
# 1. 加载Label Studio数据
print(f"加载Label Studio数据从: {input_file}")
with open(input_file, "r", encoding="utf-8") as f:
label_studio_data = json.load(f)
print(f"加载了 {len(label_studio_data)} 个样本")
# 2. 提取所有实体类型并创建标签映射
entity_types = extract_entity_types(label_studio_data)
label2id = create_label2id(entity_types)
id2label = {v: k for k, v in label2id.items()}
print(f"\n生成的label2id映射 ({len(label2id)} 个标签):")
for label, label_id in sorted(label2id.items(), key=lambda x: x[1]):
print(f" {label}: {label_id}")
# 3. 加载中文BERT tokenizer
print("\n加载tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
print("tokenizer加载成功")
except Exception as e:
print(f"tokenizer加载失败: {e}")
# 尝试使用备用模型
print("尝试使用hfl/chinese-roberta-wwm-ext...")
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
# 4. 转换数据
print("\n开始转换数据...")
bert_data, stats = convert_label_studio_to_bert(
label_studio_data,
tokenizer,
label2id,
max_length=256
)
# 5. 保存转换后的数据
print(f"\n保存转换后的数据到: {output_file}")
save_data_without_text(bert_data, output_file)
# 6. 验证第一个样本的对齐
if bert_data:
print("\n验证第一个样本的token-label对齐:")
validate_alignment(bert_data[0], tokenizer, id2label)
# 显示原始文本和标注的对应关系
print("\n=== 原始文本与标注对应关系 ===")
text = label_studio_data[0]["data"]["text"]
tokens = bert_data[0]["tokens"]
labels = [id2label[l] for l in bert_data[0]["labels"]]
# 找到所有非O的标注
entity_spans = []
current_entity = None
start_idx = -1
for i, (token, label) in enumerate(zip(tokens, labels)):
if label.startswith("B-"):
# 开始新实体
if current_entity:
entity_spans.append((start_idx, i - 1, current_entity))
current_entity = label[2:] # 去掉B-前缀
start_idx = i
elif label == "O" and current_entity:
# 实体结束
entity_spans.append((start_idx, i - 1, current_entity))
current_entity = None
# 处理最后一个实体
if current_entity:
entity_spans.append((start_idx, len(tokens) - 1, current_entity))
print(f"文本长度: {len(text)} 字符")
print(f"Token数量: {len(tokens)}")
print(f"发现的实体数量: {len(entity_spans)}")
# 显示每个实体
for start, end, entity_type in entity_spans:
entity_tokens = tokens[start:end + 1]
entity_text = "".join([t.replace("##", "") for t in entity_tokens if
t not in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]])
print(f" [{start:3}-{end:3}] {entity_type:10}: '{entity_text}'")
# 7. 生成标签说明文件
labels_info_dict = {
"label2id": label2id,
"id2label": {str(k): v for k, v in id2label.items()}, # 确保键是字符串
"entity_types": entity_types,
"total_samples": len(bert_data),
# "model_used": "bert-base-chinese"
"model_used": tokenizer.name_or_path,
"label_stats": dict(stats) # 添加标签统计
}
with open(labels_info_file, "w", encoding="utf-8") as f:
json.dump(labels_info_dict, f, indent=2, ensure_ascii=False, default=str)
print(f"\n标签信息已保存到: {labels_info_file}")
# 8. 保存详细的调试信息(可选)
debug_info = {
"samples": []
}
for idx, sample in enumerate(bert_data):
if idx < 3: # 只保存前3个样本的详细信息用于调试
debug_sample = {
"text": label_studio_data[idx]["data"]["text"][:200] + "..." if len(label_studio_data[idx]["data"]["text"]) > 200 else label_studio_data[idx]["data"]["text"],
"tokens": sample["tokens"],
"labels": [id2label[l] for l in sample["labels"]],
"input_ids_length": len(sample["input_ids"]),
"non_o_labels": sum(1 for l in sample["labels"] if id2label[l] != "O")
}
debug_info["samples"].append(debug_sample)
with open(debug_file, "w", encoding="utf-8") as f:
json.dump(debug_info, f, indent=2, ensure_ascii=False)
print(f"调试信息已保存到: {debug_file}")
print(f"\n转换完成!")
print(f"原始样本数: {len(label_studio_data)}")
print(f"转换后样本数: {len(bert_data)}")
print(f"标签数量: {len(label2id)}")
input_file = r"D:\OpenSource\Python\VipPython\information_extraction\data\label_studio_export.json"
output_file = r"D:\OpenSource\Python\VipPython\information_extraction\data\bert_training_data.json"
labels_info_file = r"D:\OpenSource\Python\VipPython\information_extraction\data\labels_info.json"
debug_file = r"D:\OpenSource\Python\VipPython\information_extraction\data\datadebug_info.json"
if __name__ == "__main__":
main()
运行
(vippython) PS D:\OpenSource\Python\VipPython> $env:HF_ENDPOINT = "https://hf-mirror.com"
(vippython) PS D:\OpenSource\Python\VipPython> uv run .\information_extraction\label_studio_to_bert.py

本文来自博客园,作者:VipSoft 转载请注明原文链接:https://www.cnblogs.com/vipsoft/p/19506768
浙公网安备 33010602011771号