第4~5章 - Agent 短期记忆和Agent长期记忆

4. Agent短期记忆

在构建能够进行多轮交互的AI代理(Agents)时,记忆系统是其核心组件之一。记忆使Agent能够保留先前交互的信息,从而学习反馈、适应用户偏好,并高效处理复杂任务。Agent记忆可分为短期记忆和长期记忆。本章节重点介绍短期记忆相关内容。

4.1. 短期记忆介绍

短期记忆(Short-term Memory),也称为线程范围记忆(Thread-scoped Memory),是指应用程序在单个线程或会话中记住先前交互的能力,其本质是维护当前会话的完整对话历史,以实现多轮对话的连贯性、个性化响应及上下文感知。

在LangChain生态中,短期记忆被作为Agent状态(State)的一部分进行管理。状态通过线程范围的检查点(checkpointer)持久化到内存/数据库中,确保对话线程可随时恢复与继续。线程(Thread)是组织短期记忆的核心概念。线程(Thread)是组织短期记忆的核心单元,代表一个独立会话上下文,每个线程通过唯一标识符(thread_id)隔离不同对话的记忆。

此外,状态不仅包含对话历史,还可存储其他有状态数据,如用户登录状态、此刻查询订单号、所在城市等临时信息。

在LangChain中,短期记忆使用代码示例如下:

from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver

from init_llm import deepseek_llm

# 创建带有短期记忆的智能体
checkpointer = InMemorySaver()

agent = create_agent(
    model=deepseek_llm,
    tools=[],  # 可在此处添加工具
    checkpointer=checkpointer  # 启用短期记忆
)

# 使用相同thread_id维持对话上下文
config = {"configurable": {"thread_id": "conversation_1"}}

# 第一轮对话
response = agent.invoke({"messages": [{"role": "user", "content": "你好,我叫张三"}]}, config)
print(response["messages"][-1].content)

print("="*20)

# 第二轮对话 - Agent能记住之前的对话
response = agent.invoke({"messages": [{"role": "user", "content": "我叫什么名字?"}]}, config)
print(response["messages"][-1].content)  # 输出: "你叫张三"

代码运行结果如下:

image.png

以上代码注意如下几点:

  1. create_agent中通过指定checkpointer参数来启用短期记忆,InMemorySaver表示将会话存储在内存中,每次对话迭代后自动保存对话状态。
  2. 通过config={"configurable": {"thread_id": "conversation_1"}}指定会话唯一标识,thread_id用于隔离不同用户的对话上下文。相同thread_id保证状态连续性,不同ID则创建新记忆空间。
  3. 每次invoke后自动触发会话状态保存,根据不同的thread_id追加保存会话内容。

4.2. 短期记忆使用方式

Checkpointer是LangChain短期记忆的核心架构,负责Agent状态的持久化管理。其工作原理是在每个执行步骤后保存Agent的状态快照,可以将状态存储在内存或者数据库中。

4.2.1. 使用内存存储短期记忆

在测试环境中,通常使用内存型的Checkpointer,这种方式简单易用但进程重启后数据会丢失。

使用内存存储短期记忆案例代码如下:

from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
from langchain_core.tools import tool

from init_llm import deepseek_llm


# 工具:获取用户信息
@tool
def get_user_info(name: str) -> str:
    """
    根据姓名查询用户信息
    Args:
        name (str): 要查询的用户姓名
    Returns:
        str: 包含用户信息的字符串
    """
    user_db = {
        "张三": {"age": 28, "hobby": "旅游、滑雪、喝茶"},
        "李四": {"age": 32, "hobby": "编程、阅读、电影"}
    }
    info = user_db.get(name, {"age": "未知", "hobby": "未知"})
    return f"姓名: {name}, 年龄: {info['age']}岁, 爱好: {info['hobby']}"

# 创建内存检查点
checkpointer = InMemorySaver()

# 创建带有短期记忆的智能体
agent = create_agent(
    model=deepseek_llm,
    tools=[get_user_info],
    checkpointer=checkpointer  # 启用短期记忆
)

# 使用线程ID维持对话上下文
config = {"configurable": {"thread_id": "user_123"}}

# 多轮对话
print("=== 第一轮对话 ===")
result1 = agent.invoke({"messages": "你好,我叫张三"}, config)
print(f"AI: {result1['messages'][-1].content}")

print("\n=== 第二轮对话 ===")
result2 = agent.invoke({"messages": "你知道我的信息吗?"}, config)
print(f"AI: {result2['messages'][-1].content}")

print("="*20)
# 获取Agent记忆状态
state = agent.get_state(config)
print(type(state))
print(state)

以上代码运行结果如下:

image.png

可以通过agent.get_state(config)来获取当前存储thread_id的状态,获取Agent状态结果是StateSnapshot对象,该对象默认有messages属性存储多轮对话的消息,内容如下:

image.png

4.2.2. 使用数据库存储短期记忆

在生产环境中,推荐使用数据库支持的Checkpointer,以确保数据的持久化和多实例部署的支持。

如下案例中使用mysql数据库来存储短期记忆,完成此案例需要提前安装好mysql数据库(默认已经安装mysql8),并且进行数据库创建和安装必要python依赖,具体如下:

1) 在mysql中创建langchain_db数据库

#进入mysql navicate客户端,创建mysql数据库langchain_db
create database langchain_db;

2) 在当前python环境中安装如下依赖

#安装必要依赖
conda activate langchain_v1.2
python -m pip install langgraph-checkpoint-mysql==3.0.0 pymysql==1.1.2 cryptography==46.0.3

使用mysql数据库存储短期记忆代码如下:

from langchain.agents import create_agent
from langchain_core.tools import tool

from init_llm import deepseek_llm
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver

# 工具:获取用户信息
@tool
def get_user_info(name: str) -> str:
    """
    根据姓名查询用户信息
    Args:
        name (str): 要查询的用户姓名
    Returns:
        str: 包含用户信息的字符串
    """
    user_db = {
        "张三": {"age": 28, "hobby": "旅游、滑雪、喝茶"},
        "李四": {"age": 32, "hobby": "编程、阅读、电影"}
    }
    info = user_db.get(name, {"age": "未知", "hobby": "未知"})
    return f"姓名: {name}, 年龄: {info['age']}岁, 爱好: {info['hobby']}"


# 配置 MySQL 连接
DB_URI = "mysql+pymysql://root:123456@localhost:3306/langchain_db?charset=utf8mb4"

# 生产环境配置
with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer:
    # 自动创建数据库表(首次运行)
    checkpointer.setup()

    # 创建生产环境可用的智能体
    agent = create_agent(
        model=deepseek_llm,
        tools=[get_user_info],
        checkpointer=checkpointer
    )

    # 配置会话
    config = {"configurable": {"thread_id": "user_001"}}

    # 先获取状态看看
    print(agent.get_state(config))

    # 模拟用户对话
    agent.invoke({
        "messages": [{"role": "user", "content": "我是用户张三"}]
    }, config)

    # 后续对话中智能体会记住用户信息
    response = agent.invoke({
        "messages": [{"role": "user", "content": "我是谁?"}]
    }, config)

    print(f"AI响应: {response['messages'][-1].content}")

以上代码运行结果如下:

image.png

运行代码后进入到Mysql数据库中可以看到对应的数据库表和数据:

image.png

以上代码需要注意如下几点:

  1. “with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer:”通过连接字符串DB_URI创建与MySQL数据库的持久化连接,用于管理对话状态检查点的存储。
  2. 使用mysql存储短期记忆需要提前在数据库中创建对应的数据库,然后代码首次运行执行“checkpointer.setup()”(首次运行需要,首次运行后可以不再执行该代码)会自动在该数据库中创建对应数据库表。
  3. 也可以使用其他数据库进行短期记忆的持久化存储,例如使用postgresql存储,需要安装“pip install langgraph-checkpoint-postgres==3.0.4”,具体代码参考:https://docs.langchain.com/oss/python/langchain/short-term-memory#in-production
  4. 持久化存储支持的数据库可以通过“https://pypi.org/search/?o=&q=langgraph-checkpoint&page=2”查看,搜索“langgraph-checkpoint-*”查看对应需要安装的依赖和使用方式。

image.png

4.3. 自定义记忆状态

默认情况下,Agent底层是通过AgentState对象管理短期记忆(如对话历史),具体是通过messages key来将所有对话的消息进行存储。一些复杂的业务场景中,我们可以通过继承和扩展基础的AgentState类,添加业务特定的记忆字段,实现自定义记忆状态,做到更精细化的状态管理。

自定义记忆状态实现步骤如下:

  1. 扩展AgentState类:通过继承AgentState类定义自定义字段。
  2. 创建支持初始化Agent时,通过state_schema参数注入自定义状态类。
  3. 自定义状态Agent:在调用Agent时传递自定义数据:在invoke方法中,可以根据需要显式传递自定义字段。

如下代码完成了自定义记忆状态:

from langchain.agents import create_agent, AgentState
from langgraph.checkpoint.memory import InMemorySaver

from init_llm import deepseek_llm

# 1. 扩展AgentState 类,自定义状态
class CustomAgentState(AgentState):
    user_id: str  # 用户唯一标识
    hobby: list  # 用户爱好
    other_info: dict  # 用户其他信息


# 2.创建 Agent,通过state_schema参数指定自定义状态
agent = create_agent(
    model=deepseek_llm,
    tools=[],
    state_schema=CustomAgentState,
    checkpointer=InMemorySaver()
)

config = {"configurable": {"thread_id": "user_001"}}

# 3.调用 Agent时,传入自定义状态
result = agent.invoke(
    {
        "messages":[{"role":"user","content":"你好,我是张三"}],
        "user_id": "user_001",
        "hobby": ["旅游、滑雪、喝茶"],
        "other_info": {"age": 28, "gender": "男"},
    },
    config=config
)

print("AI回复:", result["messages"][-1].content)
# 查看保存的状态
print("当前状态:", agent.get_state(config=config))



print("="*20)
# 后续调用自动携带状态
result = agent.invoke(
    {
        "messages": [{"role":"user","content":"使用十个字介绍你自己"}],
    },
    config=config
)
print("AI回复:", result["messages"][-1].content)
# 查看保存的状态
print("当前状态:", agent.get_state(config=config))

以上代码运行结果如下,可以看到通过内存保存了自定义记忆状态,每次对话都会有携带自定义的记忆状态:

image.png

以上代码注意如下几点:

  1. create_agent时,通过state_schema参数指定自定义记忆状态的类即可,类中参数一定要包含所有自定义状态,否则该传入的状态参数不会被作为状态管理。
  2. 自定义记忆状态与默认的“messages”key同级,对应的记忆状态key为继承AgentState类后自定义字段,value为对应invoke传入的值。
  3. create_agent时,需要指定checkpointer将对话记忆进行保存,只需要在首次invoke调用agent时指定自定义记忆状态,后续所有对话都能获取到对应的自定义记忆内容。

4.4. 短期记忆访问和修改

短期记忆的访问和修改主要通过Tools工具接口、@before_model中间件、@after_model中间件三种方式实现,均基于LangChain的AgentState(短期记忆容器)和ToolRuntime(工具运行时上下文)来实现,下面分别介绍。

4.4.1. 通过Tools访问和修改短期记忆

我们可以通过Tool来访问和修改短期记忆内容,Tools 是智能体与外部世界交互的接口(如调用 API、查询数据库),定义tools时可以传入“ToolRuntime”对象,该对象是LangChain中内置对象,是 Tools 的运行时上下文,ToolRuntime对象中包含Agent的状态(State)、上下文(Context)、存储(Store)等信息,工具通过 ToolRuntime参数访问或修改状态,且该参数不会暴露给 LLM(避免模型看到内部状态)。当通过工具进行短期记忆修改时,工具可以通过返回 Command对象(包含 update字段)来修改状态。

4.4.1.1. 通过Tools读取短期记忆

如下案例中,构建一个查询用户信息的Agent,首次调用Agent传入一些自定义状态,然后在多个工具中通过ToolRuntime对象获取该状态。代码如下:

from langchain.agents import AgentState, create_agent
from langgraph.checkpoint.memory import InMemorySaver
from langchain.tools import tool, ToolRuntime

from init_llm import deepseek_llm

@tool
def get_info(runtime: ToolRuntime) -> str:
    """
    查询用户会员等级
    Args:
        runtime (ToolRuntime): 包含当前状态的运行时环境
    Returns:
        str: 用户会员等级信息
    """
    print("runtime:", runtime)
    # 从状态中读取会员等级
    user_level = runtime.state["user_level"]
    if user_level == "VIP":
        return "你的会员等级是VIP,你有免费退换货、专属客服通道等福利"
    else:
        return "你的会员等级是普通会员,你有积分翻倍活动等福利"

@tool
def get_user_id(runtime: ToolRuntime) -> str:
    """
    查询用户唯一标识
    Args:
        runtime (ToolRuntime): 包含当前状态的运行时环境
    Returns:
        str: 用户唯一标识
    """
    print("runtime:", runtime)
    # 从状态中读取用户唯一标识
    user_id = runtime.state["user_id"]
    return f"你的用户唯一标识是:{user_id}"


# 自定义状态
class CustomerState(AgentState):
    user_id: str # 用户唯一标识
    user_level: str  # 会员等级(如"VIP", "Normal")


# 创建 Agent
agent = create_agent(
    model=deepseek_llm,
    tools=[get_info, get_user_id],
    state_schema=CustomerState,
    checkpointer=InMemorySaver()
)

config = {"configurable": {"thread_id": "session_001"}}

# 调用 Agent时,传入自定义状态
response1 = agent.invoke(
    {
        "messages": [{"role": "user", "content": "我的会员等级是什么?"}],
        "user_id": "user_123",
        "user_level": "VIP"
    },
    config=config
)

print(response1)
print(response1["messages"][-1].content)

# 后续调用自动携带状态
response2 = agent.invoke(
    {
        "messages": [{"role": "user", "content": "查看我的用户唯一标识"}],
    },
    config=config
)
print(response2)
print(response2["messages"][-1].content)

以上代码运行结果如下:

image.png

以上代码注意如下几点:

  1. 在工具中通过ToolRuntime获取短期记忆(包括用户自定义记忆状态),该对象内容如下:

    image.png

  2. ToolRuntime是隐藏参数,不会出现在工具的签名中,因此 LLM 不会知道工具如何访问状态,只会看到工具的描述。

  3. 创建Agent时指定checkpointer,只需要首次agent.invoke时传入自定义状态参数,同一thread_id后续多次对话都能携带该状态。

4.4.1.2. 通过Tools修改短期记忆

通过Tools修改短期记忆要求工具返回Command对象,Command对象中包含update字段,该update中使用K,V dict方式指定更新用户自定义状态,同时update中还必须指定“message”字段返回工具执行结果到消息历史中。

如下案例中,通过工具修改用户自定义短期记忆:

from langchain.agents import create_agent, AgentState
from langgraph.checkpoint.memory import InMemorySaver
from langchain.tools import tool, ToolRuntime
from langgraph.types import Command
from langchain.messages import ToolMessage

from init_llm import deepseek_llm


class CustomerState(AgentState):
    """自定义记忆状态"""
    user_name: str = ""
    hobby: list = []

@tool
def update_user_profile(runtime: ToolRuntime, name: str, hobby: list) -> Command:
    """
    更新用户档案信息并持久化到记忆状态
    Args:
        runtime (ToolRuntime): 包含当前状态的运行时环境
        name (str): 用户姓名
        hobby(list): 用户爱好(如"看电影", "听音乐"等)
    Returns:
        Command: 包含更新操作的命令对象
    """
    # 输入验证
    if not name or not hobby:
        return Command(
            update={
                "messages": [
                    ToolMessage(
                        content="错误:姓名和爱好不能为空",
                        tool_call_id=runtime.tool_call_id
                    )
                ]
            }
        )

    # 准备更新内容
    updates = {
        "user_name": name,
        "hobby": hobby,
        "messages": [
            ToolMessage(
                content=f"已更新用户档案:姓名={name}, 爱好={','.join(hobby)}",
                tool_call_id=runtime.tool_call_id
            )
        ]
    }

    return Command(update=updates)

# 创建支持记忆写入的智能体
agent = create_agent(
    model=deepseek_llm,
    tools=[update_user_profile],
    state_schema=CustomerState,
    checkpointer=InMemorySaver(),
)

# 测试记忆写入功能
config = {"configurable": {"thread_id": "session_001"}}

# 初始调用
result1 = agent.invoke({
    "messages": [{"role": "user", "content": "我叫王五,我的爱好是钓鱼和唱歌"}]
}, config)
print("模型回复:", result1['messages'][-1].content)

print("="*20)

result2 = agent.invoke({
    "messages": [{"role": "user", "content": "我也喜欢旅游"}]
}, config)
print("模型回复:", result2['messages'][-1].content)

print("="*20)

print("当前状态:", agent.get_state(config=config))

以上代码运行结果如下:

image.png

短期记忆内容如下,可以看到成功修改了记忆状态:

image.png

以上代码注意如下几点:

  1. 工具 update_user_profile通过返回 Command对象,将用户一些信息保存到状态中(update={"user_name": name...}),同时必须添加工具执行结果到消息历史(messages字段)
  2. messages字段对应的value是一个数组对象,对象中是ToolMessage,需要指定content和tool_call_id字段,方便后续在状态消息查看对应工具返回的消息。

4.4.2. @before_model中间件操作短期记忆

@before_model是 LangChain Agent 执行流程中的前置拦截器,在模型调用前触发。其核心功能包括:

  • 消息预处理:裁剪/删除/总结历史消息,控制上下文长度,可以参考“超出LLM 上下文解决方案”小节内容。
  • 状态管理:读取或修改 Agent 的短期记忆(AgentState)。

image.png

如下案例中演示通过@before_model来读取和修改用户自定义状态:

from langchain.agents import AgentState
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain.agents.middleware import before_model
from langgraph.runtime import Runtime
from typing import Any, Dict
from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver

from init_llm import deepseek_llm

@tool
def get_weather(city: str) -> str:
    """
    获取指定城市的天气
    Args:
        city (str): 要查询天气的城市名称

    Returns:
        str: 包含城市天气信息的字符串
    """
    return f"{city}的天气是晴朗的,温度是25摄氏度"

class CustomState(AgentState):
    """扩展状态:记录工具调用次数"""
    tool_call_count: int  # 记录工具调用次数

@before_model
def manage_state(state: CustomState, runtime: Runtime) -> Dict[str, Any] | None:
    """中间件逻辑:
    1. 更新对话轮次计数
    2. 存储用户信息
    """
    print("before_model_state:", state)
    print("before_model_runtime:", runtime)

    # 从状态中获取自定义字段 tool_call_count
    tool_call_count = state.get("tool_call_count", 0)
    print("状态中工具调用次数:", tool_call_count)

    # 读取 Messages 中为 ToolMessage 的消息条数
    tool_call_count = len([msg for msg in state["messages"] if isinstance(msg, ToolMessage)])

    # 返回修改后的状态和消息
    return {
        "tool_call_count": tool_call_count
    }

# 初始化 Agent,指定自定义状态和中间件
agent = create_agent(
    model=deepseek_llm,
    tools=[get_weather],
    middleware=[manage_state],
    checkpointer=InMemorySaver(),
    state_schema=CustomState  # 指定扩展状态类
)

config = {"configurable": {"thread_id": "session_001"}}

response1 = agent.invoke({"messages": [{"role": "user", "content": "北京今天天气如何?"}]},config=config)
print(response1["messages"][-1].content)
print("***" * 20)

response2 = agent.invoke({"messages": [{"role": "user", "content": "上海今天天气如何?"}]},config=config)
print(response2["messages"][-1].content)
print("***" * 20)

# 打印状态
print(agent.get_state(config=config))

代码最后输出的agent状态内容如下:

image.png

以上代码注意如下:

  1. @before_model中间件默认有2个参数AgentState和Runtime。AgentState存储Agent状态(messages列表和自定义状态等);Runtime是LangChain中上下文,包含context(用户上下文)、store(长期记忆持久化接口)等内容。
  2. 在创建agent时需要通过state_schema参数指定自定义参数类型,指定的类型中指定变量类型,该变量需要与自定义状态的变量名称一样,这样用户自定义状态才能在内部正确保存。

4.4.3. @after_model中间件操作短期记忆

@after_model是 LangChain Agent 执行流程中的后置拦截器,在模型生成响应后触发。其核心功能包括:

  • 输出校验:验证模型响应的合规性(如敏感词过滤)
  • 消息预处理:多轮对话中,进行裁剪/删除/总结历史消息,控制上下文长度,可以参考“超出LLM 上下文解决方案”小节内容。
  • 状态管理:基于模型输出动态修改 Agent 的短期记忆(AgentState)。

image.png

如下Agent案例可以查询订单信息和商品库存,并结构化方式输出结果。当用户查询某个订单信息时,根据模型返回结构化结果在@after_model中获取该结果的商品名称并更新自定义的状态,这样后续对话中查询商品库存时,可以直接获取状态中的该商品名称,进而查询库存。

from langchain.agents import AgentState, create_agent
from langchain.agents.middleware import after_model
from langchain.agents.structured_output import ToolStrategy
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.prebuilt import ToolRuntime
from langgraph.runtime import Runtime
from pydantic import BaseModel, Field
from typing import Dict, Any, Union

from init_llm import deepseek_llm


# 1. 定义Agent格式化返回结构
class OrderQueryResult(BaseModel):
    """订单查询响应结构"""
    order_id: str  # 订单ID
    product_name: str  # 订单中商品名称
    price: float  # 订单金额
    status: str  # 订单状态


class InventoryQueryResult(BaseModel):
    """库存查询响应结构"""
    product_name: str  # 商品名称
    stock_quantity: int  # 库存数量

# 2. 模拟数据库数据
MOCK_DATABASE = {
    "orders": {
        "order_001": OrderQueryResult(order_id="order_001", product_name="华为手机", price=1999.00, status="已发货"),
        "order_002": OrderQueryResult(order_id="order_002", product_name="苹果电脑", price=2999.00, status="待发货"),
        "order_003": OrderQueryResult(order_id="order_003", product_name="三星显示器", price=3999.00, status="已签收"),

    },
    "inventory": {
        "华为手机": InventoryQueryResult(product_name="华为手机",stock_quantity=50),
        "苹果电脑": InventoryQueryResult(product_name="苹果电脑",stock_quantity=20),
        "三星显示器": InventoryQueryResult(product_name="三星显示器",stock_quantity=30)
    }
}


# 3. 自定义状态类
class OrderState(AgentState):
    """自定义状态"""
    product_name: str  # 订单中商品名称


# 4. 定义工具函数
@tool
def get_order_info(order_id: str) -> OrderQueryResult:
    """获取订单详情
    Args:
        order_id (str): 订单ID
    Returns:
        OrderQueryResult: 订单详情
    """
    order_data = MOCK_DATABASE["orders"].get(order_id)

    if order_data:
        return order_data
    else:
        raise ValueError("订单不存在")


@tool
def get_product_inventory(runtime: ToolRuntime) -> InventoryQueryResult:
    """查询商品库存
    """
    print("runtime:", runtime)
    # 从 runtime中获取商品名称
    product_name = runtime.state["product_name"]

    inventory_data = MOCK_DATABASE["inventory"].get(product_name)
    if inventory_data:
        return inventory_data
    else:
        raise ValueError("商品不存在")


# 5. 中间件,模型调用完成后,从结构化输出中提取商品名称设置到状态中
@after_model
def manage_order_state(state: AgentState,runtime: Runtime) -> Dict[str, Any] | None:
    print("state:", state)

    # 如果state中没有结构化响应,直接返回None
    if "structured_response" not in state:
        return None

    # 获取AI大模型结构化输出结果
    structured_response = state["structured_response"]

    # 解析模型输出,如果是订单查询结果,返回商品名称,否则返回None
    if isinstance(structured_response, OrderQueryResult):
        product_name = structured_response.product_name
        return {"product_name": product_name}
    else:
        return None


# 6. 创建 Agent
agent = create_agent(
    model=deepseek_llm,
    tools=[get_order_info, get_product_inventory],
    response_format=ToolStrategy(Union[OrderQueryResult, InventoryQueryResult]),
    middleware=[manage_order_state],
    state_schema=OrderState,
    checkpointer=InMemorySaver()
)

# 7. 测试调用
config = {"configurable": {"thread_id": "user_001"}}

# 创建订单测试
response1 = agent.invoke({"messages": [{"role": "user", "content": "查询订单order_001信息"}]},config=config)
print("response1:", response1["structured_response"])
print("***" * 20)

# 查询订单测试
response2 = agent.invoke({"messages": [{"role": "user", "content": "这个订单中商品库存是多少"}]},config=config)
print("response2:", response2["structured_response"])
print("***" * 20)

response3 = agent.invoke({"messages": [{"role": "user", "content": "查询订单order_002信息"}]},config=config)
print("response3:", response3["structured_response"])
print("***" * 20)

response4 = agent.invoke({"messages": [{"role": "user", "content": "商品库存是多少"}]},config=config)
print("response4:", response4["structured_response"])
print("***" * 20)

# 查看最终状态
final_state = agent.get_state(config=config)
print("最终 product_name 状态:", final_state.values["product_name"])

以上代码运行结果如下:

image.png

以上代码注意如下几点:

  1. @after_model中间件默认有2个参数AgentState和Runtime。AgentState存储Agent状态(messages列表和自定义状态等);Runtime是LangChain中上下文,包含context(用户上下文)、store(长期记忆持久化接口)等内容。
  2. 以上代码通过2个工具来完成订单和商品库存查询,商品库存查询直接通过获取状态中的商品信息来进行查询,该状态通过@after_model中间件进行更新,这里模拟在Agent运行过程中可以动态修改状态来满足一些业务操作需要。

4.5. state和context区别

LangChain Agent中进行调用时可以传入初始state或者context(上下文),如下:

... ...
config = {"configurable": {"thread_id": "1"}}
agent.invoke(
    {
        "messages": [{"role": "user", "content": "你的问题"}],
        "state_key1": "state_value1",
        "state_key2": "state_value2",
    },
    config=config,
    context={
        "context_key1": "context_value1",
        "context_key2": "context_value2",
    }
)
... ...

以上传递的state内容和context内容都可以在Agent中工具或者中间件中进行获取使用,那么两者有什么区别?

  • state:state用于管理Agent的短期记忆,决定了一个会话线程(thread)中,Agent可以记住哪些自定义的业务数据,其值可以在单次会话的生命周期内,通过工具(Tools)或中间件(如 @before_model, @after_model)进行读取和更新。
  • context:context用于定义传递Agent的静态或半静态的上下文信息,这些信息更像对话发生的“背景板”或“环境变量”,在会话开始时设定,在会话过程中通常保持不变。

两者核心区别总结如下:

对比项 State Context
设计目的 记录在Agent同一个会话(Thread)中,随着对话进行而不断演变的信息。目的是实现多轮对话的连贯性和复杂的业务状态流转。例如:聊天历史、本次购物车商品、当前查询的订单号。 Agent中,提供单次调用或会话的稳定背景信息与环境参数。目的是为工具执行和模型推理提供“场景设定”。例如:用户的地理位置、设备信息、接入渠道(APP/网页)等。
可变性 高,动态变化。在会话中频繁被读取和更新。例如,每次调用工具后更新call_count,或根据用户问题更新current_city。 低,相对静态。通常在一次会话初始化后只读、不频繁变化。作为工具和模型判断的背景依据。
持久化与生命周期 由Checkpointer管理,自动持久化。当创建Agent时指定了checkpointer(如InMemorySaver())和唯一的thread_id,State会在每次invoke后被自动保存。同一thread_id的后续调用会自动加载之前的State,实现记忆延续。 默认不自动持久化,生命周期常与单次Agent调用绑定。如Agent第二次调用时不显式传入context,中间件和工具将无法获取到context,通常需要每次调用时传入。
会话特点 会话隔离,与thread_id绑定,不同会话拥有独立的状态副本。 invoke执行内共享,可以被当前invoke执行中的的所有工具和中间件访问,作为执行的背景信息。

案例:创建Agent,调用Agent时通过上下文Context传入用户姓名和来源,同时传入call_llm_count状态统计调用大模型次数。

该案例中,第一次调用Agent时,可以在工具/中间件中获取到状态和Context上下文信息,但是第二次调用Agent时,没有传入State和Context时,State状态值会从checkpointer中读取,但是Context上下文的值并不由checkpointer保存,所以这里我们在第一次进行Agent对话时,在after_modol中间件中将Context上下文的值获取到后设置到State中,这样后续Agent调用多次时,都可以获取到最开始传入的State和Context的值。

from langchain.agents import AgentState, create_agent
from langchain_core.tools import tool
from langchain.agents.middleware import after_model
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.prebuilt import ToolRuntime
from langgraph.runtime import Runtime
from typing import Any, Dict
from init_llm import deepseek_llm


# 1. 定义 StateSchema,用于存储在会话中动态变化的信息
class ConversationState(AgentState):
    """状态:会话中的动态记忆"""
    user_name: str  # 用户名称
    channel: str  # 渠道名称
    call_llm_count: int  # 大模型调用次数


# 2. 定义一个工具,演示如何在工具内部获取 Context 和 State
@tool
def get_weather(city: str, runtime: ToolRuntime) -> str:
    """
    获取指定城市的天气,返回天气信息。

    Args:
        city (str): 要查询天气的城市名称。
        runtime (ToolRuntime): LangChain运行时对象,用于访问状态和上下文。
                               这是一个“隐藏参数”,LLM不会直接看到,但工具函数可以使用。

    Returns:
        str: 包含天气信息的字符串。
    """
    # 从 ToolRuntime 获取 Context (上下文信息)
    context = runtime.context
    # 从 ToolRuntime 获取 State (状态信息)
    state = runtime.state

    # 从Context中获取用户名称和渠道,如果Context为空,则从State中获取
    if context:
        user_name = context.get("user_name", "未知用户")
        channel = context.get("channel", "未知渠道")
        print(f"[get_weather] 获取到上下文:用户 {user_name} 来自 {channel}")
    else:
        user_name = state.get("user_name", "未知用户")
        channel = state.get("channel", "未知渠道")
        print(f"[get_weather] 获取到状态:用户 {user_name} 来自 {channel}")

    # 从 State 中获取当前调用次数
    current_call_llm_count = state.get("call_llm_count", 0)
    print(f"[get_weather] 获取到状态:当前大模型调用次数 {current_call_llm_count}")

    return f"{city} 的天气晴朗!"


# 3. 定义 @after_model 中间件,在此处获取和修改State
@after_model
def my_middleware(state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
    """
    中间件:在模型调用后执行。
    1. 读取上下文和状态。
    2. 更新状态(如用户名称、渠道、大模型调用次数)。
    """
    # 获取上下文 (例如,用于逻辑判断,但通常不修改)
    context = runtime.context

    # 从上下文获取用户名称和渠道,如果上下文为空,则从状态中获取
    if context:
        user_name = context.get('user_name', "未知用户")
        channel = context.get('channel', "未知渠道")
        print(f"[after_model中间件] 获取到上下文: 用户 {user_name} 来自 {channel}")
    else:
        user_name = state.get('user_name', "未知用户")
        channel = state.get('channel', "未知渠道")
        print(f"[after_model中间件] 获取到状态: 用户 {user_name} 来自 {channel}")

    # 获取并更新状态
    call_llm_count = state.get('call_llm_count', 0)
    print(f"[after_model中间件] 获取到状态: 当前 LLM 调用次数: {call_llm_count}")

    current_llm_call_count = call_llm_count + 1

    # 更新状态,状态中增加用户姓名、渠道
    return {
        "user_name": user_name,
        "channel": channel,
        "call_llm_count": current_llm_call_count
    }


# 4. 创建Agent
agent = create_agent(
    model=deepseek_llm,
    tools=[get_weather],
    middleware=[my_middleware],
    checkpointer=InMemorySaver(),
    # context_schema= xxx # 如果定义状态使用 Pydantic 模型,这里可以指定上下文数据结构
    state_schema=ConversationState  # 指定状态数据结构
)

# 5. 模拟调用
config = {"configurable": {"thread_id": "session_123"}}

# 第一次调用Agent,传入 Context 和初始 State
response1 = agent.invoke(
    {
        "messages": [{"role": "user", "content": "北京今天天气如何?"}],
        # 可以传入状态的初始值,也可以不传入
        "call_llm_count": 0,
    },
    config=config,
    # 传入上下文,这些值在整个会话中通常保持不变
    context={
        "user_name": "张三",
        "channel": "App",
    }
)

print("response1:", response1["messages"][-1].content)
print("当前完整状态:", agent.get_state(config=config))

# 第二次调用,不传入初始状态和上下文
print("=" * 50)
response2 = agent.invoke(
    {"messages": [{"role": "user", "content": "上海呢?"}]},
    config=config,
    # context 每次调用都要传入,如果连续会话中调用agent不传入可以通过记忆中的状态获取,前提是要设置状态
    # context={
    #     "user_name": "李四",
    #     "channel": "Web",
    # }
)
print("response2:", response2["messages"][-1].content)
print("当前完整状态:", agent.get_state(config=config))

以上代码注意点:

  1. 调用Agent传入的Context上下文为了保证后续调用Agent能获取到,可以在工具或者中间件中获取到上下文后设置在状态中。
  2. 定义state_chema时,一定要将状态中的变量指定全,否则可能一些变量无法正常保存在状态中。
  3. Context上下文如果是Pydantic类型,可以在创建Agent时指定context_schema指定类型

4.6. 超出LLM上下文解决方案

一般大模型对话的上下文窗口存在硬性限制(如GPT-4 Turbo的128k tokens),LangChain Agent开启短期记忆后,当对话轮次过多,上下文token总量会超出模型处理能力,从而导致报错。长对话token超出LLM的上下文窗口常见的处理方案有如下四种:消息截断、消息删除、消息摘要、自定义策略,下面分别介绍。

4.6.1. Trim Message-消息截断

Trim Messages(消息截断)是在调用大模型前可以保留最近N条消息,删除旧的消息内容,例如当对话消息数(HumangMessage+AIMessage+ToolMessage)超过一定条数后,只保留最近3条消息。

Agent中消息截断需要通过@before_model中间件实现,@before_model中间件可以在模型调用前预处理Agent的状态(如消息历史),例如修剪过长的消息历史(避免超出 LLM 的上下文窗口)、过滤无关信息等。

@before_model中间件原理是通过标记一个函数为模型前置处理器,在调用LLM前,执行该函数,该函数接收当前状态(state)和运行时上下文(runtime),并返回修改后的状态(如修剪后的消息列表)。

Trim Message(消息截断)方案适用于实时性要求高、早期对话信息价值低的场景,例如:实时聊天机器人。

如下案例中,创建@before_model中间件 trim_messages,用于保留最近的 3 轮消息(避免消息历史过长,超出 LLM 的上下文窗口)。

from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model
from langchain.messages import RemoveMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime

from init_llm import deepseek_llm

@tool
def get_weather(city: str) -> str:
    """
    获取指定城市的天气
    Args:
        city (str): 要查询天气的城市名称

    Returns:
        str: 包含城市天气信息的字符串
    """
    return f"{city}的天气是晴朗的,温度是25摄氏度"

@before_model
def trim_messages(state: AgentState, runtime: Runtime) -> dict:
    # 打印当前消息
    print("当前state:", state)
    print("当前runtime:", runtime)

    messages = state["messages"]
    if len(messages) > 5:
        # 保留最后3条消息
        retain_msg_count = 3

        # 如果倒数第3条消息是工具调用,那么就保留最后2条消息
        if messages[-retain_msg_count].type == "tool":
            retain_msg_count =2

        #删除的消息
        print("删除的消息:", messages[:-retain_msg_count])

        # 删除消息
        return {"messages": [RemoveMessage(id=msg.id) for msg in messages[:-retain_msg_count]]}
    return None


agent = create_agent(
    model=deepseek_llm,
    tools=[get_weather],
    middleware=[trim_messages],
    checkpointer=InMemorySaver()
)

config = {"configurable": {"thread_id": "session_001"}}

# 模拟对话
response1 = agent.invoke({"messages": [{"role": "user", "content": "你好,我是张三"}]}, config=config)
print(response1["messages"][-1].content)
print("***"*20)
response2 = agent.invoke({"messages": [{"role": "user", "content": "今天北京天气好吗?"}]}, config=config)
print(response2["messages"][-1].content)
print("***"*20)
response3 = agent.invoke({"messages": [{"role": "user", "content": "上海天气怎么样?"}]}, config=config)
print(response3["messages"][-1].content)
print("***"*20)
final_response = agent.invoke({"messages": [{"role": "user", "content": "我的名字叫什么?"}]}, config=config)
print(final_response["messages"][-1].content)

该代码运行结果如下:

image.png

以上代码注意如下几点:

  1. @before_model中间件标记的函数在调用大模型前执行。trim_messages方法返回修改后的消息列表,消息格式为“{"messages": [...]}”
  2. 一次对话中可能会多次调用“trim_messages”方法,因为涉及到调用工具,会调用多次大模型,只要调用大模型就会调用“trim_messages”方法。
  3. trim_messages方法中当消息数超过5个时,只保留最后的3条消息,如果最后的3条消息中包含ToolMessage,那么只保留最后的2条消息,因为最后一条消息一定是HumanMessage,倒数第二条消息一定是AIMessage,这样避免保留ToolMessage,避免截断工具调用链,破坏对话逻辑从而报错。
  4. messages[:-retain_msg_count] 表示从当前消息列表 messages中截取除最后 retain_msg_count条之外的所有消息。例如:总消息数为 10 条且 retain_msg_count=3,则messages[:-retain_msg_count] 表示截取前 7 条(messages[0:7])
  5. [RemoveMessage(id=msg.id) for msg in ...]中RemoveMessage是 LangChain提供的特殊消息类型,用于标记需删除的消息。这里是对截取到的每条消息,生成一个 RemoveMessage对象,记录需要删除的消息ID。

4.6.2. Delete Message-消息删除

Delete Message(消息删除)是永久性地从状态中移除特定消息,适用于敏感信息清理(如医疗咨询后清除记录)或精确记忆管理。

Delete Message(消息删除)需要在与AI大模型对话完后决定是否删除一些消息,这里需要用到@after_model中间件,该中间件标记一个函数为模型后置处理器,该函数接收当前状态(state)和运行时上下文(runtime),并返回修改后的状态(如过滤后的消息列表)。

如下案例构建的Agent,只要用户提出“删除历史聊天记录”就会自动清空该用户所有聊天历史记录。实现原理:用户进行对话时传入自定义状态“delete_history:False”,当后续用户对话中输入“删除历史聊天记录”后,调用对应工具来更新该自定义状态,然后通过@after_modle中间件来获取该自定义状态的值从而决定是否清空历史聊天记录。

from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import after_model
from langchain.messages import RemoveMessage
from langchain_core.messages import ToolMessage, AIMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.prebuilt import ToolRuntime
from langgraph.runtime import Runtime
from langgraph.types import Command

from init_llm import deepseek_llm

@tool
def get_weather(city: str) -> str:
    """
    获取指定城市的天气
    Args:
        city (str): 要查询天气的城市名称

    Returns:
        str: 包含城市天气信息的字符串
    """
    return f"{city}的天气是晴朗的,温度是25摄氏度"

@tool
def update_delete_history_state(runtime:ToolRuntime,delete_history:bool) ->Command:
    """是否清空聊天历史记录
    Args:
        delete_history (bool): 是否清空聊天历史记录

    Returns:
        Command: 包含更新状态的命令
    """
    # 准备更新内容
    updates = {
        "delete_history": delete_history,
        "messages": [
            ToolMessage(
                content=f"已更新删除聊天历史记录状态:{delete_history}",
                tool_call_id=runtime.tool_call_id
            )
        ]
    }

    return Command(update=updates)


@after_model
def delete_messages(state: AgentState, runtime: Runtime) -> dict:
    print("当前state:", state)

    # 获取删除聊天历史记录状态,决定是否清空聊天历史记录,只在最后一条消息是AIMessage时清空
    delete_history = state.get("delete_history")
    if delete_history:
        # return {
        #   "delete_history": False,
        #   "messages": [RemoveMessage(id=m.id) for m in state["messages"],AIMessage(content="已清空聊天历史记录")]
        # }
        return {
            "delete_history": False,
            "messages":[RemoveMessage(id=REMOVE_ALL_MESSAGES),
                        AIMessage(content="聊天历史记录已经成功删除,现在我们的对话将从新的状态开始,有什么其他我可以帮助你的吗?")]
            }
    return None

class CustomState(AgentState):
    delete_history: bool

agent = create_agent(
    model=deepseek_llm,
    tools=[get_weather,update_delete_history_state],
    middleware=[delete_messages],
    checkpointer=InMemorySaver(),
    state_schema=CustomState
)

config = {"configurable": {"thread_id": "session_001"}}

# 模拟对话
response1 = agent.invoke({
    "messages": [{"role": "user", "content": "你好,我是张三"}],
    "delete_history": False,
}, config=config)

print(response1["messages"][-1].content)

print("***"*20)
response2 = agent.invoke({"messages": [{"role": "user", "content": "今天北京天气好吗?"}]}, config=config)
print(response2["messages"][-1].content)

print("***"*20)
response3 = agent.invoke({"messages": [{"role": "user", "content": "我的名字叫什么?"}]}, config=config)
print(response3["messages"][-1].content)

print("***"*20)
response4 = agent.invoke({"messages": [{"role": "user", "content": "请给我删除聊天历史记录"}]}, config=config)
print(response4["messages"][-1].content)

print("***"*20)
final_response = agent.invoke({"messages": [{"role": "user", "content": "我的名字叫什么?"}]}, config=config)
print(final_response["messages"][-1].content)

以上代码运行结果如下:

image.png

以上代码注意如下几点:

  1. 只要调用完大模型后就会执行@after_model中间件对应的函数delete_messages。
  2. 清空历史聊天记录时,“{"messages":[RemoveMessage(id=REMOVE_ALL_MESSAGES)]}”效果等效于“{"messages": [RemoveMessage(id=m.id) for m in state["messages"]]}”,都是清空所有消息。“REMOVE_ALL_MESSAGES”是 LangGraph 提供的特殊消息标识符,用于指示删除所有历史消息。其底层实现会触发状态管理模块的清空逻辑,而非逐条删除,彻底清空所有消息时,推荐使用 REMOVE_ALL_MESSAGES,效率高。
  3. after_model中删除所有消息后,为了避免回复错误,最后添加一条AIMessage回复,同时设置“delete_history”为False。

4.6.3. Summarize Message-消息摘要

消息截断和消息删除解决长对话token超出LLM上下文方式会大大减少对话过程中token数量,但会导致历史消息丢失,那有没有一种方式既能减少对话中的token消耗,又尽可能不丢失历史关键消息?Summarize Message(消息摘要)的出现正是为了在“保留历史信息”与“控制 token 消耗”之间找到平衡,消息摘要方式通过 LLM 生成对话历史的浓缩摘要,替代原始的旧消息,既减少了 token 占用,又保留了核心上下文。

Summarize Message(消息摘要)是通过LangChain中的内置SummarizationMiddleware中间件进行完成,该中间件可以自动压缩对话历史,通过生成摘要替代旧消息,在保留关键上下文的同时减少 token 消耗,使用方式如下:

agent = create_agent(
    model="gpt-4.1",
    tools=[...],
    middleware=[
        SummarizationMiddleware(
            model="gpt-4.1-mini",  # 摘要生成模型
            trigger=("tokens", 4000),  # 触发条件:token ≥4000
            keep=("messages", 20)  # 保留最近20条消息
        ),
    ],
)

SummarizationMiddleware支持的参数解释如下,具体可以参考:https://docs.langchain.com/oss/python/langchain/middleware/built-in#summarization

参数 参数解释
model 必填项,生成摘要的LLM(如:gpt-4.1-mini)
trigger 必填项,触发摘要条件(支持tokens、messages、fraction等)
keep 保留消息的策略。默认值“(“messages”,20)”
token_counter 自定义token 计数函数(如统计中英文混合文本的 token)。默认为字符计数。
summary_prompt 自定义摘要提示词模板(需包含{messages}占位符)。
trim_tokens_to_summarize 生成摘要时,摘要中允许的最大token 数(超出部分会被截断),默认4000。

案例:通过内置SummarizationMiddleware来进行消息摘要,每当消息数达到5条时,保留最后2条消息,并对之前消息进行摘要。

该案例中使用@before_model/@after_model在调用模型前/后打印对应的消息信息,方便观察摘要处理结果。

from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import after_model, SummarizationMiddleware, before_model
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime

from init_llm import deepseek_llm

@before_model
def print_before_model_state(state: AgentState, runtime: Runtime) -> dict|None:
    # 打印当前消息
    print("before_model_state:", state)

    messages = state["messages"]

    return {"messages": messages}

@after_model
def print_after_model_state(state: AgentState, runtime: Runtime) -> dict|None:
    # 打印当前消息
    print("after_model_state:", state)

    messages = state["messages"]

    return {"messages": messages}

@tool
def get_weather(city: str) -> str:
    """
    获取指定城市的天气
    Args:
        city (str): 要查询天气的城市名称

    Returns:
        str: 包含城市天气信息的字符串
    """
    return f"{city}的天气是晴朗的,温度是25摄氏度"


agent = create_agent(
    model=deepseek_llm,
    tools=[get_weather],
    middleware=[
        print_before_model_state,
        print_after_model_state,
        SummarizationMiddleware(
            model=deepseek_llm,
            trigger=('messages',5), # 当消息数量超过5条时触发总结
            keep=('messages', 2),   # 保留最后2条消息
            summary_prompt="请总结以下对话内容:{messages}"
        )
    ],
    checkpointer=InMemorySaver()
)

config = {"configurable": {"thread_id": "session_001"}}

# 模拟对话
response1 = agent.invoke({"messages": [{"role": "user", "content": "你好,我是张三"}],}, config=config)
print(response1["messages"][-1].content)

print("***"*20)
response2 = agent.invoke({"messages": [{"role": "user", "content": "今天北京天气好吗?"}]}, config=config)
print(response2["messages"][-1].content)

print("***"*20)
response3 = agent.invoke({"messages": [{"role": "user", "content": "我的名字叫什么?"}]}, config=config)
print(response3["messages"][-1].content)

以上代码运行结果如下:

第一次对话输出结果:历史消息数没有达到5条,不进行摘要总结。

image.png

第二次对话输出结果:涉及到调用工具,会多次与大模型交互,会输出多次“before_model_state”/“after_model_state”,到消息达到5条时,保留最后2条消息,对之前消息进行摘要总结。

image.png

image.png

第三次对话输出结果:消息达到5条时,保留最后2条消息,对之前消息进行摘要总结。

image.png

以上代码注意如下几点:

  1. SummarizationMiddleware会在模型调用后回复前进行总结,如果保留N条最新message,AI不会对最近这N条数据进行总结,只是对该N条之前的消息进行总结摘要。
  2. @after_model和@before_model中间件对应方法中最终可以返回None,表示不修改当前状态,继续执行后续流程。也可以返回“return {"messages": messages}”表示将messages内容追加到现有消息列表末尾,追加消息时会根据每个消息的id进行合并,即返回“return {"messages": messages}”也相当于什么都没有改变。

4.6.4. Custom Strategies-自定义策略

Custom Strategies 自定义策略可以针对特定业务需求来解决对话超出LLM上下文长度问题,该方式可以解决标准方案(如消息截断、删除、摘要)无法覆盖的复杂场景(如一些场景中对话历史中包含文本、图片、文件多种信息时,需要根据需要动态保留哪些内容)。

Custom Strategies 自定义策略可以通过中间件@before_model/@after_model/@wrap_model_call 来实现,三种中间件都可以对对话历史消息messages进行处理修改,从而达到自定义保留消息效果。

如下案例中,通过@after_model中间件实现自定义摘要策略,保留最近2条消息,并对最近2条消息之前的所有消息进行摘要,摘要时使用“通义千问”模型。特别注意的是保留的最近2条消息中需要保证完整工具链调用,避免工具调用链截断导致逻辑错误。

from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import after_model, before_model
from langchain_core.messages import RemoveMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from langchain_core.messages import SystemMessage

from init_llm import deepseek_llm, tongyi_llm


@tool
def get_weather(city: str) -> str:
    """
    获取指定城市的天气
    Args:
        city (str): 要查询天气的城市名称

    Returns:
        str: 包含城市天气信息的字符串
    """
    return f"{city}的天气是晴朗的,温度是25摄氏度"


@before_model
def print_before_model_state(state: AgentState, runtime: Runtime) -> dict | None:
    # 打印当前消息
    print("before_model_state:", state)

    messages = state["messages"]

    return {"messages": messages}


@after_model
def custom_summarizer(state: AgentState, runtime) -> dict | None:
    """自定义摘要逻辑,保留最近N条消息,并对N条消息前的所有消息进行摘要
        最后保留的最近N条消息中要保证包含完整调用工具链
    """
    # 打印当前消息
    print("after_model_state:", state)

    messages = state["messages"]
    # 触发摘要的阈值
    threshold = 5
    # 保留最近N条消息
    max_retain = 2

    # 如果 messages 条数达到触发摘要的阈值,那么就进行摘要处理,否则直接返回 None
    if len(messages) <= threshold:
        return None

    # 最近保留N条消息的集合
    recent_messages = messages[-max_retain:]

    # 遍历最近保留N条消息,判断第1条消息是否是 ToolMessage, 如果是那么多往前保留1条消息, 直到第1条消息不是 ToolMessage 为止
    while True:
        # 获取最近保留消息中的第1条消息,判断是不是ToolMessage
        if isinstance(recent_messages[0], ToolMessage):
            # 如果是那么多往前保留1条消息
            max_retain += 1
            recent_messages = messages[-max_retain:]
        else:
            # 如果第一条消息不是 ToolMessage 那么就跳出循环
            break

    early_messages = messages[:-max_retain]  # 保留最近N条消息前的所有消息


    # 准备摘要提示
    summary_prompt = f"""
        请将以下对话内容总结成一段简洁的摘要,保留重要信息和细节:
  
        对话历史:
        {"".join([f"{msg.type}: {msg.content}" for msg in early_messages])}
  
        摘要要求:
        1. 保留人物、地点、关键事件等重要信息
        2. 保持第三人称叙述
        3. 长度不超过200字
        4. 使用中文总结
  
        """

    try:
        # 调用模型生成摘要
        summary_response = tongyi_llm.invoke(summary_prompt)
        summary_content = f"对话摘要: {summary_response.content}"

        # 创建摘要消息
        summary_message = SystemMessage(content=summary_content)

        # 组合新消息列表:摘要 + 最近消息
        new_messages = [summary_message] + recent_messages

        return {
            "messages": [
                RemoveMessage(id=REMOVE_ALL_MESSAGES),
                *new_messages
            ]
        }

    except Exception as e:
        print(f"摘要生成失败: {e}")
        return None


agent = create_agent(
    model=deepseek_llm,
    tools=[get_weather],
    middleware=[print_before_model_state, custom_summarizer],
    checkpointer=InMemorySaver(),
)

config = {"configurable": {"thread_id": "session_001"}}

# 模拟对话
response1 = agent.invoke({"messages": [{"role": "user", "content": "你好,我是张三"}], }, config=config)
print(response1["messages"][-1].content)

print("***" * 20)
response2 = agent.invoke({"messages": [{"role": "user", "content": "今天北京天气好吗?"}]}, config=config)
print(response2["messages"][-1].content)

print("***" * 20)
response3 = agent.invoke({"messages": [{"role": "user", "content": "我的名字叫什么?"}]}, config=config)
print(response3["messages"][-1].content)

以上代码运行结果如下:

image.png

以上代码注意点如下:

  1. 在@after_model中间件中,保留最近N条消息时,特别要注意保留完整的“AIMessage→ToolMessage...→AIMessage”工具链调用,避免截断导致逻辑断裂。例如,将“ToolMessage→AIMessage”传递给大模型时,由于工具链逻辑断裂,会导致报错。
  2. 对话摘要组织成SystemMessage传递给大模型使用,也可以组织成HumanMessage。
  3. @after_model中间件方法最终返回最终的messages消息中,REMOVE_ALL_MESSAGES表示清空整个对话历史,*new_messages展开为新的消息列表,即最终返回了“摘要消息+保留的最近消息”。

5. Agent长期记忆

本章节中,我们重点介绍Agent长期记忆,短期记忆使Agent能够在单次会话中维持对话的连贯性,而长期记忆则赋予了Agent跨会话学习和积累知识的能力。这样Agent能够构建持久的用户画像,形成经验库,并持续优化自身行为,从而提供真正个性化、智能化的服务。

5.1. 长期记忆介绍

长期记忆是一种用于存储用户特定信息或应用级数据的系统,其核心特点是跨越会话和线程共享。与短期记忆局限于单一线程(thread_id)不同,长期记忆中的数据可以被任何时间、在任何线程中召回。其存储范围(作用域)被定义在自定义的命名空间中,而非单个线程ID内。

image.png

在LangChain框架中,长期记忆通过存储(Store)组件来实现,Store允许你将记忆保存为JSON文档,并通过命名空间(namespace)和键(Key)进行组织和管理,便于后续的检索、更新与删除。

  • 命名空间(namespace):类似于文件夹,用于对记忆进行逻辑分组(例如,按应用场景划分)。
  • 键(key):命名空间内每个文档的唯一标识符。

Store存储支持基本的put(写入)、get(读取)、delete(删除)、search(搜索)操作,如下案例中演示基于内存store存储的基本操作。

from langgraph.store.memory import InMemoryStore

# 初始化一个内存存储store
store = InMemoryStore()

# 定义命名空间:通常包含用户ID和上下文
namespace = ("user1", "preferences")

# 写入一个记忆
store.put(
    namespace,
    "fruit", # 键
    {"likes": ["苹果", "香蕉"], "dislikes": ["橙子"]} # 值 (JSON文档)
)

print("=== 读取 fruit 记忆 ===")
# 读取记忆
memory = store.get(namespace, "fruit")
print(memory)

# 写入另一个记忆
store.put(
    namespace,
    "chat", # 键
    {"language": "中文", "emotion": "高兴"} # 值 (JSON文档)
)
print("=== 读取 chat 记忆 ===")
# 读取记忆
memory = store.get(namespace, "chat")
print(memory)

print("=== 搜索所有记忆 ===")
# 搜索所有记忆
memories = store.search(namespace)
print(memories)

# 更新记忆
store.put(
    namespace,
    "fruit", # 键
    {"likes": ["橘子", "葡萄"], "dislikes": ["草莓"]} # 值 (JSON文档)
)

print("=== 读取更新后的 fruit 记忆 ===")
# 读取更新后的记忆
memory = store.get(namespace, "fruit")
print(memory)

以上代码运行结果如下:

image.png

以上代码注意点如下:

  1. 命名空间(namespace)格式为元组格式,建议包含用户ID和上下文特征,这样的层级化命名空间可以更好隔离数据、避免命名空间冲突。
  2. 键(key):在命名空间中唯一的一个key,store中put对应key的值后,后续再put相同的key的值,会覆盖当前key的值。
  3. 键(key)对应的value的值是一个字典结构,该字典中 k,v 内容由用户定义。
  4. store.put 方法可以传入命名空间(元组类型)、key(字符串类型)、value(字典类型)三个参数。
  5. store.get方法可以指定namespace和key参数获取对应的value值;store.search方法可以指定namespace参数获取该命名空间中所有的value。

5.2. Agent中使用长期记忆方式

长期记忆的核心作用是实现跨会话、跨节点的数据共享。在 LangChain 中,单个 Agent 通过将 Store注入到不同会话(thread_id)中,实现跨会话状态持久化。Agent中设置长期记忆的方式和短期记忆类似,长期记忆可以存在内存或者数据库中,下面重点介绍在Agent中如何使用长期记忆。

此外,长期记忆更多使用在 LangGraph 构建的工作流中,流中不同节点通过共享的 Store访问同一份长期记忆。关于LangGraph中长期记忆在后续进行介绍。

5.2.1. 使用内存存储长期记忆

在测试环境中,通常使用内存型的InMemoryStore对象,这种方式简单易用但程序重启后数据会丢失。

如下代码中使用内存存储长期记忆,该案例中创建好InMemoryStore后,首先向该对象中存储一些数据作为长期记忆,后续在工具中获取上下文Context中用户信息,根据不同用户查询不同的长期记忆信息,实现在相同/不同的thread_id中读取相同长期记忆内容。

from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from langchain.tools import tool, ToolRuntime

from pydantic import BaseModel

from init_llm import deepseek_llm


class UserContext(BaseModel):
    user_id: str

# 1. 初始化存储并预先存入一些长期记忆
store = InMemoryStore()
checkpointer = InMemorySaver()

# 预先在长期记忆中存储一些用户信息
store.put(
    ("users",),  # 命名空间:用户数据
    "user_123",  # 键:用户ID
    {"name": "张三", "age": 28, "city": "北京", "hobby": "编程、阅读"}  # 值:用户信息
)
store.put(
    ("users",),
    "user_456",
    {"name": "李四", "age": 32, "city": "上海", "hobby": "旅游、摄影"}
)


# 2. 定义读取长期记忆的工具
@tool
def get_user_info(runtime: ToolRuntime[UserContext]) -> str:
    """
    从长期记忆中获取当前用户的信息。

    Returns:
        str: 用户的详细信息
    """
    # 从runtime中获取store和context
    store = runtime.store
    user_id = runtime.context.user_id

    # 从长期记忆中读取用户信息
    user_data = store.get(("users",), user_id)

    if user_data:
        # 将字典格式化为字符串
        info = user_data.value
        return f"用户信息:姓名-{info['name']}, 年龄-{info['age']}, 城市-{info['city']}, 爱好-{info['hobby']}"
    else:
        return "未找到该用户的信息。"


# 3. 创建带有长期记忆读取工具的Agent
agent = create_agent(
    model=deepseek_llm,
    tools=[get_user_info],
    checkpointer=checkpointer,
    store=store,
    context_schema=UserContext
)

# 4. 演示:在不同线程中读取相同的长期记忆
print("=== 长期记忆的跨线程读取 ===")

# 线程1:用户123询问自己的信息
print("线程1 - 用户123询问信息:")
result1 = agent.invoke(
    {"messages": [{"role": "user", "content": "我的个人信息是什么?"}]},
    config={"configurable": {"thread_id": "thread_1"}},
    context=UserContext(user_id="user_123")  # 用户ID决定访问哪个长期记忆
)
print(f"Agent回复: {result1['messages'][-1].content}\n")

# 线程2:另一个线程,相同的用户
print("线程2 - 相同的用户,不同的线程:")
result2 = agent.invoke(
    {"messages": [{"role": "user", "content": "再告诉我一次我的信息"}]},
    config={"configurable": {"thread_id": "thread_2"}},  # 线程ID不同
    context=UserContext(user_id="user_123")  # 但用户ID相同
)
print(f"Agent回复: {result2['messages'][-1].content}\n")

# 线程3:另一个用户
print("线程3 - 用户456询问信息:")
result3 = agent.invoke(
    {"messages": [{"role": "user", "content": "我的信息是什么?"}]},
    config={"configurable": {"thread_id": "thread_3"}},
    context={"user_id": "user_456"}  # 用户ID不同,访问不同的长期记忆
)
print(f"Agent回复: {result3['messages'][-1].content}")

以上代码运行结果如下:

image.png

以上代码注意:Agent创建中通过Store参数指定长期记忆对象,使用长期记忆时,可以同时使用短期记忆(checkpointer指定),也可以不使用短期记忆,但一般两者会同时使用。

5.2.2. 使用数据库存储长期记忆

在生产环境中,推荐使用数据库支持的Store,以确保数据的持久化和多实例部署的支持。

如下案例中使用mysql数据库来存储长期记忆,完成此案例需要提前安装好mysql数据库(默认已经安装mysql8),并且进行数据库创建和安装必要python依赖,具体如下:

1) 在mysql中创建langchain_db数据库

#进入mysql navicate客户端,创建mysql数据库langchain_db
create database langchain_db;

2) 在当前python环境中安装如下依赖

#安装必要依赖
conda activate langchain_v1.2
python -m pip install langgraph-checkpoint-mysql==3.0.0 pymysql==1.1.2 cryptography==46.0.3
python -m pip install aiomysql==0.3.2 asyncmy==0.2.11

使用mysql数据库存储长期记忆代码如下:

from langchain.agents import create_agent
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver
from langchain.tools import tool, ToolRuntime
from langgraph.store.mysql import PyMySQLStore
from pydantic import BaseModel
from init_llm import deepseek_llm


class UserContext(BaseModel):
    user_id: str

# 1. 定义读取长期记忆的工具
@tool
def get_user_info(runtime: ToolRuntime[UserContext]) -> str:
    """
    从长期记忆中获取当前用户的信息。

    Returns:
        str: 用户的详细信息
    """
    # 从runtime中获取store和context
    store = runtime.store
    user_id = runtime.context.user_id

    # 从长期记忆中读取用户信息
    user_data = store.get(("users",), user_id)

    if user_data:
        # 将字典格式化为字符串
        info = user_data.value
        return f"用户信息:姓名-{info['name']}, 年龄-{info['age']}, 城市-{info['city']}, 爱好-{info['hobby']}"
    else:
        return "未找到该用户的信息。"

# 配置 MySQL 连接
DB_URI = "mysql+pymysql://root:123456@localhost:3306/langchain_db?charset=utf8mb4"

with (
    PyMySQLSaver.from_conn_string(DB_URI) as checkpointer,
    PyMySQLStore.from_conn_string(DB_URI) as store
):
    # 自动创建checkpointer数据库表(首次运行)
    checkpointer.setup()
    # 2. 自动创建store数据库表(首次运行)
    store.setup()

    # 3. 预先在长期记忆中存储一些用户信息
    store.put(
        ("users",),  # 命名空间:用户数据
        "user_123",  # 键:用户ID
        {"name": "张三", "age": 28, "city": "北京", "hobby": "编程、阅读"},  # 值:用户信息
    )
    store.put(
        ("users",),
        "user_456",
        {"name": "李四", "age": 32, "city": "上海", "hobby": "旅游、摄影"},
    )

    # 4. 创建带有长期记忆读取工具的Agent
    agent = create_agent(
        model=deepseek_llm,
        tools=[get_user_info],
        checkpointer=checkpointer,
        store=store,
        context_schema=UserContext
    )

    # 5. 演示:在不同线程中读取相同的长期记忆
    print("=== 长期记忆的跨线程读取 ===")

    # 线程1:用户123询问自己的信息
    print("线程1 - 用户123询问信息:")
    result1 = agent.invoke(
        {"messages": [{"role": "user", "content": "我的个人信息是什么?"}]},
        config={"configurable": {"thread_id": "thread_1"}},
        context=UserContext(user_id="user_123")  # 用户ID决定访问哪个长期记忆
    )
    print(f"Agent回复: {result1['messages'][-1].content}\n")

    # 线程2:另一个线程,相同的用户
    print("线程2 - 相同的用户,不同的线程:")
    result2 = agent.invoke(
        {"messages": [{"role": "user", "content": "再告诉我一次我的信息"}]},
        config={"configurable": {"thread_id": "thread_2"}},  # 线程ID不同
        context=UserContext(user_id="user_123")  # 但用户ID相同
    )
    print(f"Agent回复: {result2['messages'][-1].content}\n")

    # 线程3:另一个用户
    print("线程3 - 用户456询问信息:")
    result3 = agent.invoke(
        {"messages": [{"role": "user", "content": "我的信息是什么?"}]},
        config={"configurable": {"thread_id": "thread_3"}},
        context=UserContext(user_id="user_456")  # 用户ID不同,访问不同的长期记忆
    )
    print(f"Agent回复: {result3['messages'][-1].content}")

以上代码运行结果如下:

image.png

运行代码后进入到Mysql数据库中可以看到对应的数据库表和数据:

image.png

以上代码需要注意如下几点:

  1. “with PyMySQLStore.from_conn_string(DB_URI) as store:”通过连接字符串DB_URI创建与MySQL数据库的持久化连接,用于管理长期记忆的存储。
  2. 使用mysql存储长期记忆需要提前在数据库中创建对应的数据库,然后代码首次运行执行“store.setup()”(首次运行需要,首次运行后可以不再执行该代码)会自动在该数据库中创建对应数据库表。
  3. 也可以使用其他数据库进行短期记忆的持久化存储,例如使用postgresql存储,需要安装“pip install langgraph-checkpoint-postgres==3.0.4”,具体代码参考:https://docs.langchain.com/oss/python/langchain/short-term-memory#in-production
  4. 持久化存储支持的数据库可以通过“https://pypi.org/search/?o=&q=langgraph-checkpoint&page=2”查看,搜索“langgraph-checkpoint-*”查看对应需要安装的依赖和使用方式。

image.png

5.3. 工具中实现长期记忆读写

长期记忆的核心使用场景是在工具中进行读写操作。LangChain通过ToolRuntime将store和context注入到工具函数中,使工具能够访问和修改长期记忆。如下案例中通过工具使用长期记忆记录用户偏好。

from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from langchain.tools import tool, ToolRuntime
from typing import TypedDict, Literal

from pydantic import BaseModel, Field

from init_llm import deepseek_llm
import uuid


class UserContext(BaseModel):
    user_id: str

# 定义工具输入的类型
class UserPreference(BaseModel):
    category: Literal["color", "food", "music"] = Field(description=" 用户偏好类别,必须是 'color', 'food', 'music' 中的一个")
    preference: str = Field(description="具体偏好内容,如'红色'、'中国美食'等")


# 1. 初始化存储
store = InMemoryStore()
checkpointer = InMemorySaver()


# 2. 定义写入长期记忆的工具
@tool(args_schema=UserPreference)
def save_user_preference(category: str, preference: str, runtime: ToolRuntime) -> str:
    """
    将用户偏好保存到长期记忆中。

    Args:
        category: 用户偏好类别,必须是 "color", "food", "music" 中的一个
        preference: 具体偏好内容,如'红色'、'中国美食'等
        runtime: ToolRuntime  # 包含长期记忆存储和上下文
    Returns:
        str: 操作结果描述
    """
    user_id = runtime.context.user_id

    # 创建命名空间:(user_id, "preferences")
    namespace = (user_id, "preferences")

    # 生成唯一记忆ID
    memory_id = str(uuid.uuid4())

    # 准备要保存的数据
    memory_value = {
        "category": category,
        "preference": preference,
    }

    # 保存到长期记忆
    runtime.store.put(namespace, memory_id, memory_value)

    return f"已成功保存你的{category}偏好:{preference}"


# 3. 定义读取用户偏好的工具
@tool
def get_user_preferences(runtime: ToolRuntime) -> str:
    """
    从长期记忆中获取用户特定类别的所有偏好。

    Returns:
        str: 用户的偏好列表
    """
    user_id = runtime.context.user_id
    namespace = (user_id, "preferences")

    # 搜索该命名空间下的所有记忆
    memories = runtime.store.search(namespace)

    if not memories:
        return f"您还没有保存过偏好"

    print("memories:", memories)

    # 格式化所有偏好为字符串列表
    preferences_list = []
    for mem in memories:
        pref = mem.value
        preferences_list.append(f"- 种类:{pref['category']},偏好:{pref['preference']}")

    return f"你的偏好有:\n" + "\n".join(preferences_list)


# 4. 创建带有长期记忆读写工具的Agent
memory_agent = create_agent(
    model=deepseek_llm,
    tools=[save_user_preference, get_user_preferences],
    checkpointer=checkpointer,
    store=store,
    context_schema=UserContext
)

# 5. 演示:完整的长时期记忆读写流程
print("=== 完整演示:长期记忆的写入与跨线程读取 ===")

# 第一轮:用户保存颜色偏好(线程1)
print("第一轮(线程1):用户保存颜色偏好")
result1 = memory_agent.invoke(
    {"messages": [{"role": "user", "content": "请记住我喜欢的颜色是蓝色"}]},
    config={"configurable": {"thread_id": "thread1"}},
    context=UserContext(user_id="current_user")
)
print(f"Agent回复: {result1['messages'][-1].content}")

# 第二轮:用户保存食物偏好(同一线程)
print("第二轮(同一线程):用户保存食物偏好")
result2 = memory_agent.invoke(
    {"messages": [{"role": "user", "content": "我还喜欢的食物是意大利面"}]},
    config={"configurable": {"thread_id": "thread1"}},  # 同一线程
    context=UserContext(user_id="current_user")
)
print(f"Agent回复: {result2['messages'][-1].content}")

# 第三轮:在新线程中查询所有偏好
print("第三轮(新线程):查询我的所有偏好")
result3 = memory_agent.invoke(
    {"messages": [{"role": "user", "content": "告诉我我都喜欢什么颜色和食物"}]},
    config={"configurable": {"thread_id": "thread2"}},  # 新线程
    context=UserContext(user_id="current_user")  # 相同用户
)
print(f"Agent回复: {result3['messages'][-1].content}")


# 直接验证:从store中读取数据
print("=== 直接验证:从长期记忆存储中读取数据 ===")
# 读取颜色偏好
color_memories = store.search(("current_user", "preferences"))
print(f"长期记忆中存储的颜色偏好: {[m.value for m in color_memories if m.value['category'] == 'color']}")

# 读取食物偏好
food_memories = store.search(("current_user", "preferences"))
print(f"长期记忆中存储的食物偏好: {[m.value for m in food_memories if m.value['category'] == 'food']}")

代码运行结果如下:

image.png

5.4. 短期记忆和长期记忆区别总结

如下是短期记忆和长期记忆区别:

对比维度 短期记忆(Short-term Memory) 长期记忆(Long-term Memory)
作用域 线程/会话范围记忆 (Thread-scoped)。与单个会话线程 (thread_id) 绑定。 跨线程/会话记忆。存储在自定义命名空间 (namespace) 中,可被多个线程共享。
核心目的 保证单次对话的连贯性和上下文感知。 实现跨对话的个性化、知识积累和持续学习。
主要存储内容 对话的原始历史 (messages列表) 以及当前会话的状态数据。 从交互中提炼的结构化知识(如用户事实、行为经验、优化规则)。
管理组件 检查点 (Checkpointer),如 InMemorySaver、PyMySQLSaver。 存储 (Store),如 InMemoryStore、PostgresStore。
生命周期 随线程的创建而开始,随线程的销毁(或超时清理)而结束。 独立于任何特定线程,除非被显式删除,否则永久或长期存在。
访问方式 自动管理。Agent的状态在每个步骤后自动持久化到检查点,并在下次恢复。 手动控制。必须在工具(Tool) 或自定义逻辑中,通过代码显式地调用 store.put()或 store.get()。
典型应用场景 维持聊天上下文,让Agent记得用户在当前对话中刚说过的话。 记住用户的身份、偏好、历史行为。

5.5. 记忆综合案例-电商客服助手

如下案例是一个智能电商客服助手案例,该助手具备如下能力:

  • 记住当前会话状态(短期记忆):用户正在查询的订单号。
  • 了解用户长期偏好(长期记忆):用户偏好的商品类型和商品名称。
  • 处理多轮复杂对话:通过消息摘要管理长对话上下文。
  • 工具调用:查询用户信息、查询订单信息、更新用户偏好、基于用户偏好给用户推荐商品。
  • 工具调用错误捕获:当工具调用错误时,通过中间件进行错误捕获,返回友好提示。

该案例中长期和短期记忆使用MySQL进行存储,需要在MySQL中创建对应的库:

#进入mysql navicate客户端,创建mysql数据库langchain_db
drop database langchain_db;
create database langchain_db;

在项目中安装必要依赖:

#安装必要依赖
conda activate langchain_v1.2
python -m pip install langgraph-checkpoint-mysql==3.0.0 pymysql==1.1.2 cryptography==46.0.3
python - m pip install aiomysql==0.3.2 asyncmy==0.2.11

案例代码如下:

"""
智能电商客服助手
功能:结合短期记忆、长期记忆、消息摘要,提供流式、个性化的客服服务。
前置准备:
1. 安装依赖: pip install langchain langgraph langgraph-checkpoint-mysql pymysql
2. 创建MySQL数据库: CREATE DATABASE langchain_memory_db;
"""
import uuid
import warnings
from typing import List, Optional
from pydantic import BaseModel, Field
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import SummarizationMiddleware, wrap_tool_call
from langchain_core.tools import tool
from langchain_core.messages import ToolMessage
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver
from langgraph.store.mysql.pymysql import PyMySQLStore
from langgraph.prebuilt import ToolRuntime
from langgraph.types import Command
from init_llm import deepseek_llm

# 禁用Pydantic序列化警告
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic.main")

# ========== 1. 定义Context 上下文 Schema==========
class UserContext(BaseModel):
    """定义调用Agent时传入的静态上下文信息"""
    user_id: str = Field(description="用户的唯一标识符")
    channel: str = Field(description="用户咨询渠道,如: APP, Web, 小程序")


# ========== 2. 定义自定义短期记忆状态 (继承AgentState) ==========
class CustomerSessionState(AgentState):
    """自定义短期记忆状态,用于管理单次会话中的动态信息"""
    current_order_id: str  # 用户当前正在查询的订单号


# ========== 3. 模拟订单数据库 数据 ==========
MOCK_DATABASE = {
    "orders": {
        "order001": {"order_id": "order001", "status": "已发货", "product": "智能手机",
                     "preference_context": "华为手机P70"},
        "order002": {"order_id": "order002", "status": "待支付", "product": "智能手表",
                     "preference_context": "Apple Watch Series 8"},
    }

}


# ========== 4. 定义工具 ==========
@tool
def get_user_info(runtime: ToolRuntime) -> str:
    """
    获取用户当前用户信息
    Args:
        runtime (ToolRuntime): 包含上下文信息的运行时环境
    Returns:
        str: 用户当前用户信息
    """
    print("get_user_info 中 runtime:", runtime)
    # 从上下文中获取当前用户ID
    current_user_id = runtime.context.user_id

    # 从上下文中获取用户咨询渠道
    user_channel = runtime.context.channel

    # 从状态中获取用户当前正在查询的订单号
    state = runtime.state
    if "current_order_id" in state:
        current_order_id = state["current_order_id"]
    else:
        current_order_id = "无"

    # 获取当前用户信息
    return f"用户ID: {current_user_id}, 咨询渠道: {user_channel}, 当前查询订单号: {current_order_id}"


@tool
def query_order_status(order_id: str, runtime: ToolRuntime) -> Command:
    """
    查询用户订单状态
    Args:
        order_id (str): 用户订单号
        runtime (ToolRuntime): 包含上下文信息的运行时环境
    Returns:
        Command: 包含更新操作的命令对象:状态中更新当前订单ID,并返回订单信息(状态、商品、用户偏好)
    """

    # 查询订单状态
    order_info = MOCK_DATABASE["orders"].get(order_id)

    if not order_info:
        return Command(
            update={
                "messages": [
                    ToolMessage(
                        content=f"错误:订单 [{order_id}] 不存在",
                        tool_call_id=runtime.tool_call_id
                    )
                ]
            }
        )

    updates = {
        "current_order_id": order_id,
        "messages": [
            ToolMessage(
                content=f"订单 [{order_id}] 状态: {order_info['status']}, 商品: {order_info['product']}。"
                        f"需要进行用户偏好更新,用户偏好: {order_info['preference_context']}",
                tool_call_id=runtime.tool_call_id
            )
        ]
    }

    return Command(update=updates)


@tool
def update_user_preference(category: str, liked_item: str, runtime: ToolRuntime) -> str:
    """
    更新用户长期偏好
    Args:
        category (str): 商品类别,如: 手机、配件
        liked_item (str): 用户喜欢的具体商品
        runtime (ToolRuntime): 包含上下文信息的运行时环境
    Returns:
        str: 确认更新结果
    """
    user_id = runtime.context.user_id
    namespace = (f"user_{user_id}", "preferences")

    key = str(uuid.uuid4())

    value_to_store = {
        "category": category,
        "liked_item": liked_item,
    }

    # 写入到长期记忆
    runtime.store.put(namespace, key, value_to_store)
    return f"已成功将您的偏好记录到长期记忆: 喜欢 {category} 类的 {liked_item}。"


@tool
def get_recommendation(runtime: ToolRuntime) -> str:
    """
    获取用户推荐商品
    Args:
        runtime (ToolRuntime): 包含上下文信息的运行时环境
    Returns:
        str: 包含推荐商品信息的字符串
    """
    user_id = runtime.context.user_id
    current_order = runtime.state.get("current_order_id", "未知订单")
    namespace = (f"user_{user_id}", "preferences")
    prefs = runtime.store.search(namespace)

    pref_list = []
    if prefs:
        for p in prefs[-3:]:  # 仅取最近3条偏好记录,[-3:] 表示取最后3条记录
            pref_list.append(f"{p.value.get('category')}({p.value.get('liked_item')})")

    return f"基于用户当前的订单 [{current_order}] 和长期偏好 {pref_list if pref_list else '无'},为用户推荐相关配件或类似风格商品。"


@wrap_tool_call
def handle_tool_errors(request, handler):
    """使用自定义消息处理工具执行错误"""
    try:
        return handler(request)
    except Exception as e:
        # 向模型返回自定义错误消息
        return ToolMessage(
            content=f"调用工具错误:请稍后重试,错误信息:({str(e)})",
            tool_call_id=request.tool_call["id"]
        )


# ========== 5. 创建Agent,控制台交互循环 ==========
DB_URI = "mysql+pymysql://root:123456@localhost:3306/langchain_db?charset=utf8mb4"

# 初始化MySQL存储 (短期记忆Checkpointer 和 长期记忆Store)
with (
    PyMySQLSaver.from_conn_string(DB_URI) as checkpointer,
    PyMySQLStore.from_conn_string(DB_URI) as store
):
    # 首次运行时自动建表
    checkpointer.setup()
    store.setup()

    # 创建Agent
    agent = create_agent(
        model=deepseek_llm,
        tools=[get_user_info, query_order_status, update_user_preference, get_recommendation],
        system_prompt="""
                        你是一个智能电商客服助手,具备回答用户咨询、获取用户信息、查询订单状态、更新用户偏好和推荐商品功能。"
                        获取用户信息请调用 get_user_info 工具。
                        查询订单状态请调用 query_order_status 工具,查询到订单状态后,还需要调用 update_user_preference 工具更新用户偏好。
                        更新用户偏好请调用 update_user_preference 工具。
                        获取推荐商品请调用 get_recommendation 工具。
                      """,
        checkpointer=checkpointer,
        store=store,
        state_schema=CustomerSessionState,
        context_schema=UserContext,
        middleware=[
            SummarizationMiddleware(
                model=deepseek_llm,
                summary_prompt="请总结以下对话内容:{messages}",
                trigger=("messages", 10),  # 每10条消息触发一次摘要
                keep=("messages", 5),  # 保留最后5条消息
            ),
            handle_tool_errors
        ],
    )

    # 控制台交互循环 (流式调用)
    print("=" * 50)
    print("智能电商客服助手")
    print("功能: 查询订单、更新偏好、获取推荐。")
    print("输入 'quit' 或 '退出' 结束对话。")
    print("=" * 50)

    # 初始化用户上下文
    user_context = UserContext(user_id="customer_001", channel="Web")
    # 会话线程ID
    config = {"configurable": {"thread_id": "session_01"}}

    # 对话循环
    while True:
        try:
            user_input = input("[你]: ").strip()

            if user_input.lower() in ['quit', 'exit', '退出', 'q']:
                print("客服助手: 感谢你的咨询,再见!")
                break

            # 过滤空输入
            if not user_input:
                continue

            # 准备输入消息
            input_data = {"messages": {"role": "user", "content": user_input}}

            print("[客服助手]: ")
            # 流式调用Agent
            for chunk in agent.stream(input_data, config=config, context=user_context):
                # print("chunk:", chunk)
                for step, data in chunk.items():  # 遍历dict的key-value对
                    # print("step:", step)
                    # print("data:", data)

                    # 只有当 step为model或者tools时,才打印消息
                    if step in ["model", "tools"]:
                        message = data["messages"][-1]
                        message.pretty_print()

        except Exception as e:
            print(f"调用过程中出现错误: {e}")

代码运行后,对于thread_id为session_01 时,进行如下对话:

给我查询订单 order001 信息
给我推荐一些商品
我喜欢 苹果电脑 尤其是 mac pro
查询我的信息
我还喜欢 索尼的xm5耳机
给我推荐一些商品

停止程序后,切换thread_id为session_02后,进行如下对话:

查询我的信息
给我推荐一些商品

可以看到长期记忆信息在不同的会话之间是共享的,而短期记忆只针对每个thread_id生效。

此外,以上代码运行中,会出现如下警告:

UserWarning: Pydantic serializer warnings:
  PydanticSerializationUnexpectedValue(Expected `none` - serialized value may not be as expected [field_name='context', input_value=UserContext(user_id='customer_001', channel='Web'), input_type=UserContext])
  return self.__pydantic_serializer__.to_python(

该警告是将Pydantic模型对象直接传递给了LangChain Agent的context参数,而LangChain内部在处理这个对象时,期望的是基本的Python数据类型(如字典),而不是Pydantic模型对象,导致了序列化警告。去除该警告,可以在项目开始导入:“warnings.filterwarnings("ignore", category=UserWarning, module="pydantic.main")”

posted @ 2026-04-05 00:10  rbcd  阅读(0)  评论(0)    收藏  举报