max-num-batched-tokens计算方法


明白,你想深入了解 --max-num-batched-tokens 的范围和计算方法,不仅是经验值,也要能理性说明。我们分步骤分析。


1️⃣ 参数本质回顾

  • --max-num-batched-tokens:一次 batch 中 GPU 上允许的 token 总数上限
  • 公式关系:

[
\text{max_num_batched_tokens} \approx \text{max_num_seqs} \times \text{平均每序列 token 数}
]

  • 作用:控制 显存消耗,防止 batch 太大导致 GPU OOM,同时保证吞吐量。

2️⃣ 理论计算方法

  1. 确定可用显存

    • 总显存 × --gpu-memory-utilization
    • 例如 RTX 4090 24GB,利用率 0.93 → 22.3GB 可用
  2. 估算模型 KV cache 占用

    KV cache ≈ seq_len × hidden_size × num_layers × 2(key+value) × 数据类型字节数

    示例(14B 模型,8-bit):

    参数 估算值
    hidden_size 5120
    num_layers 32
    seq_len 256
    数据类型 int8 = 1 byte
    KV cache/seq 5120 × 32 × 256 × 2 × 1 ≈ 83 MB/seq
  3. 计算最大可承载序列数

[
\text{max_num_seqs} \approx \frac{\text{可用显存} - \text{模型权重显存}}{\text{KV cache/seq}}
]

  • 假设模型 14B int8 权重 ≈ 10GB
  • 剩余显存 = 22.3 - 10 ≈ 12.3 GB
  • 每 seq KV cache ≈ 83 MB → 最大 seq 数 ≈ 12.3GB / 83MB ≈ 148
  • 所以经验值 128 是安全范围
  1. 计算 max_num_batched_tokens

[
\text{max_num_batched_tokens} = \text{max_num_seqs} \times \text{平均 seq 长度}
]

  • 平均 seq_len = 256 → 128 × 256 = 32768

3️⃣ 一般范围

  • max_num_seqs:几十到几百

  • max_num_batched_tokens:几千到几十万

  • 实际取值取决于:

    1. GPU 显存大小
    2. 模型大小(参数数量、量化方式)
    3. 平均输入长度(用户请求或 prompt 长度)

举例

GPU/显存 模型 平均 seq_len max_num_seqs max_num_batched_tokens
24GB 4090 14B 8-bit 256 128 32k
24GB 4090 7B 8-bit 512 256 128k
80GB A100 30B fp16 1024 64 64k

结论:max_num_batched_tokens 主要是 KV cache 可承载 token 数的安全上限,也要与吞吐量优化结合。

seq计算

好的,我们来仔细算一遍 为什么每条序列(seq)的 KV cache 大概是 83MB,我会按 逐步公式计算,保证面试上能解释得清楚。


1️⃣ KV cache 本质

  • Transformer 在推理时,每一层都会缓存 key (K) 和 value (V)
  • 对于每个序列:

[
\text{KV cache per seq} = \text{seq_len} \times \text{hidden_size} \times 2 (\text{K+V}) \times \text{num_layers} \times \text{数据类型字节数}
]


2️⃣ 假设参数

14B 模型,8-bit 量化 为例:

参数
seq_len 256
hidden_size 5120
num_layers 32
数据类型 int8 → 1 byte
K+V 2

3️⃣ 代入公式

[
\text{KV cache per seq} = 256 \times 5120 \times 2 \times 32 \times 1 \text{ byte}
]

逐步计算:

  1. 5120 × 2 = 10240
  2. 10240 × 32 = 327680
  3. 327680 × 256 = 83,886,080 bytes

换算成 MB:

[
83,886,080 / (1024 \times 1024) \approx 79.96 \text{ MB} \approx 80MB
]

所以 每条 seq 的 KV cache 大约 80MB,之前我说 83MB 是略微向上取整的经验值。


4️⃣ 面试说明点

  • 为什么和模型大小有关:hidden_size 越大、num_layers 越多,KV cache 越大
  • 为什么和 seq_len 有关:每增加一个 token,KV cache 都会增加 hidden_size×2×num_layers 个元素
  • 数据类型影响显存:fp16 每元素占 2 byte,int8 量化每元素占 1 byte

公式简化版方便记忆:

KV_cache_per_seq ≈ seq_len × hidden_size × num_layers × 2 × bytes_per_element
posted @ 2026-01-21 17:07  向着朝阳  阅读(1)  评论(0)    收藏  举报