大词表导致训练时的巨大的显存占用
问题解析
如果 3w 的词表导致了显存爆炸,通常问题不在于参数量(Parameters)本身,而在于训练过程中计算 Loss 时产生的中间激活值(Activations),特别是在输出层(Logits)的计算上。
模型为50m,length=2048,batch 为64.

A100训练Speed: 1.12s/it
切片计算 Loss (Chunked Cross Entropy)
这是目前最直接、最有效的解决“显存爆炸”的方法。在标准的 Cross Entropy Loss 计算中,模型会生成一个巨大的 Logits 张量,形状为 [Batch_Size, Seq_Len, Vocab_Size]。
假设 Batch=4, Seq=4096, Vocab=30000,使用 fp16。
这个 Tensor 大小约为:
\(4 \times 4096 \times 30000 \times 2 \text{ bytes} \approx 983 \text{ MB}\)。
这只是 logits,还需要存储梯度,显存占用会瞬间翻倍。
如果是 Llama-3 的 128k 词表,这个数字会扩大 4 倍。
解决方案: 不要一次性把所有 logits 算出来。
使用 Fused Cross Entropy 或 Chunked Cross Entropy。
原理: 将 Logits 的计算和 Loss 的计算融合,或者分块(Chunk)进行。每次只计算一小部分 Token 的 Loss,反向传播后释放显存,再算下一块。
现成工具:
Liger Kernel : LinkedIn 开源的高效 Kernel,专门优化了 CrossEntropyLoss,可以大幅降低显存。
import torch
# 引入 Liger 的 Loss
from liger_kernel.transformers import LigerCrossEntropyLoss
# 替换原本的 torch.nn.CrossEntropyLoss
loss_fct = LigerCrossEntropyLoss()
[2] NVIDIA A100X | 62°C, 99 % | 30316 / 81920 MB
1.13s/it
降低了16g的显存。
Selective Cross Entropy
在标准的训练流程中,即使你对全量序列(Batch, Seq)做了 lm_head 投影,但在计算 CrossEntropyLoss 时,那些 label == -100 的位置产生的 Loss 权重是 0。
因此可以只推理被MASK的位置。
masked_lm_loss = None
if labels is not None:
# 2. 确定哪些位置是需要预测的 (避开 labels == -100 的位置)
# masked_bool_mask 形状: [Batch, Seq]
masked_bool_mask = (labels != -100)
# 3. 核心步骤:只提取被 Mask 位置的向量
# relevant_hidden 形状: [N_masked, Hidden]
# N_masked 远小于 Batch * Seq,显存压力瞬间释放
relevant_hidden = sequence_output[masked_bool_mask]
relevant_labels = labels[masked_bool_mask].to(relevant_hidden.device)
if relevant_hidden.numel() > 0:
# 4. 只对这部分“稀疏”的向量做全词表投影
# prediction_scores 现在的形状是 [N_masked, 20000]
# 而不是之前的 [Batch, Seq, 20000]
prediction_scores = self.lm_head(relevant_hidden)
# 5. 计算 Loss
loss_fct = LigerCrossEntropyLoss() #nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores, relevant_labels)
else:
# 万一这批数据没被 mask (极端情况)
masked_lm_loss = sequence_output.sum() * 0
else:
# 推理模式下,如果需要全量预测,则保留原样
prediction_scores = self.lm_head(sequence_output)
NVIDIA A100X | 69°C, 100 % | 19274 / 81920 MB
1.05s/it
显存又降低了。

浙公网安备 33010602011771号