大词表导致训练时的巨大的显存占用

问题解析

如果 3w 的词表导致了显存爆炸,通常问题不在于参数量(Parameters)本身,而在于训练过程中计算 Loss 时产生的中间激活值(Activations),特别是在输出层(Logits)的计算上。

模型为50m,length=2048,batch 为64.
image
A100训练Speed: 1.12s/it

切片计算 Loss (Chunked Cross Entropy)

这是目前最直接、最有效的解决“显存爆炸”的方法。在标准的 Cross Entropy Loss 计算中,模型会生成一个巨大的 Logits 张量,形状为 [Batch_Size, Seq_Len, Vocab_Size]

假设 Batch=4, Seq=4096, Vocab=30000,使用 fp16。
这个 Tensor 大小约为:
\(4 \times 4096 \times 30000 \times 2 \text{ bytes} \approx 983 \text{ MB}\)
这只是 logits,还需要存储梯度,显存占用会瞬间翻倍。
如果是 Llama-3 的 128k 词表,这个数字会扩大 4 倍。

解决方案: 不要一次性把所有 logits 算出来。
使用 Fused Cross Entropy 或 Chunked Cross Entropy。
原理: 将 Logits 的计算和 Loss 的计算融合,或者分块(Chunk)进行。每次只计算一小部分 Token 的 Loss,反向传播后释放显存,再算下一块。

现成工具:
Liger Kernel : LinkedIn 开源的高效 Kernel,专门优化了 CrossEntropyLoss,可以大幅降低显存。

import torch
# 引入 Liger 的 Loss
from liger_kernel.transformers import LigerCrossEntropyLoss

# 替换原本的 torch.nn.CrossEntropyLoss
loss_fct = LigerCrossEntropyLoss()

[2] NVIDIA A100X | 62°C, 99 % | 30316 / 81920 MB
1.13s/it

降低了16g的显存。

Selective Cross Entropy

在标准的训练流程中,即使你对全量序列(Batch, Seq)做了 lm_head 投影,但在计算 CrossEntropyLoss 时,那些 label == -100 的位置产生的 Loss 权重是 0。
因此可以只推理被MASK的位置。

masked_lm_loss = None
if labels is not None:
	# 2. 确定哪些位置是需要预测的 (避开 labels == -100 的位置)
	# masked_bool_mask 形状: [Batch, Seq]
	masked_bool_mask = (labels != -100)

	# 3. 核心步骤:只提取被 Mask 位置的向量
	# relevant_hidden 形状: [N_masked, Hidden]
	# N_masked 远小于 Batch * Seq,显存压力瞬间释放
	relevant_hidden = sequence_output[masked_bool_mask]
	relevant_labels = labels[masked_bool_mask].to(relevant_hidden.device)

	if relevant_hidden.numel() > 0:
		# 4. 只对这部分“稀疏”的向量做全词表投影
		# prediction_scores 现在的形状是 [N_masked, 20000]
		# 而不是之前的 [Batch, Seq, 20000]
		prediction_scores = self.lm_head(relevant_hidden)

		# 5. 计算 Loss
		loss_fct = LigerCrossEntropyLoss() #nn.CrossEntropyLoss()
		masked_lm_loss = loss_fct(prediction_scores, relevant_labels)
	else:
		# 万一这批数据没被 mask (极端情况)
		masked_lm_loss = sequence_output.sum() * 0
	else:
		# 推理模式下,如果需要全量预测,则保留原样
		prediction_scores = self.lm_head(sequence_output)

NVIDIA A100X | 69°C, 100 % | 19274 / 81920 MB
1.05s/it

显存又降低了。

posted @ 2026-02-02 22:15  ylifs  阅读(0)  评论(0)    收藏  举报