24-3-day5-memory-agent_with_Mysql_自动建库
24-3-day5-memory-agent_with_Mysql_自动建库
✅ 程序说明:
- 移除 SQLite 相关代码(如
sqlite3、get_db_connection上下文等)。 - 引入 MySQL 支持:使用
PyMySQL或mysql-connector-python(推荐PyMySQL,因为兼容性好且轻量)。 - 重写
MySQLChatMessageHistory类,使其操作 MySQL 表。 - 从环境变量读取 MySQL 连接参数(IP、端口、账号、密码、数据库名)。
- 保留表结构不变(字段、主键、索引一致)。
- 不包含tools
- 在
init_db()中,先连接到 MySQL 不指定数据库; - 执行
CREATE DATABASE IF NOT EXISTS langchain_chat; - 再切换到该数据库并创建表;
- 保证即使目标数据库不存在,程序也能自举运行。
⚠️ 注意:MySQL 用户必须有
CREATE权限才能自动建库。生产环境中建议仍由 DBA 预先创建,但开发/测试场景下自动建库非常方便。
✅ 程序:agent_with_mysql.py(含自动建库)
# agent_with_mysql.py
import os
from contextlib import contextmanager
from typing import List, Optional
import pymysql
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
# ==============================
# 🔽 自定义 MySQL 聊天历史类(支持自动建库)
# ==============================
load_dotenv()
# 从 .env 读取 MySQL 配置
MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
MYSQL_PORT = int(os.getenv("MYSQL_PORT", 3306))
MYSQL_USER = os.getenv("MYSQL_USER", "root")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "langchain_chat")
@contextmanager
def get_db_connection(use_db: bool = True):
"""
MySQL 数据库连接上下文管理器
:param use_db: 是否连接到具体数据库。建库时设为 False。
"""
db = MYSQL_DATABASE if use_db else None
conn = pymysql.connect(
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=db,
charset='utf8mb4',
autocommit=False
)
try:
yield conn
finally:
conn.close()
def init_db():
"""自动创建数据库(如果不存在)并初始化表"""
# 第一步:连接 MySQL(不指定数据库),创建数据库
with get_db_connection(use_db=False) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{MYSQL_DATABASE}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
conn.commit()
# 第二步:连接到目标数据库,创建表
with get_db_connection(use_db=True) as conn:
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS chat_history (
session_id VARCHAR(255) NOT NULL,
message_index INT NOT NULL,
role ENUM('human', 'ai') NOT NULL,
content TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (session_id, message_index)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_session ON chat_history(session_id);")
conn.commit()
class MySQLChatMessageHistory(BaseChatMessageHistory):
"""基于 MySQL 的聊天历史存储,兼容 LangChain"""
def __init__(self, session_id: str):
if not session_id:
raise ValueError("session_id 不能为空")
self.session_id = session_id
init_db() # 确保数据库和表存在
@property
def messages(self) -> List[BaseMessage]:
"""从数据库加载消息(供 LangChain 读取)"""
messages = []
with get_db_connection(use_db=True) as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT role, content
FROM chat_history
WHERE session_id = %s
ORDER BY message_index
""",
(self.session_id,)
)
for row in cursor.fetchall():
role, content = row
if role == "human":
messages.append(HumanMessage(content=content))
elif role == "ai":
messages.append(AIMessage(content=content))
return messages
def add_message(self, message: BaseMessage) -> None:
"""保存单条消息到数据库"""
if isinstance(message, HumanMessage):
role = "human"
elif isinstance(message, AIMessage):
role = "ai"
else:
raise ValueError(f"不支持的消息类型: {type(message)}")
with get_db_connection(use_db=True) as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT COALESCE(MAX(message_index), -1) FROM chat_history WHERE session_id = %s",
(self.session_id,)
)
next_index = cursor.fetchone()[0] + 1
cursor.execute(
"INSERT INTO chat_history (session_id, message_index, role, content) VALUES (%s, %s, %s, %s)",
(self.session_id, next_index, role, message.content)
)
conn.commit()
def clear(self) -> None:
"""清空当前会话历史"""
with get_db_connection(use_db=True) as conn:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM chat_history WHERE session_id = %s", (self.session_id,))
conn.commit()
# ==============================
# 🔽 LangChain 配置(保持不变)
# ==============================
llm = ChatOpenAI(
model="qwen-max",
openai_api_key=os.getenv("DASHSCOPE_API_KEY"),
openai_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
temperature=0.7
)
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个有记忆的 AI 助手。请记住用户之前说过的话。"),
MessagesPlaceholder(variable_name="history"),
("human", "{input}")
])
chain = prompt | llm
with_message_history = RunnableWithMessageHistory(
chain,
lambda session_id: MySQLChatMessageHistory(session_id),
input_messages_key="input",
history_messages_key="history",
)
# ==============================
# 🔽 主程序(保持不变)
# ==============================
def show_memory(session_id: str):
"""显示当前会话记忆"""
history = MySQLChatMessageHistory(session_id)
msgs = history.messages
print("🧠 当前记忆内容:")
if not msgs:
print(" (无记忆)")
else:
for i, msg in enumerate(msgs, 1):
role = "👤 用户" if isinstance(msg, HumanMessage) else "🤖 Agent"
print(f" {i}. {role}: {msg.content}")
print()
if __name__ == "__main__":
print("🤖 多会话记忆型 Agent 启动(MySQL 持久化版,支持自动建库)!")
print("指令:")
print(" - 输入 'quit' 退出")
print(" - 输入 'show_memory' 查看当前会话记忆")
print(" - 输入 'switch <session_id>' 切换会话(如 switch alice)")
print(" - 默认会话 ID: default\n")
current_session = "default"
while True:
try:
user_input = input(f"👤 [{current_session}] 你: ").strip()
except (KeyboardInterrupt, EOFError):
print("\n👋 再见!")
break
if not user_input:
continue
if user_input.lower() == "quit":
break
elif user_input.lower() == "show_memory":
show_memory(current_session)
continue
elif user_input.lower().startswith("switch "):
parts = user_input.split(" ", 1)
new_session = parts[1].strip() if len(parts) > 1 else ""
if not new_session:
print("⚠️ 用法: switch <session_id>")
continue
current_session = new_session
print(f"🔄 已切换到会话: '{current_session}'\n")
continue
response = with_message_history.invoke(
{"input": user_input},
config={"configurable": {"session_id": current_session}}
)
print(f"🤖 Agent: {response.content}\n")
✅ .env 示例
DASHSCOPE_API_KEY=your_dashscope_api_key_here
MYSQL_HOST=192.168.1.100
MYSQL_PORT=3306
MYSQL_USER=myuser
MYSQL_PASSWORD=mypassword
MYSQL_DATABASE=langchain_chat
💡 即使
langchain_chat数据库不存在,程序也会自动创建它!
✅ 安装依赖
pip install langchain-openai python-dotenv pymysql
✅ 安全提示
- 自动建库需要 MySQL 用户有
CREATE权限。 - 在生产环境,建议:
- 由 DBA 创建数据库;
- 应用使用权限受限的账号(仅
SELECT,INSERT,DELETE); - 关闭自动建库逻辑(注释掉
CREATE DATABASE部分)。
浙公网安备 33010602011771号