python fast api websocket 连接事例

服务端事例:

# -*- coding: utf-8 -*-
import asyncio
import traceback
import json
import uuid
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from typing import Dict, List, Optional
import uvicorn
from redis.asyncio import from_url, Redis

# ============= 基本配置 =============
REDIS_URL = "redis://localhost:6379/0"
CHANNEL_NAME = "websocket_messages"
SERVER_ID = str(uuid.uuid4())

print(f"当前服务器ID: {SERVER_ID}")

app = FastAPI(title="分布式WebSocket服务")


# ============= 工具函数 =============
def log(msg):
    print(f"[Server {SERVER_ID[:8]}] {msg}")


# ============= 核心管理类 =============
class ConnectionManager:
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.rooms: Dict[str, List[str]] = {}
        self.redis: Optional[Redis] = None
        self.pubsub = None
        self.subscription_task: Optional[asyncio.Task] = None

    async def initialize(self):
        """初始化 Redis 连接并订阅"""
        self.redis = from_url(REDIS_URL, decode_responses=False)
        self.pubsub = self.redis.pubsub()
        self.subscription_task = asyncio.create_task(self._subscribe_to_channel())
        log("Redis 初始化完成")

    async def _subscribe_to_channel(self):
        """订阅 Redis 频道"""
        try:
            await self.pubsub.subscribe(CHANNEL_NAME)
            log(f"已订阅 Redis 频道: {CHANNEL_NAME}")

            async for message in self.pubsub.listen():
                if message["type"] == "message":
                    await self._handle_external_message(message["data"])
        except Exception as e:
            log(f"Redis订阅异常: {e}\n{traceback.format_exc()}")

    async def _handle_external_message(self, data: bytes):
        """处理来自其他服务器的消息"""
        try:
            msg = json.loads(data.decode())
            if msg.get("server_id") == SERVER_ID:
                return  # 跳过自身消息

            msg_type = msg["type"]
            if msg_type == "broadcast":
                await self._broadcast_local(msg["message"])
            elif msg_type == "personal_message":
                await self.send_personal_message(msg["message"], msg["user_id"], remote=True)
            elif msg_type == "room_message":
                await self._send_room_local(msg["message"], msg["room_id"])
        except Exception as e:
            log(f"处理外部消息异常: {e}\n{traceback.format_exc()}")

    async def connect(self, websocket: WebSocket, user_id: str):
        await websocket.accept()
        self.active_connections[user_id] = websocket
        log(f"用户 {user_id} 已连接")
        # 确保订阅频道
        if not self.pubsub:
            self.redis = await from_url(REDIS_URL)  # 如果没有 Redis 连接,重新连接
            self.pubsub = self.redis.pubsub()  # 获取 PubSub 对象
            await self.pubsub.subscribe(CHANNEL_NAME)  # 订阅频道
            log("已订阅Redis频道:", CHANNEL_NAME)

    async def disconnect(self, user_id: str):
        if user_id in self.active_connections:
            del self.active_connections[user_id]
        for room in list(self.rooms.keys()):
            if user_id in self.rooms[room]:
                self.rooms[room].remove(user_id)
                if not self.rooms[room]:
                    del self.rooms[room]
        await self._publish_message({
            "type": "broadcast",
            "message": f"用户 {user_id} 离开了聊天室",
            "server_id": SERVER_ID,
        })
        log(f"用户 {user_id} 已断开")

    async def send_personal_message(self, message: str, user_id: str, remote: bool = False):
        """发送私聊消息"""
        if user_id in self.active_connections:
            await self.active_connections[user_id].send_text(message)
        elif not remote:
            await self._publish_message({
                "type": "personal_message",
                "message": message,
                "user_id": user_id,
                "server_id": SERVER_ID,
            })
            log(f"发送私聊消息给用户 {user_id}: {message}")

    async def broadcast(self, message: str):
        """广播到所有服务器"""
        await self._broadcast_local(message)
        await self._publish_message({
            "type": "broadcast",
            "message": message,
            "server_id": SERVER_ID,
        })

    async def _broadcast_local(self, message: str):
        """仅在当前服务器内广播"""
        for ws in self.active_connections.values():
            await ws.send_text(message)

    async def join_room(self, user_id: str, room_id: str):
        self.rooms.setdefault(room_id, [])
        if user_id not in self.rooms[room_id]:
            self.rooms[room_id].append(user_id)
        await self._send_room_local(f"用户 {user_id} 加入了房间 {room_id}", room_id)
        await self._publish_message({
            "type": "room_message",
            "message": f"用户 {user_id} 加入了房间 {room_id}",
            "room_id": room_id,
            "server_id": SERVER_ID,
        })

    async def _send_room_local(self, message: str, room_id: str):
        """仅在本地房间发送"""
        for uid in self.rooms.get(room_id, []):
            await self.send_personal_message(message, uid, remote=True)

    async def _publish_message(self, msg: dict):
        """发布 Redis 消息"""
        if self.redis:
            log(f"发布消息:{msg}")
            await self.redis.publish(CHANNEL_NAME, json.dumps(msg).encode())

    async def shutdown(self):
        """关闭 Redis 连接"""
        if self.subscription_task:
            self.subscription_task.cancel()
        if self.pubsub:
            await self.pubsub.unsubscribe(CHANNEL_NAME)
        if self.redis:
            await self.redis.close()


manager = ConnectionManager()


# ============= FastAPI 路由 =============
@app.on_event("startup")
async def startup():
    await manager.initialize()


@app.on_event("shutdown")
async def shutdown():
    await manager.shutdown()


@app.websocket("/ws/{user_id}")
async def websocket_endpoint(ws: WebSocket, user_id: str):
    await manager.connect(ws, user_id)
    try:
        await ws.send_text(f"欢迎,用户 {user_id} 已连接!")

        while True:
            msg = await ws.receive_text()
            if msg.startswith("join:"):
                await manager.join_room(user_id, msg.split(":", 1)[1])
            elif msg.startswith("room:"):
                _, rid, content = msg.split(":", 2)
                await manager._send_room_local(f"[{rid}] {user_id}: {content}", rid)
                await manager._publish_message({
                    "type": "room_message",
                    "message": f"[{rid}] {user_id}: {content}",
                    "room_id": rid,
                    "server_id": SERVER_ID,
                })
            else:
                await manager.broadcast(f"{user_id}: {msg}")
    except WebSocketDisconnect:
        await manager.disconnect(user_id)
    except Exception as e:
        log(f"WebSocket异常: {e}\n{traceback.format_exc()}")
        await manager.disconnect(user_id)


@app.get("/")
async def get_page():
    return HTMLResponse("<h3>WebSocket服务器运行中</h3><p>使用 ws://localhost:8000/ws/{user_id} 连接</p>")


@app.get("/push/{user_id}/{msg}")
async def push_message(user_id: str, msg: str):
    await manager.send_personal_message(msg, user_id)
    return {"status": "sent", "user_id": user_id, "message": msg}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

 

2.客户端通过postman 进行连接

image

 

3.通过接口向客户端推送消息 或直接调用send_personal_message

http://0.0.0.0:8000/push/oneday/hello_abc

await manager.send_personal_message(msg, user_id)

 

posted on 2025-10-11 18:08  星河赵  阅读(1)  评论(0)    收藏  举报

导航