Qwen2.5-3B 奖励模型显存瘦身术:轻量运行,性能不减

要优化 Qwen2.5-3B 奖励模型的显存占用,需从模型加载、数据处理、训练策略三个维度综合优化。以下是具体可落地的方案,按显存节省效果排序:

一、缩短序列长度(显存优化最显著)

Qwen2.5-3B 在 8192 tokens 时显存需求极高,优先缩短序列长度:

python运行
# 在tokenizer编码时减小max_length(根据数据实际长度调整,如4096或2048)
chosen_encoding = tokenizer(
    text=batch["chosen_inputs"],
    padding=True,
    truncation=True,
    max_length=4096,  # 从8192→4096(显存占用可减少约50%)
    return_tensors="pt"
)
rejected_encoding = tokenizer(
    text=batch["rejected_inputs"],
    padding=True,
    truncation=True,
    max_length=4096,  # 同步修改
    return_tensors="pt"
)
 
原理:Transformer 的显存占用与序列长度的平方成正比(O(n²)),缩短长度能显著降低自注意力矩阵的计算量和存储需求。

二、启用梯度检查点(以计算换显存)

牺牲 20%-30% 的计算速度,换取 50% 左右的显存节省:

python运行
# 加载基础模型时启用梯度检查点
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map={"": accelerator.device},
    gradient_checkpointing=True  # 关键:不保存中间激活值,反向传播时重新计算
)
# 手动启用梯度检查点(部分模型需要显式调用)
base_model.gradient_checkpointing_enable()
原理:默认情况下,模型会保存前向传播的所有中间激活值用于反向传播,梯度检查点仅保存部分关键值,反向传播时重新计算其他值,大幅减少显存占用。

三、冻结基础模型,仅训练奖励头

若只需微调奖励头(而非整个基础模型),可冻结基础模型参数:

python运行
# 初始化奖励模型后,冻结基础模型参数
for param in reward_model.base_model.parameters():
    param.requires_grad = False  # 基础模型参数不更新
# 仅奖励头参数可训练(已默认开启)
for param in reward_model.reward_head.parameters():
    param.requires_grad = True
 
原理:Qwen2.5-3B 的 3B 参数中,奖励头仅占约 1%(隐藏层大小 ×1,如 768×1),冻结基础模型后,需要更新的参数从 3B→768,显存需求大幅降低。

四、优化批次大小与梯度累积

在保持总批次等效的前提下,进一步减小单步批次:

python运行 
# 1. 减小DataLoader的batch_size(最小为1)
train_loader = torch.utils.data.DataLoader(
    processed_dataset,
    batch_size=1,  # 已为最小值,无法再减小
    shuffle=True,
    pin_memory=True
)

# 2. 增大梯度累积步数(保持总批次=batch_size×gradient_accumulation_steps)
accelerator = Accelerator(
    mixed_precision="fp16",
    gradient_accumulation_steps=8  # 从4→8,总批次不变,但单步显存需求降低
)
 
原理:单步批次越小,单次前向 / 反向传播的显存占用越低,梯度累积通过多步累积梯度来保持总批次等效,平衡显存与训练稳定性。

五、启用 DeepSpeed ZeRO-3 分布式优化(多 GPU 场景)

若使用多 GPU 训练(如用户的 4 卡环境),通过 DeepSpeed 的 ZeRO-3 分片参数、梯度和优化器状态:

  1. 创建 ds_zero3_config.yaml:
train_batch_size: 8  # 总批次=batch_size×gradient_accumulation_steps×num_gpus
gradient_accumulation_steps: 8
optimizer:
  type: Adam
  params:
    lr: 5e-7
    betas: [0.9, 0.95]
    eps: 1e-8
zero_optimization:
  stage: 3  # 关键:分片模型参数、梯度、优化器状态
  offload_optimizer:
    device: cpu  # 优化器状态移至CPU(可选,进一步节省GPU显存)
  contiguous_gradients: true
  overlap_comm: true  # 通信与计算重叠
  1. 启动命令:
CUDA_VISIBLE_DEVICES=0,1,2,3 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
accelerate launch --num_processes=4 --config_file ds_zero3_config.yaml qwen_3B_train_0721_mix_checkpoint_8192.py
原理:ZeRO-3 将模型参数、梯度和优化器状态分片到多个 GPU/CPU,每个 GPU 仅保存部分数据,4 卡环境下理论显存占用可降至 1/4。

六、清理无效显存与减少碎片

在训练循环中主动释放未使用的张量和缓存:
# 在每个batch处理结束后添加
# 1. 手动删除不再需要的中间变量
del chosen_rewards, rejected_rewards, loss
# 2. 清理CUDA缓存(释放未使用的显存块)
torch.cuda.empty_cache()

原理:训练过程中,未显式删除的张量会占用显存,empty_cache()可释放已标记为 “未使用” 但未实际回收的显存,减少碎片。

七、使用更高效的精度(如 bf16)

若 GPU 支持 bfloat16(如 A100、H100),可替换 fp16 进一步节省显存:

python运行
accelerator = Accelerator(
    mixed_precision="bf16",  # 从fp16→bf16(显存占用相近,但数值范围更大,训练更稳定)
    gradient_accumulation_steps=8
)

# 模型加载时同步使用bf16
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # 与混合精度匹配
    trust_remote_code=True,
    device_map={"": accelerator.device},
    gradient_checkpointing=True
)
 
原理:bf16 与 fp16 显存占用相同(均为 2 字节),但无需像 fp16 那样处理数值范围限制,可减少因溢出导致的显存浪费。

优化效果对比

优化方法显存节省比例训练速度影响适用场景
缩短序列长度(8192→4096) ~50% 提升 数据允许缩短的场景
梯度检查点 ~40% 降低 20-30% 显存紧张但计算资源充足
冻结基础模型 ~90% 提升 仅需微调奖励头的场景
ZeRO-3(4 卡) ~75% 略降 多 GPU 分布式训练

建议优先实施缩短序列长度 + 梯度检查点 + ZeRO-3的组合方案,可在保证训练效果的前提下将显存占用降低 70% 以上。
 

image

 

image

 

 

posted on 2025-08-19 10:20  limingqi  阅读(95)  评论(0)    收藏  举报

导航