PyTorch & Lightning 模型参数命名与权重提取详解
当然可以!下面是一篇整理好的、适合发布到博客平台的 Markdown 格式文章:
PyTorch & Lightning 模型参数命名与权重提取详解
在使用 PyTorch 或 PyTorch Lightning 进行模型训练和保存时,我们经常会遇到模型参数(state_dict)中的命名问题。比如,在 Lightning 官方文档中,提到了如下用法:
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
很多同学会疑惑:
- 这些 encoder.、decoder.前缀是 Lightning 的特殊行为吗?
- 如果用原生 nn.Module,保存 checkpoint 时是不是要手动指定所有字典键名?
本文将为你详细解答这些问题。
1. 参数名前缀是 PyTorch 的标准行为
无论你用的是原生 PyTorch 还是 PyTorch Lightning,模型的 state_dict 在保存参数时,都会自动将子模块的名字作为前缀添加到参数名上。
例子说明
假设你有如下原生 PyTorch 代码:
import torch
import torch.nn as nn
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 10)
class Autoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
autoencoder = Autoencoder(Encoder(), Decoder())
state_dict = autoencoder.state_dict()
print(state_dict.keys())
输出结果类似:
dict_keys([
    'encoder.linear.weight',
    'encoder.linear.bias',
    'decoder.linear.weight',
    'decoder.linear.bias'
])
可以看到,encoder. 和 decoder. 前缀是PyTorch 自动加上的,用于区分参数属于哪个子模块。
2. Lightning 的行为与 PyTorch 一致
PyTorch Lightning 的 LightningModule 继承自 nn.Module,它的 state_dict 行为和原生 PyTorch 完全一致。
唯一的区别是,Lightning 在保存 checkpoint 时,会把 state_dict 放到 checkpoint 文件的 state_dict 字段下(而不是直接保存为一个 dict),这是 Lightning 的约定。
3. 原生 nn.Module 保存 checkpoint 时需要手动指定键名吗?
不需要!
只要你用 model.state_dict(),PyTorch 会自动生成带有子模块前缀的参数名。
保存模型参数:
torch.save(model.state_dict(), "model.pth")
加载模型参数:
model.load_state_dict(torch.load("model.pth"))
只有在你自己拼装参数字典(比如只存一部分参数)时,才需要手动处理键名。
4. Lightning 的 checkpoint 格式
Lightning 保存 checkpoint 时,通常会包含如下内容:
- state_dict(模型参数,带前缀)
- optimizer_states
- epoch、- global_step等训练相关信息
你可以通过 torch.load 读取 checkpoint,然后取出 ['state_dict'] 字段:
checkpoint = torch.load(CKPT_PATH)
state_dict = checkpoint["state_dict"]
5. 总结
- 参数名带前缀是 PyTorch 的标准行为,不是 Lightning 的特殊行为。
- Lightning 只是把 state_dict包装进 checkpoint 文件的一个字段里。
- 无论是 Lightning 还是原生 nn.Module,子模块参数名都会自动带上前缀,无需手动指定。
如需更详细的代码示例或遇到具体保存/加载的问题,欢迎留言交流!

 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号