27. Agent 需要拦截模型调用?用 Middleware 给它加个“拦截器“!
你有没有遇到过这种情况:Agent 跑得好好的,你想加个日志看看它到底在干嘛,或者想加个安全检查防止它搞出危险操作,结果发现不知道往哪儿插?
Middleware 就是来解决这个问题的。说白了,它就是一个"拦截器",让你在模型调用前后插入你自己的逻辑。听起来是不是挺简单的?别急,咱们直接上手写代码,你一看就明白了。
动画视频在《27. Agent 需要拦截模型调用?用 Middleware 给它加个"拦截器"!》。

第一个 Middleware:日志记录
咱们先从一个最简单的需求开始——记录日志。我想知道每次调用模型的时候,当前有多少条消息,模型又回了什么。
怎么做呢?很简单,写一个类,继承 AgentMiddleware,然后实现两个方法就行。
from langchain.agents.middleware import AgentMiddleware from langchain.agents import AgentState from langgraph.runtime import Runtime class LoggingMiddleware(AgentMiddleware): def before_model(self, state: AgentState, runtime: Runtime) -> None: print(f"[日志] 即将调用模型,当前消息数: {len(state['messages'])}") def after_model(self, state: AgentState, runtime: Runtime) -> None: last_msg = state['messages'][-1] print(f"[日志] 模型已响应: {last_msg.content[:50]}...")
你注意看啊,这里有两个方法。before_model 就是在模型调用之前执行的,我打印一下当前消息的数量。after_model 是模型响应之后执行的,我取最后一条消息,截取前 50 个字符看看模型回了啥。
关键点来了——你注意这两个方法的返回类型,都是 None。为啥?因为这个 Middleware 只是记录日志,它不需要干预流程,所以直接返回 None,就是在告诉框架:"我完事了,你继续往下走就行。"
好,日志中间件搞定了。是不是特别简单?接下来咱们上点强度——写一个安全检查中间件。
第二个 Middleware:安全检查
这个中间件的作用是啥呢?拦截危险操作。比如用户说"删除所有文件",你肯定不想让 Agent 傻乎乎地去执行吧?
同样继承 AgentMiddleware,重点在 before_model 里做文章。
from langchain_core.messages import AIMessage class SafetyMiddleware(AgentMiddleware): def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: last_msg = state['messages'][-1].content if "删除" in last_msg or "危险" in last_msg: return { "jump_to": "end", "messages": [AIMessage(content="检测到危险操作,已终止")] } return None
这里有个关键区别,你一定要注意到——返回类型变了,变成了 dict | None,而不是单纯的 None。
为什么?因为这个 Middleware 有可能需要干预流程。
逻辑是这样的:拿到最后一条消息的内容,检查一下里面有没有"删除"或者"危险"这样的关键词。如果命中了,就返回一个字典,jump_to 设为 "end",意思是跳过模型调用直接结束。同时塞一条 AIMessage 告诉用户操作被终止了。这个返回值就是在告诉框架:"别调用模型了,直接结束,并且把这条消息加进去。"
如果没有问题呢?就返回 None,意思是"我没意见,流程正常继续"。
你看,这就是 Middleware 的精髓——你可以选择不管,也可以选择直接接管整个流程。
完整实战代码
好,两个 Middleware 都写好了,接下来咱们把它们组装起来,看看实际效果。
先把需要的依赖都导入进来,然后配置模型、定义工具。
import os import sqlite3 from dotenv import load_dotenv from langchain.agents import create_agent, AgentState from langchain.agents.middleware import AgentMiddleware from langchain.chat_models import init_chat_model from langchain_classic.agents import Agent from langchain_community.tools import WriteFileTool, ReadFileTool, ListDirectoryTool from langchain_core.messages import AIMessage from langchain_core.tools import tool, BaseTool from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.runtime import Runtime from langgraph.store.memory import InMemoryStore from langgraph.store.sqlite import SqliteStore load_dotenv() prefix = "QWEN" model = init_chat_model( model_provider="openai", configurable_fields=["model", "api_key", "base_url"], config_prefix=prefix).with_config({ "configurable": { f"{prefix}_model": os.getenv(f"{prefix}_MODEL"), f"{prefix}_api_key": os.getenv(f"{prefix}_API_KEY"), f"{prefix}_base_url": os.getenv(f"{prefix}_BASE_URL") }}) class CalculateTool(BaseTool): name: str = "calculate" description: str = "计算数学表达式的值" def _run(self, expression: str) -> str: try: return f"计算结果:{eval(expression)}" except Exception as e: return f"计算错误:{str(e)}" async def _arun(self, expression: str) -> str: return self._run(expression) # Middleware 1:记录模型调用日志 class LoggingMiddleware(AgentMiddleware): def before_model(self, state: AgentState, runtime: Runtime) -> None: print(f"[日志] 即将调用模型,当前消息数: {len(state['messages'])}") def after_model(self, state: AgentState, runtime: Runtime) -> None: last_msg = state['messages'][-1] print(f"[日志] 模型已响应: {last_msg.content[:50]}...") # Middleware 2:安全检查,拦截危险操作 class SafetyMiddleware(AgentMiddleware): def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: last_msg = state['messages'][-1].content if "删除" in last_msg or "危险" in last_msg: return { "jump_to": "end", "messages": [AIMessage(content="检测到危险操作,已终止")] } return None calculate = CalculateTool() write_file = WriteFileTool() read_file = ReadFileTool() list_dir = ListDirectoryTool() checkpoint_conn = sqlite3.connect("agent.db", check_same_thread=False, isolation_level=None) checkpointer = SqliteSaver(checkpoint_conn) store_conn = sqlite3.connect("agent.db", check_same_thread=False, isolation_level=None) store = SqliteStore(store_conn) agent = create_agent( model=model, tools=[calculate, write_file, read_file, list_dir], system_prompt="你是一个助手,会用工具计算、读写文件、列出目录。", debug=True, checkpointer=checkpointer, store=store, middleware=[LoggingMiddleware(), SafetyMiddleware()] ) config = {"configurable": {"thread_id": "session-1"}} store.put(("user", "user-1"), "profile", {"name": "张三", "role": "developer", "skills": ["python", "typescript", "java"]}) profile = store.get(("user", "user-1"), "profile") print(f"用户资料:{profile.value}") queries = ["计算 2024*12+500,然后把结果保存到 result.txt", "读取 result.txt的内容", "列出当前目录文件", "刚才计算的结果是多少?", "删除所有文件" ] for q in queries: print(f"\n问:{q}") response = agent.invoke({"messages": [{"role": "user", "content": q}]}, config=config) print(response) print(f"\n答:{response['messages'][-1].content}") checkpoint_conn.close() store_conn.close()
代码比较长,但核心逻辑其实就三块:模型配置、Middleware 定义、Agent 创建。
你重点看 create_agent 这一行——注意 middleware 这个参数,咱们把 LoggingMiddleware 和 SafetyMiddleware 两个实例一起传进去了。就这么一行,Agent 就同时具备了日志记录和安全拦截的能力。
然后下面准备了一些测试问题,挨个发给 Agent 执行。你注意最后一个问题——"删除所有文件",这就是用来测试安全检查中间件的。
运行效果
跑起来之后你会看到什么呢?
首先,每次模型调用前后,控制台都会打印日志信息,告诉你当前消息数和模型响应内容,这就是 LoggingMiddleware 在干活。
然后,当前面几个正常问题执行的时候,一切顺利。但是到了最后一个问题"删除所有文件"的时候,Agent 直接返回"检测到危险操作,已终止"——根本不会去调用模型。这就是 SafetyMiddleware 在拦截。
总结
好了,今天的内容就到这里,咱们快速回顾一下。
Middleware 说白了就是 Agent 的"拦截器",它给你提供了两个钩子:before_model 和 after_model。你可以在模型调用前注入上下文、检查权限,也可以在模型响应后记录日志、审计操作。
最关键的是,你可以通过返回值来控制流程——返回 None 就是"我没意见,继续走",返回一个 dict 就可以直接改变流程走向,比如跳过模型调用。
这个机制非常灵活,你可以用它做很多事情:限流、鉴权、日志、审计、上下文注入……基本上你能想到的横切关注点,都可以用 Middleware 来优雅地实现。
如果觉得有用的话,记得点赞关注,咱们下期再见!

浙公网安备 33010602011771号