• 博客园logo
  • 会员
  • 周边
  • 新闻
  • 博问
  • 闪存
  • 众包
  • 赞助商
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
记得承诺过
博客园    首页    新随笔    联系   管理    订阅  订阅

MCP server client交互,MCP tools鉴权

####################################################Client#################################################################################
# -*- coding: utf-8 -*-
"""
MCP用户认证最佳实践 - 客户端实现
为工具调用自动注入用户上下文
"""

import asyncio
import json
import logging
import os
from typing import Dict, List, Any, Optional

from dotenv import load_dotenv
from langgraph.prebuilt import create_react_agent
from langchain_deepseek import ChatDeepSeek
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.checkpoint.memory import InMemorySaver
 
# 设置记忆存储
checkpointer = InMemorySaver()

# 读取提示词
with open("agent_prompts.txt", "r", encoding="utf-8") as f:
    prompt = f.read()


# ============ 用户会话管理 ============

class UserSession:
    """用户会话管理器"""
   
    def __init__(self, user_id: str, user_info: Dict[str, Any]):
        self.user_id = user_id
        self.user_info = user_info
        self.thread_id = f"thread_{user_id}"
       
        # 创建该用户专属的config
        self.config = {
            "configurable": {
                "thread_id": self.thread_id
            }
        }
   
    def __str__(self):
        return f"UserSession(id={self.user_id}, name={self.user_info['name']})"


class AuthManager:
    """认证管理器"""
   
    # 模拟用户数据库
    USERS = {
        "user_001": {"name": "管理员", "role": "admin", "token": "admin_token_123"},
        "user_002": {"name": "张三", "role": "user", "token": "user_token_456"},
        "user_003": {"name": "李四", "role": "guest", "token": "guest_token_789"}
    }
   
    @classmethod
    def authenticate(cls, username: str) -> Optional[UserSession]:
        """通过用户名创建会话(实际应用中应验证token)"""
        user_info = cls.USERS.get(username)
        if not user_info:
            print(f"❌ 用户 '{username}' 不存在")
            return None
       
        return UserSession(username, user_info)
   
    @classmethod
    def list_users(cls):
        """列出所有可用用户"""
        return list(cls.USERS.keys())


# ============ 工具上下文包装器 ============

class ToolContextWrapper:
    """
    工具上下文包装器
    为每个工具调用自动注入用户ID
    """

    @staticmethod
    async def wrap_tools(tools: List[Any], user_id: str) -> List[Any]:
        """
        为工具列表添加用户上下文

        Args:
            tools: 原始工具列表
            user_id: 用户ID

        Returns:
            包装后的工具列表
        """
        from langchain_core.tools import StructuredTool
        from pydantic import BaseModel, create_model

        wrapped_tools = []

        for tool in tools:
            # 保存原始的 ainvoke 方法
            original_ainvoke = tool.ainvoke
            original_name = tool.name
            original_description = tool.description

            # 获取原始 schema
            original_schema = getattr(tool, 'args_schema', None)

            # 创建包装函数 - 直接注入 user_id,不暴露给大模型
            async def wrapped_ainvoke(*args, **kwargs):
                # 提取 config 参数(如果有)
                config = kwargs.pop('config', None)

                # 确保 kwargs 包含 user_id
                kwargs['user_id'] = user_id

                # 调用原始的 ainvoke
                return await original_ainvoke(kwargs, config)

            # 使用 StructuredTool.from_function 创建新工具
            # 不修改 schema,让大模型按原样调用,但在内部注入 user_id
            new_tool = StructuredTool.from_function(
                func=wrapped_ainvoke,
                coroutine=wrapped_ainvoke,  # 显式指定为协程
                name=original_name,
                description=original_description,
                args_schema=original_schema,  # 保持原始 schema
                return_direct=getattr(tool, 'return_direct', False),
                verbose=getattr(tool, 'verbose', False)
            )

            wrapped_tools.append(new_tool)

        return wrapped_tools

# ============ 环境配置 ============

class Configuration:
    """配置管理器"""
   
    def __init__(self) -> None:
        load_dotenv()
        self.api_key: str = os.getenv("OPENAI_API_KEY", "")
        self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
       
        if not self.base_url:
            raise ValueError("❌ 未找到 BASE_URL,请在 .env 中配置")
        if not self.api_key:
            raise ValueError("❌ 未找到 API_KEY,请在 .env 中配置")
   
    @staticmethod
    def load_servers(file_path: str = "servers_config_secure.json") -> Dict[str, Any]:
        """加载MCP服务器配置"""
        with open(file_path, "r", encoding="utf-8") as f:
            return json.load(f).get("mcpServers", {})


# ============ Agent管理器 ============

class SecureAgent:
    """带认证的Agent管理器"""
   
    def __init__(self):
        self.mcp_client: Optional[MultiServerMCPClient] = None
        self.model = None
        self.current_session: Optional[UserSession] = None
        self.tools: List[Any] = []
   
    async def initialize(self, config_file: str = "servers_config_secure.json"):
        """初始化Agent"""
        cfg = Configuration()
       
        # 设置环境变量
        os.environ["DEEPSEEK_API_KEY"] = cfg.api_key
        if cfg.base_url:
            os.environ["DEEPSEEK_API_BASE"] = cfg.base_url
       
        # 加载MCP服务器配置
        servers_cfg = Configuration.load_servers(config_file)
       
        # 创建MCP客户端
        self.mcp_client = MultiServerMCPClient(servers_cfg)
       
        # 初始化模型
        #self.model = ChatDeepSeek(model="deepseek-v3.2")
        self.model = ChatDeepSeek(model="qwen3-max-preview")
       
       
        print(f"✅ Agent初始化完成")
   
    async def create_session(self, username: str) -> bool:
        """创建用户会话"""
        session = AuthManager.authenticate(username)
        if not session:
            return False

        self.current_session = session

        # 获取工具并注入用户上下文
        self.tools = await self.mcp_client.get_tools()
        self.tools = await ToolContextWrapper.wrap_tools(self.tools, username)

        print(f"✅ 用户会话创建成功: {session}")
        print(f"🔧 已加载 {len(self.tools)} 个工具")
        for tool in self.tools:
            print(f"   - {tool.name}: {tool.description[:50] if tool.description else ''}...")

        return True
   
    async def chat(self, user_input: str) -> str:
        """执行对话"""
        if not self.current_session:
            raise RuntimeError("未创建用户会话")
       
        try:
            # 创建Agent
            agent = create_react_agent(
                model=self.model,
                tools=self.tools,
                prompt=prompt,
                checkpointer=checkpointer,
                debug=False
            )
           
            # 调用Agent
            result = await agent.ainvoke(
                {"messages": [{"role": "user", "content": user_input}]},
                self.current_session.config
            )
           
            # 返回最终回答
            return result['messages'][-1].content
           
        except PermissionError as e:
            return f"🚫 权限拒绝: {e}"
        except Exception as e:
            logging.error(f"对话错误: {e}")
            return f"⚠️  执行出错: {e}"
   
    def get_context_info(self) -> Dict[str, Any]:
        """获取当前上下文信息"""
        if not self.current_session:
            return {}
       
        return {
            "user_id": self.current_session.user_id,
            "user_name": self.current_session.user_info["name"],
            "user_role": self.current_session.user_info["role"],
            "thread_id": self.current_session.thread_id,
            "tools_count": len(self.tools)
        }
   
    async def cleanup(self):
        """清理资源"""
        if self.mcp_client:
            await self.mcp_client.cleanup()
        print("🧹 资源已清理")


# ============ 主程序 ============

async def interactive_chat():
    """交互式对话主程序"""
    agent = SecureAgent()
   
    print("\n" + "=" * 70)
    print("🔐 带认证的 MCP Agent")
    print("=" * 70)
   
    # 初始化Agent
    print("\n📡 正在连接MCP服务器...")
    await agent.initialize("servers_config_secure.json")
   
    # 用户登录
    print("\n📋 可用用户:")
    for user_id in AuthManager.list_users():
        user_info = AuthManager.USERS[user_id]
        print(f"  - {user_id}: {user_info['name']} ({user_info['role']})")
   
    username = input("\n请输入用户名: ").strip()
    if not await agent.create_session(username):
        print("❌ 登录失败")
        await agent.cleanup()
        return
   
    # 显示上下文信息
    ctx = agent.get_context_info()
    print(f"\n✅ 登录成功!")
    print(f"   用户: {ctx['user_name']} ({ctx['user_role']})")
    print(f"   会话ID: {ctx['thread_id']}")
    print(f"   可用工具: {ctx['tools_count']} 个")
   
    # 对话循环
    print(f"\n{'=' * 70}")
    print("💬 对话开始(输入 'quit' 退出,'info' 查看上下文)")
    print(f"{'=' * 70}\n")
   
    while True:
        try:
            user_input = input(f"[{ctx['user_name']}] > ").strip()
           
            if not user_input:
                continue
           
            if user_input.lower() == "quit":
                break
           
            if user_input.lower() == "info":
                info = agent.get_context_info()
                print(f"\n📊 当前上下文:")
                print(f"  用户ID: {info['user_id']}")
                print(f"  用户名: {info['user_name']}")
                print(f"  角色: {info['user_role']}")
                print(f"  会话ID: {info['thread_id']}")
                print(f"  工具数: {info['tools_count']}")
                continue
           
            # 显示调用前状态
            checkpoint = checkpointer.get(agent.current_session.config)
            if checkpoint and checkpoint.get("channel_values"):
                pre_messages = checkpoint["channel_values"].get("messages", [])
                print(f"📚 上下文消息数: {len(pre_messages)}")
           
            # 执行对话
            print("\n🤖 AI:", end=" ")
            response = await agent.chat(user_input)
            print(response)
           
            # 显示调用后状态
            checkpoint_after = checkpointer.get(agent.current_session.config)
            if checkpoint_after and checkpoint_after.get("channel_values"):
                post_messages = checkpoint_after["channel_values"].get("messages", [])
                if 'pre_messages' in locals():
                    print(f"📈 新增消息数: {len(post_messages) - len(pre_messages)}")
                else:
                    print(f"📈 总消息数: {len(post_messages)}")
           
            print()
           
        except KeyboardInterrupt:
            print("\n\n👋 再见!")
            break
        except Exception as e:
            print(f"\n❌ 错误: {e}")
            import traceback
            traceback.print_exc()
   
    # 清理
    await agent.cleanup()
    print("\n🧹 程序退出")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
   
    asyncio.run(interactive_chat())
 
 
 
################################################MCPServer#################################################################################################
# -*- coding: utf-8 -*-
"""
MCP用户认证最佳实践 - 服务端实现
基于请求上下文的轻量级认证方案
"""

from fastmcp import FastMCP
from typing import Dict, List, Optional, Any
from functools import wraps
import logging
import json
import time

# 创建MCP服务器实例
mcp = FastMCP("Secure Flight Service", port=3003)

# ============ 认证和权限系统 ============

class AuthManager:
    """轻量级认证管理器"""
   
    # 模拟用户数据库(实际项目中应替换为数据库)
    USERS = {
        "user_001": {
            "name": "管理员",
            "role": "admin",
            "permissions": ["*"]  # 所有权限
        },
        "user_002": {
            "name": "张三",
            "role": "user",
            "permissions": [
                "flight:read",
              #  "flight:book",
                "hotel:read"
            ]
        },
        "user_003": {
            "name": "李四",
            "role": "guest",
            "permissions": [
               # "flight:read"
            ]
        }
    }
   
    # 数据隔离:每个用户只能访问自己的数据
    USER_DATA = {
        "user_001": {
            "bookings": [
                {"id": "B001", "flight": "CA1234", "status": "confirmed"},
                {"id": "B002", "flight": "MU5678", "status": "pending"}
            ]
        },
        "user_002": {
            "bookings": [
                {"id": "B101", "flight": "CA4321", "status": "confirmed"}
            ]
        },
        "user_003": {
            "bookings": []
        }
    }
   
    @classmethod
    def get_user(cls, user_id: str) -> Optional[Dict]:
        """获取用户信息"""
        return cls.USERS.get(user_id)
   
    @classmethod
    def check_permission(cls, user_id: str, required_permission: str) -> bool:
        """检查用户权限"""
        user = cls.get_user(user_id)
        if not user:
            return False
       
        # 管理员拥有所有权限
        if "*" in user["permissions"]:
            return True
       
        # 检查具体权限
        return required_permission in user["permissions"]
   
    @classmethod
    def get_user_data(cls, user_id: str, data_type: str = "bookings") -> List[Dict]:
        """获取用户数据(自动过滤)"""
        user_data = cls.USER_DATA.get(user_id, {})
        return user_data.get(data_type, [])


# ============ 请求上下文管理 ============

class RequestContext:
    """
    请求上下文管理器
    在每次工具调用时设置和获取当前用户
    """
   
    # 使用contextvars实现线程安全的上下文(Python 3.7+)
    from contextvars import ContextVar
    _current_user: ContextVar[Optional[str]] = ContextVar('current_user', default=None)
    _request_id: ContextVar[Optional[str]] = ContextVar('request_id', default=None)
   
    @classmethod
    def set_user(cls, user_id: str):
        """设置当前请求的用户ID"""
        cls._current_user.set(user_id)
        cls._request_id.set(f"{user_id}_{int(time.time())}")
   
    @classmethod
    def get_user(cls) -> Optional[str]:
        """获取当前请求的用户ID"""
        return cls._current_user.get()
   
    @classmethod
    def get_request_id(cls) -> str:
        """获取请求ID用于日志追踪"""
        return cls._request_id.get()
   
    @classmethod
    def clear(cls):
        """清除上下文"""
        cls._current_user.set(None)
        cls._request_id.set(None)


# ============ 权限装饰器 ============

def require_permission(permission: str):
    """
    权限检查装饰器
   
    使用示例:
    @require_permission("flight:read")
    def search_flights(...):
        pass
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 从kwargs中提取user_id(由客户端通过MCP协议传递)
            user_id = kwargs.get('user_id', None)

            if not user_id:
                raise PermissionError(
                    "未提供用户身份信息,请确保在调用工具时传递user_id参数"
                )

            # 设置请求上下文(这样后续可以通过RequestContext.get_user()获取)
            RequestContext.set_user(user_id)

            # 检查权限
            if not AuthManager.check_permission(user_id, permission):
                user_info = AuthManager.get_user(user_id)
                raise PermissionError(
                    f"用户 '{user_info['name']}' (ID: {user_id}) "
                    f"没有权限执行此操作。需要权限: {permission}"
                )

            # 记录操作日志
            request_id = RequestContext.get_request_id()
            logging.info(
                f"[{request_id}] 用户 {user_id} 调用工具 {func.__name__} "
                f"(权限: {permission})"
            )

            # 执行函数(user_id 仍然在 kwargs 中)
            result = func(*args, **kwargs)

            # 清理上下文
            RequestContext.clear()

            return result

        return wrapper
    return decorator


# ============ 数据访问控制 ============

def require_user_data_access(func):
    """
    数据访问控制装饰器
    自动为查询添加用户过滤条件
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        user_id = kwargs.get('user_id')
        if user_id:
            kwargs['filter_by_user'] = user_id
        return func(*args, **kwargs)
    return wrapper


# ============ MCP工具定义 ============

@mcp.tool()
@require_permission("flight:read")
def search_flights(
    user_id : str ="",
    departure_list: List[str] = ["北京"],
    arrival_list: List[str] = ["上海"],
   
   
):
    """
    当用户查询北京到上海的航班时调用此方法
    搜索航班信息(需要flight:read权限)
   
    Args:
        user_id:用户ID 掉接口必须传入
        departure_list: 出发城市列表
        arrival_list: 到达城市列表
       
       
    Returns:
        str: 航班查询结果
    """
    #user_id = RequestContext.get_user()
    #user_info = AuthManager.get_user(user_id)
    user_info=""
    # 模拟查询(实际项目中连接数据库)
    flights_data = [
        {"id": "CA1234", "departure": "北京", "arrival": "上海",
         "duration": 2.1, "price": 1200, "airline": "中国国航"},
        {"id": "MU5678", "departure": "北京", "arrival": "上海",
         "duration": 2.3, "price": 1150, "airline": "东方航空"},
        {"id": "CZ3456", "departure": "北京", "arrival": "上海",
         "duration": 2.5, "price": 1080, "airline": "南方航空"}
    ]
   
    # 根据用户权限返回不同详细程度的数据
    #result = f"\n=== 航班查询结果 (用户: {user_info['name']}) ===\n"
    result = f"\n=== 航班查询结果  ===\n"+user_id
    result += "| 航班号 | 出发 | 到达 | 时长 | 价格 | 航空公司 |\n"
    result += "|--------|------|------|------|------|----------|\n"
   
    for flight in flights_data :
        result += f"| {flight['id']} | {flight['departure']} | {flight['arrival']} | "
        result += f"{flight['duration']}h | ¥{flight['price']} | {flight['airline']} |\n"
   
    return result
 

# ============ 审计日志 ============

class AuditLogger:
    """审计日志记录器"""
   
    @staticmethod
    def log_operation(user_id: str, operation: str, result: str, error: Any = None):
        """记录操作日志"""
        log_entry = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "user_id": user_id,
            "operation": operation,
            "success": error is None,
            "error": str(error) if error else None,
            "request_id": RequestContext.get_request_id()
        }
       
        # 记录到日志文件
        logging.info(f"AUDIT: {json.dumps(log_entry, ensure_ascii=False)}")
       
        # 实际项目中可以写入数据库
        # db.insert("audit_logs", log_entry)


# ============ 启动服务器 ============

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
   
    print("\n" + "=" * 60)
    print("🚀 启动带认证的MCP服务器")
    print("=" * 60)
    print("\n📋 可用用户:")
    for user_id, info in AuthManager.USERS.items():
        print(f"  - {user_id}: {info['name']} ({info['role']})")
        print(f"    权限: {', '.join(info['permissions'])}")
   
    print("\n📝 说明:")
    print("  - 在调用工具时需要传递 user_id 参数")
    print("  - 系统会自动检查用户权限")
    print("  - 所有操作都会记录审计日志")
    print("\n" + "=" * 60 + "\n")
   
    mcp.run(transport="http", port=3003, host="0.0.0.0", path="/mcp_secure")
 
 
posted @ 2026-01-13 14:28  记得承诺过  阅读(0)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2026
浙公网安备 33010602011771号 浙ICP备2021040463号-3