24-3-day5-memory-agent_with_Mysql_自动建库

24-3-day5-memory-agent_with_Mysql_自动建库

✅ 程序说明:

  1. 移除 SQLite 相关代码(如 sqlite3get_db_connection 上下文等)。
  2. 引入 MySQL 支持:使用 PyMySQLmysql-connector-python(推荐 PyMySQL,因为兼容性好且轻量)。
  3. 重写 MySQLChatMessageHistory,使其操作 MySQL 表。
  4. 从环境变量读取 MySQL 连接参数(IP、端口、账号、密码、数据库名)。
  5. 保留表结构不变(字段、主键、索引一致)。
  6. 不包含tools
  • init_db() 中,先连接到 MySQL 不指定数据库
  • 执行 CREATE DATABASE IF NOT EXISTS langchain_chat
  • 再切换到该数据库并创建表;
  • 保证即使目标数据库不存在,程序也能自举运行。

⚠️ 注意:MySQL 用户必须有 CREATE 权限才能自动建库。生产环境中建议仍由 DBA 预先创建,但开发/测试场景下自动建库非常方便。


✅ 程序:agent_with_mysql.py(含自动建库)

# agent_with_mysql.py
import os
from contextlib import contextmanager
from typing import List, Optional

import pymysql
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage

# ==============================
# 🔽 自定义 MySQL 聊天历史类(支持自动建库)
# ==============================

load_dotenv()

# 从 .env 读取 MySQL 配置
MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
MYSQL_PORT = int(os.getenv("MYSQL_PORT", 3306))
MYSQL_USER = os.getenv("MYSQL_USER", "root")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "langchain_chat")


@contextmanager
def get_db_connection(use_db: bool = True):
    """
    MySQL 数据库连接上下文管理器
    :param use_db: 是否连接到具体数据库。建库时设为 False。
    """
    db = MYSQL_DATABASE if use_db else None
    conn = pymysql.connect(
        host=MYSQL_HOST,
        port=MYSQL_PORT,
        user=MYSQL_USER,
        password=MYSQL_PASSWORD,
        database=db,
        charset='utf8mb4',
        autocommit=False
    )
    try:
        yield conn
    finally:
        conn.close()


def init_db():
    """自动创建数据库(如果不存在)并初始化表"""
    # 第一步:连接 MySQL(不指定数据库),创建数据库
    with get_db_connection(use_db=False) as conn:
        with conn.cursor() as cursor:
            cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{MYSQL_DATABASE}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
        conn.commit()

    # 第二步:连接到目标数据库,创建表
    with get_db_connection(use_db=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS chat_history (
                    session_id VARCHAR(255) NOT NULL,
                    message_index INT NOT NULL,
                    role ENUM('human', 'ai') NOT NULL,
                    content TEXT NOT NULL,
                    timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
                    PRIMARY KEY (session_id, message_index)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
            """)
            cursor.execute("CREATE INDEX IF NOT EXISTS idx_session ON chat_history(session_id);")
        conn.commit()


class MySQLChatMessageHistory(BaseChatMessageHistory):
    """基于 MySQL 的聊天历史存储,兼容 LangChain"""

    def __init__(self, session_id: str):
        if not session_id:
            raise ValueError("session_id 不能为空")
        self.session_id = session_id
        init_db()  # 确保数据库和表存在

    @property
    def messages(self) -> List[BaseMessage]:
        """从数据库加载消息(供 LangChain 读取)"""
        messages = []
        with get_db_connection(use_db=True) as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    """
                    SELECT role, content
                    FROM chat_history
                    WHERE session_id = %s
                    ORDER BY message_index
                    """,
                    (self.session_id,)
                )
                for row in cursor.fetchall():
                    role, content = row
                    if role == "human":
                        messages.append(HumanMessage(content=content))
                    elif role == "ai":
                        messages.append(AIMessage(content=content))
        return messages

    def add_message(self, message: BaseMessage) -> None:
        """保存单条消息到数据库"""
        if isinstance(message, HumanMessage):
            role = "human"
        elif isinstance(message, AIMessage):
            role = "ai"
        else:
            raise ValueError(f"不支持的消息类型: {type(message)}")

        with get_db_connection(use_db=True) as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    "SELECT COALESCE(MAX(message_index), -1) FROM chat_history WHERE session_id = %s",
                    (self.session_id,)
                )
                next_index = cursor.fetchone()[0] + 1

                cursor.execute(
                    "INSERT INTO chat_history (session_id, message_index, role, content) VALUES (%s, %s, %s, %s)",
                    (self.session_id, next_index, role, message.content)
                )
            conn.commit()

    def clear(self) -> None:
        """清空当前会话历史"""
        with get_db_connection(use_db=True) as conn:
            with conn.cursor() as cursor:
                cursor.execute("DELETE FROM chat_history WHERE session_id = %s", (self.session_id,))
            conn.commit()


# ==============================
# 🔽 LangChain 配置(保持不变)
# ==============================

llm = ChatOpenAI(
    model="qwen-max",
    openai_api_key=os.getenv("DASHSCOPE_API_KEY"),
    openai_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
    temperature=0.7
)

prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个有记忆的 AI 助手。请记住用户之前说过的话。"),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{input}")
])

chain = prompt | llm

with_message_history = RunnableWithMessageHistory(
    chain,
    lambda session_id: MySQLChatMessageHistory(session_id),
    input_messages_key="input",
    history_messages_key="history",
)


# ==============================
# 🔽 主程序(保持不变)
# ==============================

def show_memory(session_id: str):
    """显示当前会话记忆"""
    history = MySQLChatMessageHistory(session_id)
    msgs = history.messages
    print("🧠 当前记忆内容:")
    if not msgs:
        print("  (无记忆)")
    else:
        for i, msg in enumerate(msgs, 1):
            role = "👤 用户" if isinstance(msg, HumanMessage) else "🤖 Agent"
            print(f"  {i}. {role}: {msg.content}")
    print()


if __name__ == "__main__":
    print("🤖 多会话记忆型 Agent 启动(MySQL 持久化版,支持自动建库)!")
    print("指令:")
    print("  - 输入 'quit' 退出")
    print("  - 输入 'show_memory' 查看当前会话记忆")
    print("  - 输入 'switch <session_id>' 切换会话(如 switch alice)")
    print("  - 默认会话 ID: default\n")

    current_session = "default"

    while True:
        try:
            user_input = input(f"👤 [{current_session}] 你: ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\n👋 再见!")
            break

        if not user_input:
            continue

        if user_input.lower() == "quit":
            break
        elif user_input.lower() == "show_memory":
            show_memory(current_session)
            continue
        elif user_input.lower().startswith("switch "):
            parts = user_input.split(" ", 1)
            new_session = parts[1].strip() if len(parts) > 1 else ""
            if not new_session:
                print("⚠️ 用法: switch <session_id>")
                continue
            current_session = new_session
            print(f"🔄 已切换到会话: '{current_session}'\n")
            continue

        response = with_message_history.invoke(
            {"input": user_input},
            config={"configurable": {"session_id": current_session}}
        )
        print(f"🤖 Agent: {response.content}\n")

.env 示例

DASHSCOPE_API_KEY=your_dashscope_api_key_here

MYSQL_HOST=192.168.1.100
MYSQL_PORT=3306
MYSQL_USER=myuser
MYSQL_PASSWORD=mypassword
MYSQL_DATABASE=langchain_chat

💡 即使 langchain_chat 数据库不存在,程序也会自动创建它!


✅ 安装依赖

pip install langchain-openai python-dotenv pymysql

✅ 安全提示

  • 自动建库需要 MySQL 用户有 CREATE 权限。
  • 在生产环境,建议:
    • 由 DBA 创建数据库;
    • 应用使用权限受限的账号(仅 SELECT, INSERT, DELETE);
    • 关闭自动建库逻辑(注释掉 CREATE DATABASE 部分)。

posted @ 2026-02-02 07:53  船山薪火  阅读(0)  评论(0)    收藏  举报
![image](https://img2024.cnblogs.com/blog/3174785/202601/3174785-20260125205854513-941832118.jpg)