from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Depends
from fastapi.responses import StreamingResponse
import aiohttp
import asyncio
import random
import json
import logging
from typing import AsyncGenerator, Optional
logger = logging.getLogger(__name__)
class LLMHttpClient:
def __init__(
self,
base_url: str = "",
max_connections: int = 500,
max_connections_per_host: int = 100,
concurrency: int = 200,
timeout_total: Optional[float] = None,
timeout_connect: float = 10,
timeout_read: float = 30,
max_retries: int = 4,
):
self.base_url = base_url
self.max_retries = max_retries
self._sem = asyncio.Semaphore(concurrency)
self.session: aiohttp.ClientSession | None = None
self._timeout = aiohttp.ClientTimeout(
total=timeout_total,
connect=timeout_connect,
sock_connect=timeout_connect,
sock_read=timeout_read,
)
self._connector = aiohttp.TCPConnector(
limit=max_connections,
limit_per_host=max_connections_per_host,
ttl_dns_cache=300,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
self._trace_configs = [self._create_trace()]
async def start(self) -> None:
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
timeout=self._timeout,
connector=self._connector,
trace_configs=self._trace_configs,
)
async def close(self) -> None:
if self.session and not self.session.closed:
await self.session.close()
def _create_trace(self) -> aiohttp.TraceConfig:
trace = aiohttp.TraceConfig()
@trace.on_request_start.append
async def on_start(session, ctx, params):
ctx.started_at = asyncio.get_running_loop().time()
@trace.on_request_end.append
async def on_end(session, ctx, params):
cost_ms = (asyncio.get_running_loop().time() - ctx.started_at) * 1000
logger.info("llm_http %s %s %.1fms", params.method, params.url, cost_ms)
return trace
async def _backoff(self, attempt: int) -> None:
base = 0.5
cap = 8.0
sleep_s = min(cap, base * (2 ** (attempt - 1))) * random.uniform(0.8, 1.2)
await asyncio.sleep(sleep_s)
async def _request_with_retry(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
assert self.session is not None, "client not started"
full_url = self.base_url + url
last_exc = None
for attempt in range(1, self.max_retries + 1):
try:
resp = await self.session.request(method, full_url, **kwargs)
if resp.status in {429, 500, 502, 503, 504}:
if attempt == self.max_retries:
try:
body = await resp.text()
finally:
resp.close()
raise RuntimeError(f"upstream status={resp.status}, body={body[:500]}")
resp.close()
await self._backoff(attempt)
continue
return resp
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
last_exc = e
if attempt == self.max_retries:
raise
await self._backoff(attempt)
raise last_exc or RuntimeError("request failed")
async def call_json(self, method: str, url: str, json_body=None, headers=None) -> dict:
async with self._sem:
resp = await self._request_with_retry(method, url, json=json_body, headers=headers)
try:
return await resp.json()
finally:
resp.close()
async def call_stream(
self,
method: str,
url: str,
json_body=None,
headers=None,
) -> AsyncGenerator[str, None]:
async with self._sem:
resp = await self._request_with_retry(method, url, json=json_body, headers=headers)
try:
buffer = b""
async for chunk in resp.content.iter_chunked(1024):
buffer += chunk
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
text = line.decode("utf-8", "ignore").strip()
if not text:
continue
if text.startswith("data:"):
data = text[5:].strip()
if data == "[DONE]":
return
yield data
finally:
resp.close()
@asynccontextmanager
async def lifespan(app: FastAPI):
client = LLMHttpClient(base_url="https://api.openai.com")
await client.start()
app.state.llm_client = client
try:
yield
finally:
await client.close()
app = FastAPI(lifespan=lifespan)
def get_llm_client(request: Request) -> LLMHttpClient:
return request.app.state.llm_client
@app.post("/chat")
async def chat(request: Request, client: LLMHttpClient = Depends(get_llm_client)):
payload = {
"model": "gpt-4.1-mini",
"stream": True,
"messages": [{"role": "user", "content": "hello"}],
}
headers = {"Authorization": "Bearer xxx"}
async def event_generator():
async for item in client.call_stream("POST", "/v1/chat/completions", json_body=payload, headers=headers):
if await request.is_disconnected():
break
yield f"data: {item}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")