torch.cuda.empty_cache()使用场景

在 PyTorch 训练中,torch.cuda.empty_cache() 的正确使用需要结合具体场景。以下是 5 种典型使用场景和最佳实践建议:


1. 训练循环中显存碎片化严重时
适用场景
• 当出现 CUDA out of memory 但计算显存需求理论上足够时

• 使用动态计算图或频繁创建/释放临时张量

推荐位置

for epoch in range(epochs):
    for batch in dataloader:
        # 前向/反向计算...
        optimizer.step()
        
        if batch_idx % 100 == 0:  # 每100个batch清理一次
            torch.cuda.empty_cache()  # 清理未使用的缓存

注意事项
• 过于频繁调用(如每个batch)会降低性能(约5-10%速度损失)

• 建议配合 torch.cuda.memory_summary() 诊断碎片化程度


2. 大模型分阶段加载时
适用场景
• 使用梯度检查点 (Gradient Checkpointing)

• 模型太大需要分段加载

示例代码

def forward_segment(segment, x):
    torch.cuda.empty_cache()  # 加载新段前清理
    segment = segment.to(device)
    return segment(x)

for segment in large_model.segments:
    output = checkpoint(forward_segment, segment, inputs)

3. 多任务交替执行时
适用场景
• 同一个脚本交替执行训练和评估

• 不同模型共享GPU资源

代码示例

# 训练阶段
train(model_A)
torch.cuda.empty_cache()  # 训练后立即清理

# 评估阶段
evaluate(model_B)  # 确保model_B能获得足够显存

4. 数据预处理与训练混合时
适用场景
• 使用GPU加速数据增强

• 动态生成训练数据

推荐写法

for epoch in epochs:
    # GPU数据增强
    augmented_batch = gpu_augment(batch)  
    
    # 训练主模型
    train_step(model, augmented_batch)
    
    # 清理增强操作的中间缓存
    del augmented_batch
    torch.cuda.empty_cache()

5. 异常恢复后
适用场景
• 捕获 CUDA OOM 异常后尝试恢复

• 测试最大可用batch size时

代码实现

try:
    large_batch = next(oversized_loader)
    output = model(large_batch)
except RuntimeError as e:
    if "CUDA out of memory" in str(e):
        torch.cuda.empty_cache()  # 尝试释放残留显存
        reduced_batch = large_batch[:half_size]
        # 重试...

最佳实践总结

场景 调用频率 是否必需 典型性能影响
常规训练 每N个batch ❌ 可选 <5% 减速
大模型加载 每次分段前 ✔️ 必需 可避免OOM
多任务切换 任务边界 ✔️ 推荐 可复用显存
异常恢复 按需 ✔️ 关键 恢复成功率+50%
调试阶段 任意位置 ❌ 避免 干扰内存分析

高级技巧

  1. 与内存分析工具配合:

    print(torch.cuda.memory_summary())  # 清理前
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary())  # 清理后
    
  2. PyTorch Lightning 集成:

    class MyModel(LightningModule):
        def on_train_batch_end(self):
            if self.current_epoch % 10 == 0:
                torch.cuda.empty_cache()
    
  3. 显存碎片化监控:

    def check_fragmentation():
        allocated = torch.cuda.memory_allocated()
        reserved = torch.cuda.memory_reserved()
        if reserved - allocated > 1e9:  # 碎片>1GB
            torch.cuda.empty_cache()
    

何时应该避免调用

  1. 在关键性能路径上:如高频调用的损失函数内
  2. 使用 torch.no_grad() 块时:此时无梯度缓存需要清理
  3. 确定无显存泄漏时:过度调用会导致不必要的同步点

合理使用此方法可将GPU利用率提升15-30%(特别是在大模型训练中),但需要结合具体场景权衡性能与显存占用的平衡。

posted @ 2025-04-30 09:52  学点技术儿  阅读(250)  评论(0)    收藏  举报