PyTorch & Lightning 模型参数命名与权重提取详解

当然可以!下面是一篇整理好的、适合发布到博客平台的 Markdown 格式文章:


PyTorch & Lightning 模型参数命名与权重提取详解

在使用 PyTorch 或 PyTorch Lightning 进行模型训练和保存时,我们经常会遇到模型参数(state_dict)中的命名问题。比如,在 Lightning 官方文档中,提到了如下用法:

encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}

很多同学会疑惑:

  • 这些 encoder.decoder. 前缀是 Lightning 的特殊行为吗?
  • 如果用原生 nn.Module,保存 checkpoint 时是不是要手动指定所有字典键名?

本文将为你详细解答这些问题。


1. 参数名前缀是 PyTorch 的标准行为

无论你用的是原生 PyTorch 还是 PyTorch Lightning,模型的 state_dict 在保存参数时,都会自动将子模块的名字作为前缀添加到参数名上。

例子说明

假设你有如下原生 PyTorch 代码:

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 10)

class Autoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

autoencoder = Autoencoder(Encoder(), Decoder())
state_dict = autoencoder.state_dict()
print(state_dict.keys())

输出结果类似:

dict_keys([
    'encoder.linear.weight',
    'encoder.linear.bias',
    'decoder.linear.weight',
    'decoder.linear.bias'
])

可以看到,encoder.decoder. 前缀是PyTorch 自动加上的,用于区分参数属于哪个子模块。


2. Lightning 的行为与 PyTorch 一致

PyTorch Lightning 的 LightningModule 继承自 nn.Module,它的 state_dict 行为和原生 PyTorch 完全一致。
唯一的区别是,Lightning 在保存 checkpoint 时,会把 state_dict 放到 checkpoint 文件的 state_dict 字段下(而不是直接保存为一个 dict),这是 Lightning 的约定。


3. 原生 nn.Module 保存 checkpoint 时需要手动指定键名吗?

不需要!

只要你用 model.state_dict(),PyTorch 会自动生成带有子模块前缀的参数名。

保存模型参数:

torch.save(model.state_dict(), "model.pth")

加载模型参数:

model.load_state_dict(torch.load("model.pth"))

只有在你自己拼装参数字典(比如只存一部分参数)时,才需要手动处理键名。


4. Lightning 的 checkpoint 格式

Lightning 保存 checkpoint 时,通常会包含如下内容:

  • state_dict(模型参数,带前缀)
  • optimizer_states
  • epochglobal_step 等训练相关信息

你可以通过 torch.load 读取 checkpoint,然后取出 ['state_dict'] 字段:

checkpoint = torch.load(CKPT_PATH)
state_dict = checkpoint["state_dict"]

5. 总结

  • 参数名带前缀是 PyTorch 的标准行为,不是 Lightning 的特殊行为。
  • Lightning 只是把 state_dict 包装进 checkpoint 文件的一个字段里。
  • 无论是 Lightning 还是原生 nn.Module,子模块参数名都会自动带上前缀,无需手动指定。

如需更详细的代码示例或遇到具体保存/加载的问题,欢迎留言交流!

posted @ 2025-06-13 16:23  Gold_stein  阅读(112)  评论(0)    收藏  举报