LangChain学习2 完成一个mysql数据库管理的示例agent
前言
LangChain中的模块都过了一遍了,来做个总结。
大概有以下知识点:提示词模版,Chain, 记忆,工具,agent, RAG
RAG没看。这官方文档看的我脑袋疼,尚硅谷的视频比黑马的详细些。但是都没有langgraph的相关内容。
根据学的东西做了一个小agent,可以用来参考学习,个人感觉写的还是比较规范的。
项目
数据库操作助手Agent ,基于LangChain框架实现,专门用于帮助用户管理和查询MySQL数据库。
完成的功能
1. 数据库连接管理 :
- 连接到MySQL数据库(支持指定host、port、user、password、db参数)
- 获取数据库的完整schema信息(包括表结构、字段信息、主外键关系和索引等)
- 关闭数据库连接
2. 智能交互 :
- 基于ChatOpenAI模型(使用deepseek-chat)实现自然语言交互
- 支持流式输出,提供实时反馈
- 能够理解用户的数据库操作请求并执行相应的工具调用
3. 长期记忆功能 :
- 基于文件存储会话记录,以session_id为文件名
- 实现了会话记录的持久化存储,程序重启后仍保留历史对话
- 支持手动清除会话记录(输入"清除记忆"命令)
技术栈
- 编程语言 :Python
- 核心框架 :LangChain(langchain_classic、langchain_community、langchain_core)
- AI模型 :ChatOpenAI(使用deepseek-chat模型)
- 数据库驱动 :aiomysql(异步MySQL连接)
- 存储方式 :文件存储(JSON格式)
项目结构
├── database_agent.py # 主文件,实现数据库操作助手Agent
├── database_tools.py # 数据库工具函数
├── message_history.py # 基于文件的会话记录存储实现
└── chat_history/ # 存储会话记录的文件夹
└── database_agent_session.json # 会话记录文件
代码
database_agent.py
from langchain_classic.agents import create_openai_tools_agent, AgentExecutor
from langchain_classic.memory import ConversationBufferMemory
from langchain_community.chat_models import ChatOpenAI
from langchain_core.callbacks import StreamingStdOutCallbackHandler
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import asyncio
from database_tools import connect_to_mysql, get_database_schema, close_connection
from message_history import FileChatMessageHistory
# 创建模型实例(支持流式输出)
model = ChatOpenAI(
api_key="",
model="deepseek-chat",
base_url="https://api.deepseek.com",
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()]
)
# 创建工具列表 - 直接使用装饰器修饰的函数
tools = [
connect_to_mysql,
get_database_schema,
close_connection
]
# 系统提示词
system_prompt = """你是一个数据库操作助手,专门帮助用户管理MySQL数据库。
你可以使用以下工具:
1. connect_to_mysql: 连接到MySQL数据库
参数:
- host: 数据库主机地址
- port: 数据库端口
- user: 数据库用户名
- password: 数据库密码
- db: 数据库名称
2. get_database_schema: 获取数据库的完整schema信息,包括表结构、字段信息、主外键关系和索引等
需要先连接数据库
3. close_connection: 关闭数据库连接
重要说明:
1. 当用户提供连接参数时(如:host='localhost', port=3306, user='root', password='123456', db='test'),你应该提取这些参数并调用connect_to_mysql工具
2. 连接成功后,你可以调用get_database_schema获取数据库schema信息
3. 基于schema信息回答用户的问题,比如有哪些表、表结构是什么
4. 如果已经获取过schema信息,可以直接基于已有信息回答,不需要重复获取
5. 对于不需要工具调用的问题,直接回答
现在开始,请根据用户的问题选择合适的工具并提供帮助。"""
# 创建提示词模板
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
])
# 创建文件存储的消息历史
file_history = FileChatMessageHistory(session_id="database_agent_session")
# 创建记忆,使用文件存储的消息历史
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="output",
chat_memory=file_history
)
# 创建Agent
agent = create_openai_tools_agent(
llm=model,
tools=tools,
prompt=prompt
)
# 创建Agent执行器
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
memory=memory, # 记忆
verbose=False, # 是否显示详细日志
max_iterations=5, # 最大迭代次数
handle_parsing_errors=True, # 自动捕获错误并修复
early_stopping_method="generate" # 当达到最大迭代次数时,生成最终回答
)
async def chat_with_agent():
"""
与数据库Agent进行对话
"""
print("=== 数据库Agent ===")
print("输入'退出'结束对话")
print("输入'清除记忆'清除会话记录")
print("\n示例用法:")
print("1. 连接数据库: 请连接到MySQL数据库,host='localhost', port=3306, user='root', password='123456', db='test'")
print("2. 查看表结构: 这个数据库中有哪些表呀")
print("3. 查看schema: 显示数据库的表结构")
print("4. 关闭连接: 关闭数据库连接")
print()
while True:
try:
# 获取用户输入
user_input = input("\n用户: ")
if user_input.lower() in ["退出", "exit", "quit"]:
print("再见!")
break
if user_input.lower() == "清除记忆":
file_history.clear()
print("会话记录已清除!")
continue
print("\nAgent: ", end="", flush=True)
# 执行Agent
await agent_executor.ainvoke({
"input": user_input
})
print() # 换行
except KeyboardInterrupt:
print("\n\n再见!")
break
except Exception as e:
print(f"\n处理请求时出错: {str(e)}")
def chat_with_agent_sync():
"""
同步方式调用异步的chat_with_agent函数
"""
asyncio.run(chat_with_agent())
if __name__ == "__main__":
chat_with_agent_sync()
database_tools.py
from typing import Dict, Any
import aiomysql
from langchain_core.tools import tool
# 全局数据库连接池
_connection_pool: aiomysql.Pool = None
@tool
async def connect_to_mysql(host: str, port: int, user: str, password: str, db: str) -> Dict[str, str]:
"""
连接到MySQL数据库
Args:
host: 数据库主机地址
port: 数据库端口
user: 数据库用户名
password: 数据库密码
db: 数据库名称
Returns:
Dict[str, str]: 连接结果,包含状态和消息
"""
global _connection_pool
try:
# 关闭现有连接(如果有)
if _connection_pool:
_connection_pool.close()
await _connection_pool.wait_closed()
_connection_pool = None
# 创建连接池
_connection_pool = await aiomysql.create_pool(
host=host,
port=port,
user=user,
password=password,
db=db,
minsize=1,
maxsize=10,
autocommit=True
)
# 测试连接
async with _connection_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
result = await cur.fetchone()
if result[0] == 1:
return {"status": "success", "message": "数据库连接成功"}
else:
return {"status": "error", "message": "连接测试失败"}
except Exception as e:
return {"status": "error", "message": f"数据库连接失败: {str(e)}"}
@tool
async def get_database_schema() -> Dict[str, Any]:
"""
获取数据库的完整schema信息,包括表结构、字段信息、主外键关系和索引等
Returns:
Dict[str, Any]: JSON格式的数据库schema信息
"""
global _connection_pool
try:
# 检查连接池是否存在
if not _connection_pool:
return {"status": "error", "message": "请先连接数据库", "schema": {}}
async with _connection_pool.acquire() as conn:
async with conn.cursor() as cur:
# 获取所有表名
await cur.execute("SHOW TABLES")
tables = await cur.fetchall()
table_names = [table[0] for table in tables]
schema = {}
for table_name in table_names:
table_info = {}
# 获取表基本信息
await cur.execute("""
SELECT TABLE_NAME, ENGINE, TABLE_COLLATION, TABLE_COMMENT
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s
""", (table_name,))
table_meta = await cur.fetchone()
if table_meta:
table_info["table_name"] = table_meta[0]
table_info["engine"] = table_meta[1]
table_info["collation"] = table_meta[2]
table_info["comment"] = table_meta[3]
# 获取字段信息
await cur.execute("""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_COMMENT,
CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION
""", (table_name,))
columns = await cur.fetchall()
table_info["columns"] = []
for col in columns:
column_info = {
"column_name": col[0],
"data_type": col[1],
"is_nullable": col[2] == "YES",
"default_value": col[3],
"comment": col[4],
"max_length": col[5],
"numeric_precision": col[6],
"numeric_scale": col[7]
}
table_info["columns"].append(column_info)
# 获取主键信息
await cur.execute("""
SELECT COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND CONSTRAINT_NAME = 'PRIMARY'
ORDER BY ORDINAL_POSITION
""", (table_name,))
primary_keys = await cur.fetchall()
table_info["primary_keys"] = [pk[0] for pk in primary_keys]
# 获取外键信息
await cur.execute("""
SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND REFERENCED_TABLE_NAME IS NOT NULL
ORDER BY CONSTRAINT_NAME, ORDINAL_POSITION
""", (table_name,))
foreign_keys = await cur.fetchall()
table_info["foreign_keys"] = []
for fk in foreign_keys:
fk_info = {
"constraint_name": fk[0],
"column_name": fk[1],
"referenced_table": fk[2],
"referenced_column": fk[3]
}
table_info["foreign_keys"].append(fk_info)
# 获取索引信息
await cur.execute("""
SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s
ORDER BY INDEX_NAME, SEQ_IN_INDEX
""", (table_name,))
indexes = await cur.fetchall()
table_info["indexes"] = []
for idx in indexes:
index_info = {
"index_name": idx[0],
"column_name": idx[1],
"is_unique": idx[2] == 0
}
table_info["indexes"].append(index_info)
schema[table_name] = table_info
return {"status": "success", "schema": schema}
except Exception as e:
return {"status": "error", "message": f"获取数据库schema失败: {str(e)}", "schema": {}}
@tool
async def close_connection() -> Dict[str, str]:
"""
关闭数据库连接
Returns:
Dict[str, str]: 关闭结果
"""
global _connection_pool
try:
if _connection_pool:
_connection_pool.close()
await _connection_pool.wait_closed()
_connection_pool = None
return {"status": "success", "message": "数据库连接已关闭"}
else:
return {"status": "success", "message": "数据库连接未建立或已关闭"}
except Exception as e:
return {"status": "error", "message": f"关闭连接失败: {str(e)}"}
message_history.py
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict, messages_to_dict
import json
import os
class FileChatMessageHistory(BaseChatMessageHistory):
"""
基于文件存储的会话记录管理类
以session_id为文件名,不同session_id有不同文件存储消息
"""
def __init__(self, session_id: str, directory: str = "./chat_history"):
"""
初始化FileChatMessageHistory
Args:
session_id: 会话ID,作为文件名
directory: 存储会话记录的目录,默认在当前目录下的chat_history文件夹
"""
self.session_id = session_id
self.directory = directory
self.file_path = os.path.join(directory, f"{session_id}.json")
# 确保目录存在
os.makedirs(directory, exist_ok=True)
def add_messages(self, messages: list[BaseMessage]) -> None:
"""
同步模式,添加消息到会话记录
Args:
messages: 要添加的消息列表
"""
# 读取现有消息
existing_messages = self.messages
# 合并消息
all_messages = existing_messages + messages
# 转换为可序列化的格式
messages_dict = messages_to_dict(all_messages)
# 写入文件
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(messages_dict, f, ensure_ascii=False, indent=2)
@property
def messages(self) -> list[BaseMessage]:
"""
同步模式,获取会话记录中的所有消息
Returns:
消息列表
"""
# 检查文件是否存在
if not os.path.exists(self.file_path):
return []
# 读取文件
try:
with open(self.file_path, "r", encoding="utf-8") as f:
messages_dict = json.load(f)
# 转换为消息对象
return messages_from_dict(messages_dict)
except (json.JSONDecodeError, FileNotFoundError):
return []
def clear(self) -> None:
"""
同步模式,清除会话记录
"""
# 检查文件是否存在
if os.path.exists(self.file_path):
# 删除文件
os.remove(self.file_path)
浙公网安备 33010602011771号