MCP server client交互,MCP tools鉴权
####################################################Client#################################################################################
# -*- coding: utf-8 -*-
"""
MCP用户认证最佳实践 - 客户端实现
为工具调用自动注入用户上下文
"""
import asyncio
import json
import logging
import os
from typing import Dict, List, Any, Optional
from dotenv import load_dotenv
from langgraph.prebuilt import create_react_agent
from langchain_deepseek import ChatDeepSeek
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.checkpoint.memory import InMemorySaver
# 设置记忆存储
checkpointer = InMemorySaver()
# 读取提示词
with open("agent_prompts.txt", "r", encoding="utf-8") as f:
prompt = f.read()
# ============ 用户会话管理 ============
class UserSession:
"""用户会话管理器"""
def __init__(self, user_id: str, user_info: Dict[str, Any]):
self.user_id = user_id
self.user_info = user_info
self.thread_id = f"thread_{user_id}"
# 创建该用户专属的config
self.config = {
"configurable": {
"thread_id": self.thread_id
}
}
def __str__(self):
return f"UserSession(id={self.user_id}, name={self.user_info['name']})"
class AuthManager:
"""认证管理器"""
# 模拟用户数据库
USERS = {
"user_001": {"name": "管理员", "role": "admin", "token": "admin_token_123"},
"user_002": {"name": "张三", "role": "user", "token": "user_token_456"},
"user_003": {"name": "李四", "role": "guest", "token": "guest_token_789"}
}
@classmethod
def authenticate(cls, username: str) -> Optional[UserSession]:
"""通过用户名创建会话(实际应用中应验证token)"""
user_info = cls.USERS.get(username)
if not user_info:
print(f"❌ 用户 '{username}' 不存在")
return None
return UserSession(username, user_info)
@classmethod
def list_users(cls):
"""列出所有可用用户"""
return list(cls.USERS.keys())
# ============ 工具上下文包装器 ============
class ToolContextWrapper:
"""
工具上下文包装器
为每个工具调用自动注入用户ID
"""
@staticmethod
async def wrap_tools(tools: List[Any], user_id: str) -> List[Any]:
"""
为工具列表添加用户上下文
Args:
tools: 原始工具列表
user_id: 用户ID
Returns:
包装后的工具列表
"""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, create_model
wrapped_tools = []
for tool in tools:
# 保存原始的 ainvoke 方法
original_ainvoke = tool.ainvoke
original_name = tool.name
original_description = tool.description
# 获取原始 schema
original_schema = getattr(tool, 'args_schema', None)
# 创建包装函数 - 直接注入 user_id,不暴露给大模型
async def wrapped_ainvoke(*args, **kwargs):
# 提取 config 参数(如果有)
config = kwargs.pop('config', None)
# 确保 kwargs 包含 user_id
kwargs['user_id'] = user_id
# 调用原始的 ainvoke
return await original_ainvoke(kwargs, config)
# 使用 StructuredTool.from_function 创建新工具
# 不修改 schema,让大模型按原样调用,但在内部注入 user_id
new_tool = StructuredTool.from_function(
func=wrapped_ainvoke,
coroutine=wrapped_ainvoke, # 显式指定为协程
name=original_name,
description=original_description,
args_schema=original_schema, # 保持原始 schema
return_direct=getattr(tool, 'return_direct', False),
verbose=getattr(tool, 'verbose', False)
)
wrapped_tools.append(new_tool)
return wrapped_tools
# ============ 环境配置 ============
class Configuration:
"""配置管理器"""
def __init__(self) -> None:
load_dotenv()
self.api_key: str = os.getenv("OPENAI_API_KEY", "")
self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
if not self.base_url:
raise ValueError("❌ 未找到 BASE_URL,请在 .env 中配置")
if not self.api_key:
raise ValueError("❌ 未找到 API_KEY,请在 .env 中配置")
@staticmethod
def load_servers(file_path: str = "servers_config_secure.json") -> Dict[str, Any]:
"""加载MCP服务器配置"""
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f).get("mcpServers", {})
# ============ Agent管理器 ============
class SecureAgent:
"""带认证的Agent管理器"""
def __init__(self):
self.mcp_client: Optional[MultiServerMCPClient] = None
self.model = None
self.current_session: Optional[UserSession] = None
self.tools: List[Any] = []
async def initialize(self, config_file: str = "servers_config_secure.json"):
"""初始化Agent"""
cfg = Configuration()
# 设置环境变量
os.environ["DEEPSEEK_API_KEY"] = cfg.api_key
if cfg.base_url:
os.environ["DEEPSEEK_API_BASE"] = cfg.base_url
# 加载MCP服务器配置
servers_cfg = Configuration.load_servers(config_file)
# 创建MCP客户端
self.mcp_client = MultiServerMCPClient(servers_cfg)
# 初始化模型
#self.model = ChatDeepSeek(model="deepseek-v3.2")
self.model = ChatDeepSeek(model="qwen3-max-preview")
print(f"✅ Agent初始化完成")
async def create_session(self, username: str) -> bool:
"""创建用户会话"""
session = AuthManager.authenticate(username)
if not session:
return False
self.current_session = session
# 获取工具并注入用户上下文
self.tools = await self.mcp_client.get_tools()
self.tools = await ToolContextWrapper.wrap_tools(self.tools, username)
print(f"✅ 用户会话创建成功: {session}")
print(f"🔧 已加载 {len(self.tools)} 个工具")
for tool in self.tools:
print(f" - {tool.name}: {tool.description[:50] if tool.description else ''}...")
return True
async def chat(self, user_input: str) -> str:
"""执行对话"""
if not self.current_session:
raise RuntimeError("未创建用户会话")
try:
# 创建Agent
agent = create_react_agent(
model=self.model,
tools=self.tools,
prompt=prompt,
checkpointer=checkpointer,
debug=False
)
# 调用Agent
result = await agent.ainvoke(
{"messages": [{"role": "user", "content": user_input}]},
self.current_session.config
)
# 返回最终回答
return result['messages'][-1].content
except PermissionError as e:
return f"🚫 权限拒绝: {e}"
except Exception as e:
logging.error(f"对话错误: {e}")
return f"⚠️ 执行出错: {e}"
def get_context_info(self) -> Dict[str, Any]:
"""获取当前上下文信息"""
if not self.current_session:
return {}
return {
"user_id": self.current_session.user_id,
"user_name": self.current_session.user_info["name"],
"user_role": self.current_session.user_info["role"],
"thread_id": self.current_session.thread_id,
"tools_count": len(self.tools)
}
async def cleanup(self):
"""清理资源"""
if self.mcp_client:
await self.mcp_client.cleanup()
print("🧹 资源已清理")
# ============ 主程序 ============
async def interactive_chat():
"""交互式对话主程序"""
agent = SecureAgent()
print("\n" + "=" * 70)
print("🔐 带认证的 MCP Agent")
print("=" * 70)
# 初始化Agent
print("\n📡 正在连接MCP服务器...")
await agent.initialize("servers_config_secure.json")
# 用户登录
print("\n📋 可用用户:")
for user_id in AuthManager.list_users():
user_info = AuthManager.USERS[user_id]
print(f" - {user_id}: {user_info['name']} ({user_info['role']})")
username = input("\n请输入用户名: ").strip()
if not await agent.create_session(username):
print("❌ 登录失败")
await agent.cleanup()
return
# 显示上下文信息
ctx = agent.get_context_info()
print(f"\n✅ 登录成功!")
print(f" 用户: {ctx['user_name']} ({ctx['user_role']})")
print(f" 会话ID: {ctx['thread_id']}")
print(f" 可用工具: {ctx['tools_count']} 个")
# 对话循环
print(f"\n{'=' * 70}")
print("💬 对话开始(输入 'quit' 退出,'info' 查看上下文)")
print(f"{'=' * 70}\n")
while True:
try:
user_input = input(f"[{ctx['user_name']}] > ").strip()
if not user_input:
continue
if user_input.lower() == "quit":
break
if user_input.lower() == "info":
info = agent.get_context_info()
print(f"\n📊 当前上下文:")
print(f" 用户ID: {info['user_id']}")
print(f" 用户名: {info['user_name']}")
print(f" 角色: {info['user_role']}")
print(f" 会话ID: {info['thread_id']}")
print(f" 工具数: {info['tools_count']}")
continue
# 显示调用前状态
checkpoint = checkpointer.get(agent.current_session.config)
if checkpoint and checkpoint.get("channel_values"):
pre_messages = checkpoint["channel_values"].get("messages", [])
print(f"📚 上下文消息数: {len(pre_messages)}")
# 执行对话
print("\n🤖 AI:", end=" ")
response = await agent.chat(user_input)
print(response)
# 显示调用后状态
checkpoint_after = checkpointer.get(agent.current_session.config)
if checkpoint_after and checkpoint_after.get("channel_values"):
post_messages = checkpoint_after["channel_values"].get("messages", [])
if 'pre_messages' in locals():
print(f"📈 新增消息数: {len(post_messages) - len(pre_messages)}")
else:
print(f"📈 总消息数: {len(post_messages)}")
print()
except KeyboardInterrupt:
print("\n\n👋 再见!")
break
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback
traceback.print_exc()
# 清理
await agent.cleanup()
print("\n🧹 程序退出")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
asyncio.run(interactive_chat())
################################################MCPServer#################################################################################################
# -*- coding: utf-8 -*-
"""
MCP用户认证最佳实践 - 服务端实现
基于请求上下文的轻量级认证方案
"""
from fastmcp import FastMCP
from typing import Dict, List, Optional, Any
from functools import wraps
import logging
import json
import time
# 创建MCP服务器实例
mcp = FastMCP("Secure Flight Service", port=3003)
# ============ 认证和权限系统 ============
class AuthManager:
"""轻量级认证管理器"""
# 模拟用户数据库(实际项目中应替换为数据库)
USERS = {
"user_001": {
"name": "管理员",
"role": "admin",
"permissions": ["*"] # 所有权限
},
"user_002": {
"name": "张三",
"role": "user",
"permissions": [
"flight:read",
# "flight:book",
"hotel:read"
]
},
"user_003": {
"name": "李四",
"role": "guest",
"permissions": [
# "flight:read"
]
}
}
# 数据隔离:每个用户只能访问自己的数据
USER_DATA = {
"user_001": {
"bookings": [
{"id": "B001", "flight": "CA1234", "status": "confirmed"},
{"id": "B002", "flight": "MU5678", "status": "pending"}
]
},
"user_002": {
"bookings": [
{"id": "B101", "flight": "CA4321", "status": "confirmed"}
]
},
"user_003": {
"bookings": []
}
}
@classmethod
def get_user(cls, user_id: str) -> Optional[Dict]:
"""获取用户信息"""
return cls.USERS.get(user_id)
@classmethod
def check_permission(cls, user_id: str, required_permission: str) -> bool:
"""检查用户权限"""
user = cls.get_user(user_id)
if not user:
return False
# 管理员拥有所有权限
if "*" in user["permissions"]:
return True
# 检查具体权限
return required_permission in user["permissions"]
@classmethod
def get_user_data(cls, user_id: str, data_type: str = "bookings") -> List[Dict]:
"""获取用户数据(自动过滤)"""
user_data = cls.USER_DATA.get(user_id, {})
return user_data.get(data_type, [])
# ============ 请求上下文管理 ============
class RequestContext:
"""
请求上下文管理器
在每次工具调用时设置和获取当前用户
"""
# 使用contextvars实现线程安全的上下文(Python 3.7+)
from contextvars import ContextVar
_current_user: ContextVar[Optional[str]] = ContextVar('current_user', default=None)
_request_id: ContextVar[Optional[str]] = ContextVar('request_id', default=None)
@classmethod
def set_user(cls, user_id: str):
"""设置当前请求的用户ID"""
cls._current_user.set(user_id)
cls._request_id.set(f"{user_id}_{int(time.time())}")
@classmethod
def get_user(cls) -> Optional[str]:
"""获取当前请求的用户ID"""
return cls._current_user.get()
@classmethod
def get_request_id(cls) -> str:
"""获取请求ID用于日志追踪"""
return cls._request_id.get()
@classmethod
def clear(cls):
"""清除上下文"""
cls._current_user.set(None)
cls._request_id.set(None)
# ============ 权限装饰器 ============
def require_permission(permission: str):
"""
权限检查装饰器
使用示例:
@require_permission("flight:read")
def search_flights(...):
pass
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 从kwargs中提取user_id(由客户端通过MCP协议传递)
user_id = kwargs.get('user_id', None)
if not user_id:
raise PermissionError(
"未提供用户身份信息,请确保在调用工具时传递user_id参数"
)
# 设置请求上下文(这样后续可以通过RequestContext.get_user()获取)
RequestContext.set_user(user_id)
# 检查权限
if not AuthManager.check_permission(user_id, permission):
user_info = AuthManager.get_user(user_id)
raise PermissionError(
f"用户 '{user_info['name']}' (ID: {user_id}) "
f"没有权限执行此操作。需要权限: {permission}"
)
# 记录操作日志
request_id = RequestContext.get_request_id()
logging.info(
f"[{request_id}] 用户 {user_id} 调用工具 {func.__name__} "
f"(权限: {permission})"
)
# 执行函数(user_id 仍然在 kwargs 中)
result = func(*args, **kwargs)
# 清理上下文
RequestContext.clear()
return result
return wrapper
return decorator
# ============ 数据访问控制 ============
def require_user_data_access(func):
"""
数据访问控制装饰器
自动为查询添加用户过滤条件
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = kwargs.get('user_id')
if user_id:
kwargs['filter_by_user'] = user_id
return func(*args, **kwargs)
return wrapper
# ============ MCP工具定义 ============
@mcp.tool()
@require_permission("flight:read")
def search_flights(
user_id : str ="",
departure_list: List[str] = ["北京"],
arrival_list: List[str] = ["上海"],
):
"""
当用户查询北京到上海的航班时调用此方法
搜索航班信息(需要flight:read权限)
Args:
user_id:用户ID 掉接口必须传入
departure_list: 出发城市列表
arrival_list: 到达城市列表
Returns:
str: 航班查询结果
"""
#user_id = RequestContext.get_user()
#user_info = AuthManager.get_user(user_id)
user_info=""
# 模拟查询(实际项目中连接数据库)
flights_data = [
{"id": "CA1234", "departure": "北京", "arrival": "上海",
"duration": 2.1, "price": 1200, "airline": "中国国航"},
{"id": "MU5678", "departure": "北京", "arrival": "上海",
"duration": 2.3, "price": 1150, "airline": "东方航空"},
{"id": "CZ3456", "departure": "北京", "arrival": "上海",
"duration": 2.5, "price": 1080, "airline": "南方航空"}
]
# 根据用户权限返回不同详细程度的数据
#result = f"\n=== 航班查询结果 (用户: {user_info['name']}) ===\n"
result = f"\n=== 航班查询结果 ===\n"+user_id
result += "| 航班号 | 出发 | 到达 | 时长 | 价格 | 航空公司 |\n"
result += "|--------|------|------|------|------|----------|\n"
for flight in flights_data :
result += f"| {flight['id']} | {flight['departure']} | {flight['arrival']} | "
result += f"{flight['duration']}h | ¥{flight['price']} | {flight['airline']} |\n"
return result
# ============ 审计日志 ============
class AuditLogger:
"""审计日志记录器"""
@staticmethod
def log_operation(user_id: str, operation: str, result: str, error: Any = None):
"""记录操作日志"""
log_entry = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"user_id": user_id,
"operation": operation,
"success": error is None,
"error": str(error) if error else None,
"request_id": RequestContext.get_request_id()
}
# 记录到日志文件
logging.info(f"AUDIT: {json.dumps(log_entry, ensure_ascii=False)}")
# 实际项目中可以写入数据库
# db.insert("audit_logs", log_entry)
# ============ 启动服务器 ============
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
print("\n" + "=" * 60)
print("🚀 启动带认证的MCP服务器")
print("=" * 60)
print("\n📋 可用用户:")
for user_id, info in AuthManager.USERS.items():
print(f" - {user_id}: {info['name']} ({info['role']})")
print(f" 权限: {', '.join(info['permissions'])}")
print("\n📝 说明:")
print(" - 在调用工具时需要传递 user_id 参数")
print(" - 系统会自动检查用户权限")
print(" - 所有操作都会记录审计日志")
print("\n" + "=" * 60 + "\n")
mcp.run(transport="http", port=3003, host="0.0.0.0", path="/mcp_secure")
浙公网安备 33010602011771号