BM25_test_jieba

import asyncio
import jieba
import logging
from typing import Any, List
from rank_bm25 import BM25Okapi

# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# 停用词列表(可根据业务扩展)
STOP_WORDS = {
    "的", "了", "在", "是", "我", "你", "他", "她", "它", "们", "就", "都", "而", "及",
    "与", "之", "于", "也", "又", "还", "个", "对", "对于", "关于", "把", "被", "为",
    "着", "过", "这", "那", "此", "彼", "和", "或", "如果", "那么", "因为", "所以"
}

class BM25Reranker:
    """基于 BM25 + jieba 分词的文本重排序类"""
    
    def __init__(self):
        # 初始化 jieba(可选:加载自定义词典)
        # jieba.load_userdict("custom_dict.txt")  # 如有业务专属词汇,可加载
        pass

    async def _cut_text(self, text: str) -> List[str]:
        """
        异步分词方法:jieba 分词 + 停用词过滤
        :param text: 待分词的文本
        :return: 过滤后的分词列表
        """
        if not isinstance(text, str) or text.strip() == "":
            logger.warning("空文本,返回空分词列表")
            return []
        
        # jieba 是同步操作,用 to_thread 封装为异步,避免阻塞事件循环
        words = await asyncio.to_thread(jieba.lcut, text.strip())
        # 过滤停用词、空字符串、纯数字
        filtered_words = [
            word for word in words 
            if word not in STOP_WORDS and word.strip() != "" and not word.isdigit()
        ]
        return filtered_words

    async def rerank(
        self,
        text: str,
        corpus: List[str],
        documents: List[Any],
        topK: int = 10
    ) -> List[Any]:
        """
        基于 BM25 的文本重排序,返回相关性最高的前 topK 个原始文档
        :param text: 查询文本
        :param corpus: 待检索的语料库(与 documents 一一对应)
        :param documents: 原始文档列表(需与 corpus 严格一一对应)
        :param topK: 返回的结果数量
        :return: 排序后的原始文档列表
        """
        # 异常处理:参数校验
        if not text.strip():
            logger.error("查询文本不能为空")
            return []
        if len(corpus) != len(documents):
            logger.error(f"语料库长度({len(corpus)})与文档列表长度({len(documents)})不匹配")
            return []
        if topK <= 0:
            logger.warning("topK 需大于 0,已重置为 10")
            topK = 10

        try:
            # 1. 对查询文本分词
            tokenized_query = await self._cut_text(text)
            if not tokenized_query:
                logger.warning("查询文本分词后为空,返回空列表")
                return []
            
            # 2. 对语料库批量分词
            tokenized_corpus = []
            for cor in corpus:
                words = await self._cut_text(cor)
                tokenized_corpus.append(words)
            
            # 过滤分词后为空的语料(避免 BM25 计算报错)
            # 同时保留 corpus 和 documents 的对应关系
            valid_pairs = [
                (tok_cor, doc) 
                for tok_cor, doc in zip(tokenized_corpus, documents) 
                if tok_cor
            ]
            if not valid_pairs:
                logger.warning("语料库分词后无有效内容,返回空列表")
                return []
            valid_corpus, valid_docs = zip(*valid_pairs)
            
            # 3. 初始化 BM25 模型并计算 TopK
            bm25 = BM25Okapi(valid_corpus)
            top_n_docs = bm25.get_top_n(tokenized_query, valid_docs, n=topK)
            
            logger.info(f"重排序完成,返回前 {len(top_n_docs)} 个结果")
            return top_n_docs

        except Exception as e:
            logger.error(f"重排序过程出错:{str(e)}", exc_info=True)
            return []

# ------------------------------ 测试代码 ------------------------------
async def test_reranker():
    """测试 BM25Reranker 的 rerank 方法"""
    # 1. 初始化重排序器
    reranker = BM25Reranker()
    
    # 2. 测试数据
    # 查询文本
    query_text = "查询工厂组织联机情况"
    # 语料库(工具描述,与下面的 documents 一一对应)
    corpus_list = [
        "根据组织名称查询组织编码",
        "根据组织编码查询组织设备联机信息",
        "查询设备运行状态和故障原因",
        "工厂人员考勤管理系统",
        "联机设备的实时数据采集",
        ""  # 空文本(测试异常场景)
    ]
    # 原始文档(工具完整信息)
    documents_list = [
        {"tool_code": "AG202508220001", "tool_name": "根据组织名称查询组织编码", "tool_desc": "根据组织名称查询组织编码"},
        {"tool_code": "IF20250729000006", "tool_name": "根据组织编码查询组织设备联机信息", "tool_desc": "根据组织编码查询组织设备联机信息"},
        {"tool_code": "ST202508010001", "tool_name": "查询设备运行状态", "tool_desc": "查询设备运行状态和故障原因"},
        {"tool_code": "HR202508020001", "tool_name": "工厂人员管理", "tool_desc": "工厂人员考勤管理系统"},
        {"tool_code": "DC202508030001", "tool_name": "设备数据采集", "tool_desc": "联机设备的实时数据采集"},
        {"tool_code": "EMPTY001", "tool_name": "空文本测试", "tool_desc": ""}  # 对应空语料
    ]
    
    # 3. 调用重排序方法(返回前3个结果)
    top3_docs = await reranker.rerank(query_text, corpus_list, documents_list, topK=3)
    
    # 4. 打印测试结果
    print("\n===== 重排序结果(Top3) =====")
    for idx, doc in enumerate(top3_docs, 1):
        print(f"\n{idx}. 工具编码:{doc['tool_code']}")
        print(f"   工具名称:{doc['tool_name']}")
        print(f"   工具描述:{doc['tool_desc']}")

if __name__ == "__main__":
    # 安装依赖(首次运行需执行)
    # pip install jieba rank_bm25
    
    # 运行测试
    asyncio.run(test_reranker())
    
 
重排序完成,返回前 3 个结果

===== 重排序结果(Top3) =====

1. 工具编码:HR202508020001
   工具名称:工厂人员管理
   工具描述:工厂人员考勤管理系统

2. 工具编码:IF20250729000006
   工具名称:根据组织编码查询组织设备联机信息
   工具描述:根据组织编码查询组织设备联机信息

3. 工具编码:AG202508220001
   工具名称:根据组织名称查询组织编码
   工具描述:根据组织名称查询组织编码

 

posted @ 2026-01-20 16:14  BlogMemory  阅读(1)  评论(0)    收藏  举报