langchain检索器比较

LangChain 检索器比较与使用指南

是的,LangChain 提供了多种检索器,每种都有不同的特点和适用场景。下面我为你介绍几种常用的检索器,并展示它们之间的差异。

常用检索器类型及代码示例

# retrieval_comparison.py
import os
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.retrievers import (
    BM25Retriever,
    EnsembleRetriever
)
from langchain.retrievers.multi_query import MultiQueryRetriever
import numpy as np

class RetrievalComparison:
    def __init__(self):
        # 初始化组件
        self.llm = ChatOpenAI(
            model_name="xop3qwen1b7",
            openai_api_base="https://maas-api.cn-huabei-1.xf-yun.com/v1",
            openai_api_key="sk-kBa9GlpIxpWX2...",
            temperature=0.1,
            max_tokens=1024
        )
        
        self.embeddings = HuggingFaceEmbeddings(
            model_name="GanymedeNil/text2vec-large-chinese",
            model_kwargs={'device': 'cpu'}
        )
        
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=100
        )
        
        self.vectorstore = None
        self.documents = None
    
    def load_and_process_documents(self, file_paths):
        """加载和处理文档"""
        documents = []
        for file_path in file_paths:
            loader = TextLoader(file_path, encoding='utf-8')
            docs = loader.load()
            documents.extend(docs)
        
        self.documents = self.text_splitter.split_documents(documents)
        
        # 创建向量存储
        self.vectorstore = Chroma.from_documents(
            documents=self.documents,
            embedding=self.embeddings,
            persist_directory="./chroma_db_comparison"
        )
        
        return True
    
    # 1. 基础相似度检索器 (默认)
    def get_similarity_retriever(self, k=4):
        """相似度检索器 - 基于向量相似度"""
        return self.vectorstore.as_retriever(
            search_type="similarity",
            search_kwargs={"k": k}
        )
    
    # 2. MMR (最大边际相关性) 检索器
    def get_mmr_retriever(self, k=4, fetch_k=10):
        """
        MMR 检索器 - 平衡相关性和多样性
        - k: 返回的结果数量
        - fetch_k: 初始获取的结果数量
        """
        return self.vectorstore.as_retriever(
            search_type="mmr",
            search_kwargs={"k": k, "fetch_k": fetch_k}
        )
    
    # 3. 相似度阈值检索器
    def get_similarity_threshold_retriever(self, score_threshold=0.7):
        """相似度阈值检索器 - 只返回相似度高于阈值的结果"""
        return self.vectorstore.as_retriever(
            search_type="similarity_score_threshold",
            search_kwargs={"score_threshold": score_threshold}
        )
    
    # 4. BM25 检索器 (基于关键词)
    def get_bm25_retriever(self, k=4):
        """BM25 检索器 - 基于传统的关键词匹配算法"""
        # 提取文本内容
        texts = [doc.page_content for doc in self.documents]
        return BM25Retriever.from_texts(
            texts, 
            metadatas=[doc.metadata for doc in self.documents],
            k=k
        )
    
    # 5. 多查询检索器
    def get_multi_query_retriever(self, k=4):
        """多查询检索器 - 自动生成多个相关查询"""
        base_retriever = self.get_similarity_retriever(k=k)
        return MultiQueryRetriever.from_llm(
            retriever=base_retriever,
            llm=self.llm
        )
    
    # 6. 集成检索器 (混合搜索)
    def get_ensemble_retriever(self, k=4):
        """集成检索器 - 结合多种检索方法"""
        # 创建不同的检索器
        similarity_retriever = self.get_similarity_retriever(k=k*2)
        bm25_retriever = self.get_bm25_retriever(k=k*2)
        
        # 组合检索器 (可以调整权重)
        ensemble_retriever = EnsembleRetriever(
            retrievers=[similarity_retriever, bm25_retriever],
            weights=[0.7, 0.3]  # 向量搜索权重70%,关键词搜索权重30%
        )
        
        return ensemble_retriever
    
    # 7. 上下文压缩检索器
    def get_contextual_compression_retriever(self, k=4):
        """上下文压缩检索器 - 对检索结果进行压缩和去冗余"""
        from langchain.retrievers import ContextualCompressionRetriever
        from langchain.retrievers.document_compressors import LLMChainExtractor
        
        base_retriever = self.get_similarity_retriever(k=k*2)
        
        # 创建压缩器
        compressor = LLMChainExtractor.from_llm(self.llm)
        
        # 创建压缩检索器
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor,
            base_retriever=base_retriever
        )
        
        return compression_retriever
    
    def test_retrievers(self, question, retriever_types=None):
        """测试不同检索器的效果"""
        if retriever_types is None:
            retriever_types = [
                "similarity", "mmr", "similarity_threshold",
                "bm25", "multi_query", "ensemble", "contextual_compression"
            ]
        
        results = {}
        
        for retriever_type in retriever_types:
            try:
                print(f"\n🔍 测试 {retriever_type} 检索器...")
                
                # 获取相应的检索器
                if retriever_type == "similarity":
                    retriever = self.get_similarity_retriever(k=4)
                elif retriever_type == "mmr":
                    retriever = self.get_mmr_retriever(k=4, fetch_k=10)
                elif retriever_type == "similarity_threshold":
                    retriever = self.get_similarity_threshold_retriever(score_threshold=0.7)
                elif retriever_type == "bm25":
                    retriever = self.get_bm25_retriever(k=4)
                elif retriever_type == "multi_query":
                    retriever = self.get_multi_query_retriever(k=4)
                elif retriever_type == "ensemble":
                    retriever = self.get_ensemble_retriever(k=4)
                elif retriever_type == "contextual_compression":
                    retriever = self.get_contextual_compression_retriever(k=4)
                else:
                    continue
                
                # 执行检索
                docs = retriever.get_relevant_documents(question)
                
                # 存储结果
                results[retriever_type] = {
                    "documents": docs,
                    "count": len(docs),
                    "snippets": [doc.page_content[:100] + "..." for doc in docs]
                }
                
                print(f"  找到 {len(docs)} 个相关文档")
                for i, doc in enumerate(docs):
                    print(f"  {i+1}. {doc.page_content[:80]}...")
                    
            except Exception as e:
                print(f"  ❌ {retriever_type} 检索器出错: {e}")
                results[retriever_type] = {"error": str(e)}
        
        return results

def create_sample_documents():
    """创建包含不同主题的示例文档"""
    documents = [
        {
            "title": "langchain_intro.txt",
            "content": """
            LangChain是一个用于开发由语言模型驱动的应用程序的框架。
            它提供了一套工具、组件和接口,简化了构建基于LLM的应用过程。
            LangChain的核心概念包括模型I/O、检索、链、代理、内存等模块。
            """
        },
        {
            "title": "langchain_installation.txt",
            "content": """
            安装LangChain: pip install langchain。
            如果需要使用OpenAI模型,还需安装pip install langchain-openai。
            LangChain支持Python 3.8及以上版本。建议使用虚拟环境进行安装。
            """
        },
        {
            "title": "ai_technology.txt",
            "content": """
            人工智能技术正在快速发展,深度学习和大语言模型是当前的热点领域。
            机器学习算法可以帮助计算机从数据中学习模式和规律。
            自然语言处理技术使计算机能够理解和生成人类语言。
            """
        },
        {
            "title": "programming_basics.txt",
            "content": """
            Python是一种流行的编程语言,广泛用于数据科学和人工智能领域。
            编程基础包括变量、函数、循环和条件语句等概念。
            良好的代码结构和注释习惯对项目维护非常重要。
            """
        }
    ]
    
    file_paths = []
    for doc in documents:
        file_path = doc["title"]
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(doc["content"])
        file_paths.append(file_path)
    
    return file_paths

def main():
    # 创建示例文档
    print("创建示例文档...")
    file_paths = create_sample_documents()
    
    # 初始化比较器
    comparator = RetrievalComparison()
    comparator.load_and_process_documents(file_paths)
    
    # 测试问题
    test_questions = [
        "如何安装LangChain?",
        "LangChain是什么?",
        "人工智能技术有哪些?"
    ]
    
    for question in test_questions:
        print(f"\n{'='*60}")
        print(f"测试问题: {question}")
        print(f"{'='*60}")
        
        results = comparator.test_retrievers(
            question,
            retriever_types=["similarity", "mmr", "bm25", "ensemble"]
        )
        
        # 简要比较结果
        print(f"\n📊 {question} 的检索结果比较:")
        for retriever_type, result in results.items():
            if "error" in result:
                print(f"  {retriever_type}: 错误 - {result['error']}")
            else:
                print(f"  {retriever_type}: {result['count']} 个结果")

if __name__ == "__main__":
    main()

检索器类型对比分析

以下是不同检索器的特点比较:

1. 相似度检索器 (Similarity)

  • 原理: 基于向量相似度(余弦相似度)
  • 优点: 语义理解能力强,能找到概念相关的文档
  • 缺点: 可能返回过于相似的结果,缺乏多样性
  • 适用场景: 大多数语义搜索场景

2. MMR 检索器 (Maximal Marginal Relevance)

  • 原理: 平衡相关性和多样性,避免重复内容
  • 优点: 结果多样性好,避免信息冗余
  • 缺点: 计算成本稍高
  • 适用场景: 需要多样化结果的场景,如推荐系统

3. 相似度阈值检索器 (Similarity Threshold)

  • 原理: 只返回相似度高于阈值的结果
  • 优点: 结果质量高,过滤掉不相关文档
  • 缺点: 可能返回结果过少
  • 适用场景: 对精度要求极高的场景

4. BM25 检索器

  • 原理: 基于传统的关键词匹配算法
  • 优点: 计算速度快,对精确匹配效果好
  • 缺点: 缺乏语义理解能力
  • 适用场景: 关键词搜索、文档检索

5. 多查询检索器 (Multi-Query)

  • 原理: 自动生成多个相关查询,综合结果
  • 优点: 提高召回率,从不同角度搜索
  • 缺点: 计算成本高,可能引入噪声
  • 适用场景: 复杂查询,需要高召回率的场景

6. 集成检索器 (Ensemble)

  • 原理: 结合多种检索方法的结果
  • 优点: 兼顾语义搜索和关键词搜索的优点
  • 缺点: 配置复杂,需要调整权重
  • 适用场景: 需要最佳综合性能的场景

7. 上下文压缩检索器 (Contextual Compression)

  • 原理: 对检索结果进行压缩和去冗余
  • 优点: 返回更精炼、去重的结果
  • 缺点: 增加计算开销
  • 适用场景: 需要简洁结果的场景

选择指南

检索器类型 适用场景 优点 缺点
相似度 通用语义搜索 语义理解强 结果可能重复
MMR 多样化结果需求 结果多样性好 计算成本稍高
阈值过滤 高精度要求 结果质量高 可能结果过少
BM25 关键词搜索 速度快,精确匹配 缺乏语义理解
多查询 复杂查询 高召回率 计算成本高
集成 综合性能需求 兼顾多种优点 配置复杂
上下文压缩 简洁结果需求 结果精炼 增加计算开销

进阶使用建议

1. 根据数据特性选择

  • 短文本:BM25 + 相似度集成
  • 长文档:MMR 或上下文压缩
  • 专业术语:BM25 或阈值过滤

2. 性能优化

# 调整参数以获得最佳性能
retriever = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={
        "k": 6,                 # 返回结果数
        "fetch_k": 20,          # 初始获取数
        "lambda_mult": 0.7      # MMR多样性参数 (0-1)
    }
)

3. 混合策略

# 根据查询类型动态选择检索器
def get_dynamic_retriever(query):
    if is_keyword_query(query):  # 自定义的关键词检测函数
        return get_bm25_retriever()
    else:
        return get_similarity_retriever()
posted @ 2025-09-15 17:56  PyAj  阅读(26)  评论(0)    收藏  举报