【aiohttp】 使用说明 - llm封装版

from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Depends
from fastapi.responses import StreamingResponse
import aiohttp
import asyncio
import random
import json
import logging
from typing import AsyncGenerator, Optional

logger = logging.getLogger(__name__)


class LLMHttpClient:
    def __init__(
        self,
        base_url: str = "",
        max_connections: int = 500,
        max_connections_per_host: int = 100,
        concurrency: int = 200,
        timeout_total: Optional[float] = None,
        timeout_connect: float = 10,
        timeout_read: float = 30,
        max_retries: int = 4,
    ):
        self.base_url = base_url
        self.max_retries = max_retries
        self._sem = asyncio.Semaphore(concurrency)
        self.session: aiohttp.ClientSession | None = None

        self._timeout = aiohttp.ClientTimeout(
            total=timeout_total,
            connect=timeout_connect,
            sock_connect=timeout_connect,
            sock_read=timeout_read,
        )
        self._connector = aiohttp.TCPConnector(
            limit=max_connections,
            limit_per_host=max_connections_per_host,
            ttl_dns_cache=300,
            keepalive_timeout=30,
            enable_cleanup_closed=True,
        )
        self._trace_configs = [self._create_trace()]

    async def start(self) -> None:
        if self.session is None or self.session.closed:
            self.session = aiohttp.ClientSession(
                timeout=self._timeout,
                connector=self._connector,
                trace_configs=self._trace_configs,
            )

    async def close(self) -> None:
        if self.session and not self.session.closed:
            await self.session.close()

    def _create_trace(self) -> aiohttp.TraceConfig:
        trace = aiohttp.TraceConfig()

        @trace.on_request_start.append
        async def on_start(session, ctx, params):
            ctx.started_at = asyncio.get_running_loop().time()

        @trace.on_request_end.append
        async def on_end(session, ctx, params):
            cost_ms = (asyncio.get_running_loop().time() - ctx.started_at) * 1000
            logger.info("llm_http %s %s %.1fms", params.method, params.url, cost_ms)

        return trace

    async def _backoff(self, attempt: int) -> None:
        base = 0.5
        cap = 8.0
        sleep_s = min(cap, base * (2 ** (attempt - 1))) * random.uniform(0.8, 1.2)
        await asyncio.sleep(sleep_s)

    async def _request_with_retry(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
        assert self.session is not None, "client not started"

        full_url = self.base_url + url
        last_exc = None

        for attempt in range(1, self.max_retries + 1):
            try:
                resp = await self.session.request(method, full_url, **kwargs)

                if resp.status in {429, 500, 502, 503, 504}:
                    if attempt == self.max_retries:
                        try:
                            body = await resp.text()
                        finally:
                            resp.close()
                        raise RuntimeError(f"upstream status={resp.status}, body={body[:500]}")
                    resp.close()
                    await self._backoff(attempt)
                    continue

                return resp

            except (aiohttp.ClientError, asyncio.TimeoutError) as e:
                last_exc = e
                if attempt == self.max_retries:
                    raise
                await self._backoff(attempt)

        raise last_exc or RuntimeError("request failed")

    async def call_json(self, method: str, url: str, json_body=None, headers=None) -> dict:
        async with self._sem:
            resp = await self._request_with_retry(method, url, json=json_body, headers=headers)
            try:
                return await resp.json()
            finally:
                resp.close()

    async def call_stream(
        self,
        method: str,
        url: str,
        json_body=None,
        headers=None,
    ) -> AsyncGenerator[str, None]:
        async with self._sem:
            resp = await self._request_with_retry(method, url, json=json_body, headers=headers)
            try:
                buffer = b""
                async for chunk in resp.content.iter_chunked(1024):
                    buffer += chunk
                    while b"\n" in buffer:
                        line, buffer = buffer.split(b"\n", 1)
                        text = line.decode("utf-8", "ignore").strip()
                        if not text:
                            continue
                        if text.startswith("data:"):
                            data = text[5:].strip()
                            if data == "[DONE]":
                                return
                            yield data
            finally:
                resp.close()


@asynccontextmanager
async def lifespan(app: FastAPI):
    client = LLMHttpClient(base_url="https://api.openai.com")
    await client.start()
    app.state.llm_client = client
    try:
        yield
    finally:
        await client.close()


app = FastAPI(lifespan=lifespan)


def get_llm_client(request: Request) -> LLMHttpClient:
    return request.app.state.llm_client


@app.post("/chat")
async def chat(request: Request, client: LLMHttpClient = Depends(get_llm_client)):
    payload = {
        "model": "gpt-4.1-mini",
        "stream": True,
        "messages": [{"role": "user", "content": "hello"}],
    }
    headers = {"Authorization": "Bearer xxx"}

    async def event_generator():
        async for item in client.call_stream("POST", "/v1/chat/completions", json_body=payload, headers=headers):
            if await request.is_disconnected():
                break
            yield f"data: {item}\n\n"

    return StreamingResponse(event_generator(), media_type="text/event-stream")
posted @ 2026-03-06 20:47  X1OO  阅读(1)  评论(0)    收藏  举报