Qwen2.5-3B 奖励模型显存瘦身术:轻量运行,性能不减
要优化 Qwen2.5-3B 奖励模型的显存占用,需从模型加载、数据处理、训练策略三个维度综合优化。以下是具体可落地的方案,按显存节省效果排序:
一、缩短序列长度(显存优化最显著)
Qwen2.5-3B 在 8192 tokens 时显存需求极高,优先缩短序列长度:
python运行
# 在tokenizer编码时减小max_length(根据数据实际长度调整,如4096或2048)
chosen_encoding = tokenizer(
text=batch["chosen_inputs"],
padding=True,
truncation=True,
max_length=4096, # 从8192→4096(显存占用可减少约50%)
return_tensors="pt"
)
rejected_encoding = tokenizer(
text=batch["rejected_inputs"],
padding=True,
truncation=True,
max_length=4096, # 同步修改
return_tensors="pt"
)
原理:Transformer 的显存占用与序列长度的平方成正比(
O(n²)
),缩短长度能显著降低自注意力矩阵的计算量和存储需求。二、启用梯度检查点(以计算换显存)
牺牲 20%-30% 的计算速度,换取 50% 左右的显存节省:
python运行
# 加载基础模型时启用梯度检查点
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map={"": accelerator.device},
gradient_checkpointing=True # 关键:不保存中间激活值,反向传播时重新计算
)
# 手动启用梯度检查点(部分模型需要显式调用)
base_model.gradient_checkpointing_enable()
原理:默认情况下,模型会保存前向传播的所有中间激活值用于反向传播,梯度检查点仅保存部分关键值,反向传播时重新计算其他值,大幅减少显存占用。
三、冻结基础模型,仅训练奖励头
若只需微调奖励头(而非整个基础模型),可冻结基础模型参数:
python运行
# 初始化奖励模型后,冻结基础模型参数
for param in reward_model.base_model.parameters():
param.requires_grad = False # 基础模型参数不更新
# 仅奖励头参数可训练(已默认开启)
for param in reward_model.reward_head.parameters():
param.requires_grad = True
原理:Qwen2.5-3B 的 3B 参数中,奖励头仅占约 1%(隐藏层大小 ×1,如 768×1),冻结基础模型后,需要更新的参数从 3B→768,显存需求大幅降低。
四、优化批次大小与梯度累积
在保持总批次等效的前提下,进一步减小单步批次:
python运行
# 1. 减小DataLoader的batch_size(最小为1)
train_loader = torch.utils.data.DataLoader(
processed_dataset,
batch_size=1, # 已为最小值,无法再减小
shuffle=True,
pin_memory=True
)
# 2. 增大梯度累积步数(保持总批次=batch_size×gradient_accumulation_steps)
accelerator = Accelerator(
mixed_precision="fp16",
gradient_accumulation_steps=8 # 从4→8,总批次不变,但单步显存需求降低
)
原理:单步批次越小,单次前向 / 反向传播的显存占用越低,梯度累积通过多步累积梯度来保持总批次等效,平衡显存与训练稳定性。
五、启用 DeepSpeed ZeRO-3 分布式优化(多 GPU 场景)
若使用多 GPU 训练(如用户的 4 卡环境),通过 DeepSpeed 的 ZeRO-3 分片参数、梯度和优化器状态:
- 创建 ds_zero3_config.yaml:
train_batch_size: 8 # 总批次=batch_size×gradient_accumulation_steps×num_gpus
gradient_accumulation_steps: 8
optimizer:
type: Adam
params:
lr: 5e-7
betas: [0.9, 0.95]
eps: 1e-8
zero_optimization:
stage: 3 # 关键:分片模型参数、梯度、优化器状态
offload_optimizer:
device: cpu # 优化器状态移至CPU(可选,进一步节省GPU显存)
contiguous_gradients: true
overlap_comm: true # 通信与计算重叠
- 启动命令:
CUDA_VISIBLE_DEVICES=0,1,2,3 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
accelerate launch --num_processes=4 --config_file ds_zero3_config.yaml qwen_3B_train_0721_mix_checkpoint_8192.py
原理:ZeRO-3 将模型参数、梯度和优化器状态分片到多个 GPU/CPU,每个 GPU 仅保存部分数据,4 卡环境下理论显存占用可降至 1/4。
六、清理无效显存与减少碎片
在训练循环中主动释放未使用的张量和缓存:
# 在每个batch处理结束后添加
# 1. 手动删除不再需要的中间变量
del chosen_rewards, rejected_rewards, loss
# 2. 清理CUDA缓存(释放未使用的显存块)
torch.cuda.empty_cache()
原理:训练过程中,未显式删除的张量会占用显存,
empty_cache()
可释放已标记为 “未使用” 但未实际回收的显存,减少碎片。七、使用更高效的精度(如 bf16)
若 GPU 支持 bfloat16(如 A100、H100),可替换 fp16 进一步节省显存:
python运行
accelerator = Accelerator(
mixed_precision="bf16", # 从fp16→bf16(显存占用相近,但数值范围更大,训练更稳定)
gradient_accumulation_steps=8
)
# 模型加载时同步使用bf16
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # 与混合精度匹配
trust_remote_code=True,
device_map={"": accelerator.device},
gradient_checkpointing=True
)
原理:bf16 与 fp16 显存占用相同(均为 2 字节),但无需像 fp16 那样处理数值范围限制,可减少因溢出导致的显存浪费。
优化效果对比
优化方法 | 显存节省比例 | 训练速度影响 | 适用场景 |
---|---|---|---|
缩短序列长度(8192→4096) | ~50% | 提升 | 数据允许缩短的场景 |
梯度检查点 | ~40% | 降低 20-30% | 显存紧张但计算资源充足 |
冻结基础模型 | ~90% | 提升 | 仅需微调奖励头的场景 |
ZeRO-3(4 卡) | ~75% | 略降 | 多 GPU 分布式训练 |
建议优先实施缩短序列长度 + 梯度检查点 + ZeRO-3的组合方案,可在保证训练效果的前提下将显存占用降低 70% 以上。
本文来自博客园,作者:limingqi,转载请注明原文链接:https://www.cnblogs.com/limingqi/p/19046068