【mysql】异步mysql client

如果在企业生产环境要python编程实现访问数据库包括mysql,要求能包含以下功能:
- 超时策略:connect / read(流) / total 分开(流式一般 total 不宜太小)
- 重试策略:只对“可重试”的错误重试(网络错误、429、部分 5xx),并且指数退避+抖动
- 并发与连接池:每个 host 的并发上限、全局并发上限、keep-alive
- 限流与排队:令牌桶/漏桶;避免瞬时把 429 打爆
- 流式处理背压:消费者慢时别把内存堆爆(逐块处理、超时与断线恢复)

用什么工具最好呢

那就别把它拆成一堆零散 helper 了,直接做一个异步 DB client 层

推荐栈:SQLAlchemy 2.x Async + asyncmy + Tenacity + asyncio.Queue/Semaphore

原因是:

  • SQLAlchemy 2.x 官方提供 MySQL 的 async 方言,支持 asyncmyaiomysql。(SQLAlchemy 文档)
  • asyncmy 是面向 asyncio 的 MySQL/MariaDB 驱动,项目说明里明确主打 “fast asyncio MySQL/MariaDB driver”。(GitHub)
  • Tenacity 官方支持指数退避和抖动,wait_random_exponential / wait_exponential_jitter 都适合分布式争抢场景。(Tenacity)

我会建议你在 FastAPI 里采用这套分层:

结构

  • DBClient

    • 管理 AsyncEngine
    • 管理 async_sessionmaker
    • 封装重试、总超时、并发控制
    • 提供 session() / execute() / stream() 接口
  • FastAPI

    • startup 初始化 client
    • shutdown 关闭 engine
    • 用依赖注入给 handler 提供 AsyncSession

关键结论

1) 驱动选型

优先 asyncmy,保留 aiomysql 作为兼容备选。

因为 SQLAlchemy 官方 async MySQL 方言里两者都支持。(SQLAlchemy 文档)

2) 连接池

连接池交给 SQLAlchemy engine:

  • pool_size
  • max_overflow
  • pool_timeout
  • pool_recycle
  • pool_pre_ping=True

其中 pool_pre_ping / pool_recycle 是处理失效连接和长连接老化的常用配置。(SQLAlchemy 文档)

3) 重试

只重试可重试异常

  • 建连失败
  • 短暂网络抖动
  • 连接被服务端断开
  • 死锁/锁等待这类你明确判定可重试的只读或幂等请求

不要对“已进入事务且可能部分提交”的写操作盲目自动重试。

4) 流式处理

大结果集不要 all() / fetchall(),改成分批拉取 + 逐批处理
SQLAlchemy 支持流式结果相关能力。(SQLAlchemy 文档)

5) FastAPI 集成

FastAPI 里不要每个请求新建 engine。
engine 是应用级单例,session 是请求级。


一个可落地的 client 骨架

1. db_client.py

import asyncio
import random
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Optional, Sequence

from sqlalchemy import text
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.ext.asyncio import (
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)

from tenacity import (
    AsyncRetrying,
    retry_if_exception,
    stop_after_attempt,
    wait_random_exponential,
)


@dataclass(slots=True)
class DBClientConfig:
    dsn: str  # e.g. mysql+asyncmy://user:pass@host:3306/dbname?charset=utf8mb4
    pool_size: int = 20
    max_overflow: int = 20
    pool_timeout: int = 10
    pool_recycle: int = 1800
    pool_pre_ping: bool = True

    connect_timeout: float = 3.0
    read_timeout: float = 30.0
    write_timeout: float = 30.0
    total_timeout: float = 60.0

    max_retries: int = 3

    global_concurrency: int = 100
    host_concurrency: int = 50

    stream_chunk_size: int = 500


class DBRetryableError(Exception):
    pass


def is_retryable_db_error(exc: BaseException) -> bool:
    """
    这里只做保守判定:
    - 网络/连接中断类
    - 连接池拿连接超时(可视场景决定)
    - DBAPIError 且 connection_invalidated=True
    """
    if isinstance(exc, asyncio.TimeoutError):
        return True

    if isinstance(exc, DBAPIError):
        if getattr(exc, "connection_invalidated", False):
            return True

    if isinstance(exc, OperationalError):
        # 这里可按你的驱动/错误码再细化
        return True

    if isinstance(exc, DBRetryableError):
        return True

    return False


class TokenBucket:
    """
    简化版异步令牌桶:
    rate = 每秒补充 token 数
    capacity = 桶容量
    """
    def __init__(self, rate: float, capacity: int):
        self._rate = rate
        self._capacity = capacity
        self._tokens = float(capacity)
        self._updated = asyncio.get_running_loop().time()
        self._lock = asyncio.Lock()

    async def acquire(self, tokens: float = 1.0) -> None:
        async with self._lock:
            while True:
                now = asyncio.get_running_loop().time()
                elapsed = now - self._updated
                self._updated = now
                self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)

                if self._tokens >= tokens:
                    self._tokens -= tokens
                    return

                need = (tokens - self._tokens) / self._rate
                await asyncio.sleep(max(need, 0.001))


class DBClient:
    def __init__(self, config: DBClientConfig):
        self.config = config
        self.engine: AsyncEngine = create_async_engine(
            config.dsn,
            pool_size=config.pool_size,
            max_overflow=config.max_overflow,
            pool_timeout=config.pool_timeout,
            pool_recycle=config.pool_recycle,
            pool_pre_ping=config.pool_pre_ping,
            pool_use_lifo=True,
            connect_args={
                # asyncmy 常见可透传的超时参数;具体可按你的驱动版本验证
                "connect_timeout": config.connect_timeout,
                "read_timeout": config.read_timeout,
                "write_timeout": config.write_timeout,
            },
        )
        self.session_factory = async_sessionmaker(
            self.engine,
            class_=AsyncSession,
            expire_on_commit=False,
            autoflush=False,
        )

        self._global_sem = asyncio.Semaphore(config.global_concurrency)
        self._host_sem = asyncio.Semaphore(config.host_concurrency)
        self._rate_limiter = TokenBucket(rate=50, capacity=100)  # 示例值

    async def close(self) -> None:
        await self.engine.dispose()

    @asynccontextmanager
    async def session(self) -> AsyncIterator[AsyncSession]:
        async with self._global_sem, self._host_sem:
            await self._rate_limiter.acquire(1)
            async with self.session_factory() as session:
                try:
                    yield session
                except Exception:
                    await session.rollback()
                    raise

    async def _run_with_retry(self, func: Callable[[], Any]) -> Any:
        async for attempt in AsyncRetrying(
            stop=stop_after_attempt(self.config.max_retries),
            wait=wait_random_exponential(multiplier=0.5, max=8),
            retry=retry_if_exception(is_retryable_db_error),
            reraise=True,
        ):
            with attempt:
                return await asyncio.wait_for(func(), timeout=self.config.total_timeout)

    async def execute(
        self,
        sql: str,
        params: Optional[dict[str, Any]] = None,
        *,
        commit: bool = False,
    ):
        async def _op():
            async with self.session() as session:
                result = await session.execute(text(sql), params or {})
                if commit:
                    await session.commit()
                return result

        return await self._run_with_retry(_op)

    async def fetch_all(
        self,
        sql: str,
        params: Optional[dict[str, Any]] = None,
    ) -> list[dict[str, Any]]:
        result = await self.execute(sql, params)
        return [dict(row) for row in result.mappings().all()]

    async def fetch_one(
        self,
        sql: str,
        params: Optional[dict[str, Any]] = None,
    ) -> Optional[dict[str, Any]]:
        result = await self.execute(sql, params)
        row = result.mappings().first()
        return dict(row) if row else None

    async def stream(
        self,
        sql: str,
        params: Optional[dict[str, Any]] = None,
        *,
        chunk_size: Optional[int] = None,
    ) -> AsyncIterator[list[dict[str, Any]]]:
        """
        逐批流式返回,避免一次性堆内存
        """
        chunk_size = chunk_size or self.config.stream_chunk_size

        async with self._global_sem, self._host_sem:
            await self._rate_limiter.acquire(1)

            async with self.session_factory() as session:
                try:
                    result = await asyncio.wait_for(
                        session.stream(text(sql), params or {}),
                        timeout=self.config.total_timeout,
                    )

                    batch: list[dict[str, Any]] = []
                    async for row in result.mappings():
                        batch.append(dict(row))
                        if len(batch) >= chunk_size:
                            yield batch
                            batch = []

                    if batch:
                        yield batch
                except Exception:
                    await session.rollback()
                    raise

2. FastAPI 集成

from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, Request
from sqlalchemy.ext.asyncio import AsyncSession

from .db_client import DBClient, DBClientConfig


@asynccontextmanager
async def lifespan(app: FastAPI):
    app.state.db = DBClient(
        DBClientConfig(
            dsn="mysql+asyncmy://user:pass@127.0.0.1:3306/app?charset=utf8mb4",
            pool_size=20,
            max_overflow=20,
            pool_timeout=10,
            pool_recycle=1800,
            connect_timeout=3,
            read_timeout=30,
            write_timeout=30,
            total_timeout=60,
            max_retries=3,
            global_concurrency=100,
            host_concurrency=50,
            stream_chunk_size=500,
        )
    )
    try:
        yield
    finally:
        await app.state.db.close()


app = FastAPI(lifespan=lifespan)


def get_db_client(request: Request) -> DBClient:
    return request.app.state.db


async def get_db_session(
    db: DBClient = Depends(get_db_client),
) -> AsyncSession:
    async with db.session() as session:
        yield session

3. 普通查询

from fastapi import APIRouter, Depends
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

router = APIRouter()


@router.get("/users/{user_id}")
async def get_user(user_id: int, session: AsyncSession = Depends(get_db_session)):
    result = await session.execute(
        text("SELECT id, name, email FROM users WHERE id=:id"),
        {"id": user_id},
    )
    row = result.mappings().first()
    return dict(row) if row else {"message": "not found"}

4. 流式导出,避免内存爆掉

from fastapi import Depends
from fastapi.responses import StreamingResponse
import json


@app.get("/export")
async def export_data(db: DBClient = Depends(get_db_client)):
    async def gen():
        async for batch in db.stream(
            "SELECT id, name, created_at FROM big_table ORDER BY id"
        ):
            for item in batch:
                yield (json.dumps(item, ensure_ascii=False) + "\n").encode("utf-8")

    return StreamingResponse(gen(), media_type="application/x-ndjson")

这个模式的好处是:

  • 每次只保留一小批数据
  • 下游客户端读得慢,yield 自然形成背压
  • 不会把整个结果集堆进内存

生产里你还需要再补的几条

A. 重试边界

建议只把自动重试放在:

  • SELECT
  • 幂等读请求
  • 明确可回放的操作

INSERT/UPDATE/DELETE,尤其事务里多语句写入,默认别自动重试,除非你自己定义了业务幂等键。

B. 总超时和流式超时分开

你前面说得对:流式读取的 total 不宜太小
所以更合理的是拆成:

  • query_total_timeout
  • stream_open_timeout
  • stream_idle_timeout

也就是:

  • 打开结果流超时
  • 每批之间的空闲超时
  • 整体允许更长

C. 背压

如果你要“生产者查库 -> 消费者写文件/推 Kafka”,中间加:

queue = asyncio.Queue(maxsize=100)

让查库协程把批次放进有界队列,消费者慢时,生产者会自然阻塞,不会无限吃内存。

D. 观测

给 client 加:

  • 查询耗时
  • 重试次数
  • 池等待时间
  • 当前并发数
  • 流式批大小
  • 超时/断连计数

不然线上出了慢查询和连接池打满,很难定位。


我最终建议

在 FastAPI 里就定这个方案:

SQLAlchemy Async + asyncmy + Tenacity + 应用级 DBClient 单例
SQLAlchemy 官方支持 MySQL async 方言,asyncmy/aiomysql 可用;Tenacity 官方提供指数退避和抖动;engine 做应用单例、session 做请求级,这是这套组合最稳的用法。(SQLAlchemy 文档)

你要的话,我下一条直接给你一版更完整的生产级代码,把这些都补进去:

  • 可重试错误码分类
  • 读写分离接口
  • 有界队列背压
  • 流式断线恢复
  • FastAPI 依赖注入与中间件日志
posted @ 2026-03-17 14:21  X1OO  阅读(21)  评论(0)    收藏  举报