import asyncio
import jieba
import logging
from typing import Any, List
from rank_bm25 import BM25Okapi
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 停用词列表(可根据业务扩展)
STOP_WORDS = {
"的", "了", "在", "是", "我", "你", "他", "她", "它", "们", "就", "都", "而", "及",
"与", "之", "于", "也", "又", "还", "个", "对", "对于", "关于", "把", "被", "为",
"着", "过", "这", "那", "此", "彼", "和", "或", "如果", "那么", "因为", "所以"
}
class BM25Reranker:
"""基于 BM25 + jieba 分词的文本重排序类"""
def __init__(self):
# 初始化 jieba(可选:加载自定义词典)
# jieba.load_userdict("custom_dict.txt") # 如有业务专属词汇,可加载
pass
async def _cut_text(self, text: str) -> List[str]:
"""
异步分词方法:jieba 分词 + 停用词过滤
:param text: 待分词的文本
:return: 过滤后的分词列表
"""
if not isinstance(text, str) or text.strip() == "":
logger.warning("空文本,返回空分词列表")
return []
# jieba 是同步操作,用 to_thread 封装为异步,避免阻塞事件循环
words = await asyncio.to_thread(jieba.lcut, text.strip())
# 过滤停用词、空字符串、纯数字
filtered_words = [
word for word in words
if word not in STOP_WORDS and word.strip() != "" and not word.isdigit()
]
return filtered_words
async def rerank(
self,
text: str,
corpus: List[str],
documents: List[Any],
topK: int = 10
) -> List[Any]:
"""
基于 BM25 的文本重排序,返回相关性最高的前 topK 个原始文档
:param text: 查询文本
:param corpus: 待检索的语料库(与 documents 一一对应)
:param documents: 原始文档列表(需与 corpus 严格一一对应)
:param topK: 返回的结果数量
:return: 排序后的原始文档列表
"""
# 异常处理:参数校验
if not text.strip():
logger.error("查询文本不能为空")
return []
if len(corpus) != len(documents):
logger.error(f"语料库长度({len(corpus)})与文档列表长度({len(documents)})不匹配")
return []
if topK <= 0:
logger.warning("topK 需大于 0,已重置为 10")
topK = 10
try:
# 1. 对查询文本分词
tokenized_query = await self._cut_text(text)
if not tokenized_query:
logger.warning("查询文本分词后为空,返回空列表")
return []
# 2. 对语料库批量分词
tokenized_corpus = []
for cor in corpus:
words = await self._cut_text(cor)
tokenized_corpus.append(words)
# 过滤分词后为空的语料(避免 BM25 计算报错)
# 同时保留 corpus 和 documents 的对应关系
valid_pairs = [
(tok_cor, doc)
for tok_cor, doc in zip(tokenized_corpus, documents)
if tok_cor
]
if not valid_pairs:
logger.warning("语料库分词后无有效内容,返回空列表")
return []
valid_corpus, valid_docs = zip(*valid_pairs)
# 3. 初始化 BM25 模型并计算 TopK
bm25 = BM25Okapi(valid_corpus)
top_n_docs = bm25.get_top_n(tokenized_query, valid_docs, n=topK)
logger.info(f"重排序完成,返回前 {len(top_n_docs)} 个结果")
return top_n_docs
except Exception as e:
logger.error(f"重排序过程出错:{str(e)}", exc_info=True)
return []
# ------------------------------ 测试代码 ------------------------------
async def test_reranker():
"""测试 BM25Reranker 的 rerank 方法"""
# 1. 初始化重排序器
reranker = BM25Reranker()
# 2. 测试数据
# 查询文本
query_text = "查询工厂组织联机情况"
# 语料库(工具描述,与下面的 documents 一一对应)
corpus_list = [
"根据组织名称查询组织编码",
"根据组织编码查询组织设备联机信息",
"查询设备运行状态和故障原因",
"工厂人员考勤管理系统",
"联机设备的实时数据采集",
"" # 空文本(测试异常场景)
]
# 原始文档(工具完整信息)
documents_list = [
{"tool_code": "AG202508220001", "tool_name": "根据组织名称查询组织编码", "tool_desc": "根据组织名称查询组织编码"},
{"tool_code": "IF20250729000006", "tool_name": "根据组织编码查询组织设备联机信息", "tool_desc": "根据组织编码查询组织设备联机信息"},
{"tool_code": "ST202508010001", "tool_name": "查询设备运行状态", "tool_desc": "查询设备运行状态和故障原因"},
{"tool_code": "HR202508020001", "tool_name": "工厂人员管理", "tool_desc": "工厂人员考勤管理系统"},
{"tool_code": "DC202508030001", "tool_name": "设备数据采集", "tool_desc": "联机设备的实时数据采集"},
{"tool_code": "EMPTY001", "tool_name": "空文本测试", "tool_desc": ""} # 对应空语料
]
# 3. 调用重排序方法(返回前3个结果)
top3_docs = await reranker.rerank(query_text, corpus_list, documents_list, topK=3)
# 4. 打印测试结果
print("\n===== 重排序结果(Top3) =====")
for idx, doc in enumerate(top3_docs, 1):
print(f"\n{idx}. 工具编码:{doc['tool_code']}")
print(f" 工具名称:{doc['tool_name']}")
print(f" 工具描述:{doc['tool_desc']}")
if __name__ == "__main__":
# 安装依赖(首次运行需执行)
# pip install jieba rank_bm25
# 运行测试
asyncio.run(test_reranker())
重排序完成,返回前 3 个结果
===== 重排序结果(Top3) =====
1. 工具编码:HR202508020001
工具名称:工厂人员管理
工具描述:工厂人员考勤管理系统
2. 工具编码:IF20250729000006
工具名称:根据组织编码查询组织设备联机信息
工具描述:根据组织编码查询组织设备联机信息
3. 工具编码:AG202508220001
工具名称:根据组织名称查询组织编码
工具描述:根据组织名称查询组织编码