Loading

【python语法】deepcopy()方法

示例

origin_model = deepcopy(model)

解释

简短回答(要点):

  • deepcopy 是 Python 标准库 copy 模块里的函数,文件开头通过 from copy import deepcopy 导入。
  • 作用:对对象做“深拷贝”(递归拷贝),生成一个和原对象结构相同但占用独立内存的新对象;修改新对象不会影响原对象。

详细说明与建议:

  • 来源:deepcopy 来自标准库模块 copy

    from copy import deepcopy
    

    你在 main_tta.py 文件顶部已经看到了这一行导入。

  • 行为(语义):

    • 对复杂对象(如包含嵌套列表、字典、类实例等)进行完全复制,包括内部所有可拷贝的子对象;

    不同于浅拷贝(copy.copy),浅拷贝只复制顶层对象,内部引用仍指向原子对象。

    • 对 PyTorch 的 nn.Module(模型)使用 deepcopy(model) 会创建一个新的模型对象,并复制其参数张量(新对象含有新的 Parameter 实例);因此你可以安全地修改 origin_model 而不影响 model
  • 在你代码里的用途:

    origin_model = deepcopy(model)
    

    目的是保存模型的原始一份(未经过测试时适配修改的状态),便于后续比较或恢复。

  • 注意事项 / 取舍:

    • 性能与内存:deepcopy 会复制所有数据(包括权重张量),对大模型会占用较多内存并稍慢。
    • 设备(CPU/GPU):如果模型参数位于 GPU,上述 deepcopy 会复制这些 CUDA 张量(通常也位于 GPU);在某些环境下可能产生跨设备复制问题。常见做法是在 deepcopy 之前把模型移到 CPU,或更常用的是只保存 state_dict() 的拷贝(更轻量、通用)。
    • 推荐(更稳健的做法):
      • 保存 state dict(只复制张量数据,而不是整个对象/方法):
        origin_state = deepcopy(model.state_dict())   # 深拷贝参数字典
        # 或更明确地把参数移到 CPU 并 clone:
        origin_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        
      • 恢复时使用:
        model.load_state_dict(origin_state)
        
      • 这样比 deepcopy 整个 Module 更节省内存,也更便于序列化/保存。
  • 小结对比

    • origin_model = deepcopy(model):直观、把模型对象连同结构和参数都复制一份;简单但可能更耗内存。
    • origin_state = deepcopy(model.state_dict())origin_state = {k: v.clone().cpu()...}:推荐用于保存权重快照,更经济且设备友好。
posted @ 2025-12-04 17:11  SaTsuki26681534  阅读(2)  评论(0)    收藏  举报