LangChain学习2 完成一个mysql数据库管理的示例agent

前言

LangChain中的模块都过了一遍了,来做个总结。

大概有以下知识点:提示词模版,Chain, 记忆,工具,agent, RAG

RAG没看。这官方文档看的我脑袋疼,尚硅谷的视频比黑马的详细些。但是都没有langgraph的相关内容。

根据学的东西做了一个小agent,可以用来参考学习,个人感觉写的还是比较规范的。

项目

数据库操作助手Agent ,基于LangChain框架实现,专门用于帮助用户管理和查询MySQL数据库。

完成的功能

1. 数据库连接管理 :
   
   - 连接到MySQL数据库(支持指定host、port、user、password、db参数)
   - 获取数据库的完整schema信息(包括表结构、字段信息、主外键关系和索引等)
   - 关闭数据库连接
2. 智能交互 :
   
   - 基于ChatOpenAI模型(使用deepseek-chat)实现自然语言交互
   - 支持流式输出,提供实时反馈
   - 能够理解用户的数据库操作请求并执行相应的工具调用
3. 长期记忆功能 :
   
   - 基于文件存储会话记录,以session_id为文件名
   - 实现了会话记录的持久化存储,程序重启后仍保留历史对话
   - 支持手动清除会话记录(输入"清除记忆"命令)

技术栈

- 编程语言 :Python
- 核心框架 :LangChain(langchain_classic、langchain_community、langchain_core)
- AI模型 :ChatOpenAI(使用deepseek-chat模型)
- 数据库驱动 :aiomysql(异步MySQL连接)
- 存储方式 :文件存储(JSON格式)

 

项目结构

├── database_agent.py      # 主文件,实现数据库操作助手Agent
├── database_tools.py      # 数据库工具函数
├── message_history.py     # 基于文件的会话记录存储实现
└── chat_history/          # 存储会话记录的文件夹
    └── database_agent_session.json  # 会话记录文件

代码

database_agent.py
from langchain_classic.agents import create_openai_tools_agent, AgentExecutor
from langchain_classic.memory import ConversationBufferMemory
from langchain_community.chat_models import ChatOpenAI
from langchain_core.callbacks import StreamingStdOutCallbackHandler
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import asyncio

from database_tools import connect_to_mysql, get_database_schema, close_connection
from message_history import FileChatMessageHistory

# 创建模型实例(支持流式输出)
model = ChatOpenAI(
    api_key="",
    model="deepseek-chat",
    base_url="https://api.deepseek.com",
    streaming=True,
    callbacks=[StreamingStdOutCallbackHandler()]
)

# 创建工具列表 - 直接使用装饰器修饰的函数
tools = [
    connect_to_mysql,
    get_database_schema,
    close_connection
]

# 系统提示词
system_prompt = """你是一个数据库操作助手,专门帮助用户管理MySQL数据库。

你可以使用以下工具:
1. connect_to_mysql: 连接到MySQL数据库
   参数:
   - host: 数据库主机地址
   - port: 数据库端口
   - user: 数据库用户名
   - password: 数据库密码
   - db: 数据库名称

2. get_database_schema: 获取数据库的完整schema信息,包括表结构、字段信息、主外键关系和索引等
   需要先连接数据库

3. close_connection: 关闭数据库连接

重要说明:
1. 当用户提供连接参数时(如:host='localhost', port=3306, user='root', password='123456', db='test'),你应该提取这些参数并调用connect_to_mysql工具
2. 连接成功后,你可以调用get_database_schema获取数据库schema信息
3. 基于schema信息回答用户的问题,比如有哪些表、表结构是什么
4. 如果已经获取过schema信息,可以直接基于已有信息回答,不需要重复获取
5. 对于不需要工具调用的问题,直接回答

现在开始,请根据用户的问题选择合适的工具并提供帮助。"""

# 创建提示词模板
prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    MessagesPlaceholder("chat_history"),
    ("human", "{input}"),
    MessagesPlaceholder("agent_scratchpad"),
])

# 创建文件存储的消息历史
file_history = FileChatMessageHistory(session_id="database_agent_session")

# 创建记忆,使用文件存储的消息历史
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True,
    output_key="output",
    chat_memory=file_history
)

# 创建Agent
agent = create_openai_tools_agent(
    llm=model,
    tools=tools,
    prompt=prompt
)

# 创建Agent执行器
agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    memory=memory,  # 记忆
    verbose=False,  # 是否显示详细日志
    max_iterations=5,  # 最大迭代次数
    handle_parsing_errors=True, # 自动捕获错误并修复
    early_stopping_method="generate"  # 当达到最大迭代次数时,生成最终回答
)

async def chat_with_agent():
    """
    与数据库Agent进行对话
    """
    print("=== 数据库Agent ===")
    print("输入'退出'结束对话")
    print("输入'清除记忆'清除会话记录")
    print("\n示例用法:")
    print("1. 连接数据库: 请连接到MySQL数据库,host='localhost', port=3306, user='root', password='123456', db='test'")
    print("2. 查看表结构: 这个数据库中有哪些表呀")
    print("3. 查看schema: 显示数据库的表结构")
    print("4. 关闭连接: 关闭数据库连接")
    print()

    while True:
        try:
            # 获取用户输入
            user_input = input("\n用户: ")

            if user_input.lower() in ["退出", "exit", "quit"]:
                print("再见!")
                break
            
            if user_input.lower() == "清除记忆":
                file_history.clear()
                print("会话记录已清除!")
                continue

            print("\nAgent: ", end="", flush=True)

            # 执行Agent
            await agent_executor.ainvoke({
                "input": user_input
            })

            print()  # 换行

        except KeyboardInterrupt:
            print("\n\n再见!")
            break
        except Exception as e:
            print(f"\n处理请求时出错: {str(e)}")


def chat_with_agent_sync():
    """
    同步方式调用异步的chat_with_agent函数
    """
    asyncio.run(chat_with_agent())


if __name__ == "__main__":
    chat_with_agent_sync()

 

database_tools.py
from typing import Dict, Any
import aiomysql
from langchain_core.tools import tool

# 全局数据库连接池
_connection_pool: aiomysql.Pool = None


@tool
async def connect_to_mysql(host: str, port: int, user: str, password: str, db: str) -> Dict[str, str]:
    """
    连接到MySQL数据库

    Args:
        host: 数据库主机地址
        port: 数据库端口
        user: 数据库用户名
        password: 数据库密码
        db: 数据库名称

    Returns:
        Dict[str, str]: 连接结果,包含状态和消息
    """
    global _connection_pool

    try:
        # 关闭现有连接(如果有)
        if _connection_pool:
            _connection_pool.close()
            await _connection_pool.wait_closed()
            _connection_pool = None

        # 创建连接池
        _connection_pool = await aiomysql.create_pool(
            host=host,
            port=port,
            user=user,
            password=password,
            db=db,
            minsize=1,
            maxsize=10,
            autocommit=True
        )

        # 测试连接
        async with _connection_pool.acquire() as conn:
            async with conn.cursor() as cur:
                await cur.execute("SELECT 1")
                result = await cur.fetchone()
                if result[0] == 1:
                    return {"status": "success", "message": "数据库连接成功"}
                else:
                    return {"status": "error", "message": "连接测试失败"}
    except Exception as e:
        return {"status": "error", "message": f"数据库连接失败: {str(e)}"}


@tool
async def get_database_schema() -> Dict[str, Any]:
    """
    获取数据库的完整schema信息,包括表结构、字段信息、主外键关系和索引等

    Returns:
        Dict[str, Any]: JSON格式的数据库schema信息
    """
    global _connection_pool

    try:
        # 检查连接池是否存在
        if not _connection_pool:
            return {"status": "error", "message": "请先连接数据库", "schema": {}}

        async with _connection_pool.acquire() as conn:
            async with conn.cursor() as cur:
                # 获取所有表名
                await cur.execute("SHOW TABLES")
                tables = await cur.fetchall()
                table_names = [table[0] for table in tables]

                schema = {}

                for table_name in table_names:
                    table_info = {}

                    # 获取表基本信息
                    await cur.execute("""
                        SELECT TABLE_NAME, ENGINE, TABLE_COLLATION, TABLE_COMMENT
                        FROM information_schema.TABLES
                        WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s
                    """, (table_name,))
                    table_meta = await cur.fetchone()
                    if table_meta:
                        table_info["table_name"] = table_meta[0]
                        table_info["engine"] = table_meta[1]
                        table_info["collation"] = table_meta[2]
                        table_info["comment"] = table_meta[3]

                    # 获取字段信息
                    await cur.execute("""
                        SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_COMMENT,
                               CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE
                        FROM information_schema.COLUMNS
                        WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s
                        ORDER BY ORDINAL_POSITION
                    """, (table_name,))
                    columns = await cur.fetchall()
                    table_info["columns"] = []
                    for col in columns:
                        column_info = {
                            "column_name": col[0],
                            "data_type": col[1],
                            "is_nullable": col[2] == "YES",
                            "default_value": col[3],
                            "comment": col[4],
                            "max_length": col[5],
                            "numeric_precision": col[6],
                            "numeric_scale": col[7]
                        }
                        table_info["columns"].append(column_info)

                    # 获取主键信息
                    await cur.execute("""
                        SELECT COLUMN_NAME
                        FROM information_schema.KEY_COLUMN_USAGE
                        WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND CONSTRAINT_NAME = 'PRIMARY'
                        ORDER BY ORDINAL_POSITION
                    """, (table_name,))
                    primary_keys = await cur.fetchall()
                    table_info["primary_keys"] = [pk[0] for pk in primary_keys]

                    # 获取外键信息
                    await cur.execute("""
                        SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
                        FROM information_schema.KEY_COLUMN_USAGE
                        WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND REFERENCED_TABLE_NAME IS NOT NULL
                        ORDER BY CONSTRAINT_NAME, ORDINAL_POSITION
                    """, (table_name,))
                    foreign_keys = await cur.fetchall()
                    table_info["foreign_keys"] = []
                    for fk in foreign_keys:
                        fk_info = {
                            "constraint_name": fk[0],
                            "column_name": fk[1],
                            "referenced_table": fk[2],
                            "referenced_column": fk[3]
                        }
                        table_info["foreign_keys"].append(fk_info)

                    # 获取索引信息
                    await cur.execute("""
                        SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE
                        FROM information_schema.STATISTICS
                        WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s
                        ORDER BY INDEX_NAME, SEQ_IN_INDEX
                    """, (table_name,))
                    indexes = await cur.fetchall()
                    table_info["indexes"] = []
                    for idx in indexes:
                        index_info = {
                            "index_name": idx[0],
                            "column_name": idx[1],
                            "is_unique": idx[2] == 0
                        }
                        table_info["indexes"].append(index_info)

                    schema[table_name] = table_info

                return {"status": "success", "schema": schema}
    except Exception as e:
        return {"status": "error", "message": f"获取数据库schema失败: {str(e)}", "schema": {}}


@tool
async def close_connection() -> Dict[str, str]:
    """
    关闭数据库连接

    Returns:
        Dict[str, str]: 关闭结果
    """
    global _connection_pool

    try:
        if _connection_pool:
            _connection_pool.close()
            await _connection_pool.wait_closed()
            _connection_pool = None
            return {"status": "success", "message": "数据库连接已关闭"}
        else:
            return {"status": "success", "message": "数据库连接未建立或已关闭"}
    except Exception as e:
        return {"status": "error", "message": f"关闭连接失败: {str(e)}"}

 

message_history.py
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict, messages_to_dict
import json
import os

class FileChatMessageHistory(BaseChatMessageHistory):
    """
    基于文件存储的会话记录管理类
    以session_id为文件名,不同session_id有不同文件存储消息
    """
    
    def __init__(self, session_id: str, directory: str = "./chat_history"):
        """
        初始化FileChatMessageHistory
        
        Args:
            session_id: 会话ID,作为文件名
            directory: 存储会话记录的目录,默认在当前目录下的chat_history文件夹
        """
        self.session_id = session_id
        self.directory = directory
        self.file_path = os.path.join(directory, f"{session_id}.json")
        
        # 确保目录存在
        os.makedirs(directory, exist_ok=True)
    
    def add_messages(self, messages: list[BaseMessage]) -> None:
        """
        同步模式,添加消息到会话记录
        
        Args:
            messages: 要添加的消息列表
        """
        # 读取现有消息
        existing_messages = self.messages
        # 合并消息
        all_messages = existing_messages + messages
        # 转换为可序列化的格式
        messages_dict = messages_to_dict(all_messages)
        # 写入文件
        with open(self.file_path, "w", encoding="utf-8") as f:
            json.dump(messages_dict, f, ensure_ascii=False, indent=2)
    
    @property
    def messages(self) -> list[BaseMessage]:
        """
        同步模式,获取会话记录中的所有消息
        
        Returns:
            消息列表
        """
        # 检查文件是否存在
        if not os.path.exists(self.file_path):
            return []
        # 读取文件
        try:
            with open(self.file_path, "r", encoding="utf-8") as f:
                messages_dict = json.load(f)
                # 转换为消息对象
                return messages_from_dict(messages_dict)
        except (json.JSONDecodeError, FileNotFoundError):
            return []
    
    def clear(self) -> None:
        """
        同步模式,清除会话记录
        """
        # 检查文件是否存在
        if os.path.exists(self.file_path):
            # 删除文件
            os.remove(self.file_path)
posted @ 2026-02-02 17:20  雨花阁  阅读(5)  评论(0)    收藏  举报