ConversationalRetrievalChain的一些记录

最近在研究AI这个玩意儿,实现了一个简单的RAG问答demo。

用的是langchain + Ollama的本地模型 + bge-small-zh的本地模型来实现的。

chain用的是RetrievalQA来实现的。

然后呢就想搞一个连续对话形式的,发现他上下文不关联,因为第二次提问的时候,他相当于重新检索了后给LLM传递的信息。

搜了一下,说这叫检索的孤立性。如果要携带历史记录,要使用ConversationalRetrievalChain来实现。不用也可以,那就要自己维护一套历史缓存,每次携带到promtp里,然后检索后给LLM,会比较麻烦。

我用的是langchain1.0.0版本。

直接上代码:

    from langchain_classic.chains import ConversationalRetrievalChain
    from langchain_classic.memory import ConversationBufferMemory
    from langchain_core.prompts import PromptTemplate
    from langchain_community.vectorstores import FAISS
    from langchain_community.llms import Ollama
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        return_messages=True
    )
    llm = Ollama(model="deepseek-r1:7b", temperature=0.3)
    vectorstore = FAISS.load_local("./vector_store_bge_zh_002",EMBEDDINGS_MODEL,allow_dangerous_deserialization=True)
    prompt_template = """
            你是一个智能问答助手,根据历史消息,使用检索增强生成(RAG)技术回答问题。以下是与用户问题相关的检索结果和历史消息:
            
            【历史消息】
            {chat_history}
            
            【检索结果】
            {context}

            【任务要求】
            请根据上述检索结果,逐步推理并回答以下问题。你的回答必须包含以下三部分:
            1. 【相关信息摘要】:简要总结检索结果中与问题最相关的内容。
            2. 【推理过程】:基于上述信息,逐步解释你是如何得出结论的。如果信息不足,请明确指出。
            3. 【最终答案】:请不要使用“抱歉,我无法从文档中找到相关信息”来回答问题。
            【用户问题】
            {question}
            """
    PROMPT = PromptTemplate(
        template=prompt_template,
        input_variables=["context", "question","chat_history"]
    )

    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=vectorstore.as_retriever(search_kwargs={"k": 1}),
        return_source_documents=True,
        verbose=True,
        # 将自定义prompt传递给combine_docs_chain
        combine_docs_chain_kwargs={"prompt": PROMPT},
        memory = memory
    )
    while True:
        try:
            print("当前记录:",memory.chat_memory.messages)
            query = input("输入你的问题 \n")
            response = qa_chain.invoke({"question":query})
            print(response)
            print(response["answer"])
            print(memory.chat_memory.messages)
        except:
            loguru.logger.exception("xxx")

  这里踩了几个坑,记录一下。

  1、当ConversationalRetrievalChain的return_source_documents=True时,ConversationBufferMemory的output_key 必须指定,不然就会报错,指定的值在answer和source_document之间选一个

    因为ConversationalRetrievalChain默认只返回AI的回答,当需要返回文档时,不指定key的情况下,ConversationBufferMemory就不知道该缓存什么到历史消息里去了,就会抛异常  

File "D:\pycharmProject\RAG_Start\.venv\lib\site-packages\langchain_classic\memory\chat_memory.py", line 69, in _get_input_output

ValueError: Got multiple output keys: dict_keys(['answer', 'source_documents']), cannot determine which to store in memory. Please set the 'output_key' explicitly.

  2、ConversationBufferMemory的return_messages 参数,这个参数时bool值,用来控制缓存的历史消息的格式,默认是False,将历史消息拼接成str类型的,设置为True的话,返回的是BaseMessage的列表。

    而ConversationalRetrievalChain的invoke方法里获取history的时候,要的是tuple或BaseMessage,不设置这个参数,传递的就是str,就会报错。更尴尬的是,初次问答,chat_history是空字符串,还不会报错,第二次问答的时候,内层函数遍历chat_histoty的时候,进行类型检查,才会报错。如果不熟悉的话,根本不知道是ConversationBufferMemory这里的问题(比如我,debug了半天)

  

\Lib\site-packages\langchain_classic\chains\conversational_retrieval\base.py
def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
    buffer = ""
    for dialogue_turn in chat_history:
        if isinstance(dialogue_turn, BaseMessage):
            if len(dialogue_turn.content) > 0:
                role_prefix = _ROLE_MAP.get(
                    dialogue_turn.type,
                    f"{dialogue_turn.type}: ",
                )
                buffer += f"\n{role_prefix}{dialogue_turn.content}"
        elif isinstance(dialogue_turn, tuple):
            human = "Human: " + dialogue_turn[0]
            ai = "Assistant: " + dialogue_turn[1]
            buffer += f"\n{human}\n{ai}"
        else:
            msg = (  # type: ignore[unreachable]
                f"Unsupported chat history format: {type(dialogue_turn)}."
                f" Full chat history: {chat_history} "
            )
            raise ValueError(msg)  # noqa: TRY004
    return buffer

  报错信息:

  get_chat_history ValueError: Unsupported chat history format: <class 'str'>.

    

 

posted on 2026-01-15 22:50  超级大懒虫vip  阅读(1)  评论(0)    收藏  举报

导航