大模型-大模型权重文件的加载-62

safetensor 权重文件的加载 详细过程

深入理解大模型 Safetensors 权重文件的加载过程
作为一名大模型算法工程师,你肯定知道权重文件对于模型的运行至关重要。传统的 PyTorch 模型权重通常保存为 .bin 或 .pth 格式,这些格式在加载时存在一些安全隐患和效率问题。而 Safetensors 文件的出现就是为了解决这些痛点。

Safetensors 是一种序列化格式,专为安全、快速地加载深度学习模型权重而设计。它主要解决了以下几个问题:

安全性:传统 PyTorch 文件可能包含任意代码执行的风险,而 Safetensors 严格限制了序列化的内容,只允许张量数据,从而避免了恶意代码的注入。
速度:Safetensors 通过内存映射(memory-mapping)的方式加载文件,这意味着它不会将整个文件加载到内存中,而是直接从磁盘读取所需的部分。这显著提高了加载速度,尤其对于大型模型。
内存效率:由于采用内存映射,Safetensors 在加载时占用的内存更少。
兼容性:Safetensors 可以被 PyTorch、TensorFlow、JAX 等主流深度学习框架原生支持。
Safetensors 文件的内部结构
Safetensors 文件本质上由两部分组成:

JSON 头部(JSON Header):这是一个小的 JSON 字典,包含了文件中所有张量的元数据(metadata)。这些元数据包括张量的名称、数据类型(dtype)、形状(shape)以及在文件中的偏移量(data_offsets)。
张量数据(Tensor Data):紧跟在 JSON 头部之后的是原始的二进制张量数据。这些数据是连续存储的,没有额外的填充或分隔符。
Safetensors 权重文件的加载过程详解
当一个大模型从 Safetensors 文件中加载权重时,会经历以下详细步骤:

打开文件并读取 JSON 头部:

加载器首先以二进制读取模式打开 .safetensors 文件。
它会读取文件末尾的 8 字节,这 8 字节指示了 JSON 头部的大小。
根据这个大小,加载器从文件的相应位置读取整个 JSON 头部。
将读取到的 JSON 数据解析为一个 Python 字典,其中包含了所有张量的元信息。
内存映射文件(Memory Mapping):

这是 Safetensors 实现高效加载的关键。加载器会使用操作系统的内存映射功能(例如 Linux 上的 mmap)将整个 Safetensors 文件映射到进程的虚拟地址空间中。
这意味着文件内容在内存中是可见的,但实际上并没有完全加载到 RAM 中。只有当程序访问某个特定地址时,操作系统才会将对应的文件块加载到物理内存中。
迭代张量元数据并按需加载:

加载器会遍历 JSON 头部中存储的每一个张量元数据。
对于每个张量,它会获取其名称、数据类型、形状以及最重要的 数据偏移量(data_offsets)。
data_offsets 是一对整数,表示该张量数据在 Safetensors 文件中起始和结束的字节偏移量。
按需加载(Lazy Loading):当需要访问某个特定的张量时(例如,当模型的前向传播需要该层的权重时),加载器会:
利用内存映射,直接跳转到该张量在文件中的偏移量。
根据张量的形状和数据类型,从内存映射区域中读取相应的字节数据。
将这些字节数据转换为对应的张量对象(例如 PyTorch 的 torch.Tensor)。
由于是内存映射,这个过程非常快,并且避免了一次性加载所有权重,从而节省了大量内存。
张量校验(Optional):

一些实现可能会在加载张量后进行可选的校验,例如检查张量的数据类型或形状是否与预期一致,以确保数据的完整性。
构建模型状态字典(State Dictionary):

加载的每个张量(及其对应的名称)会被存储在一个字典中,通常称为 state_dict。
这个 state_dict 随后会被加载到模型的对应层中,完成模型的权重初始化。
示例代码(使用 Hugging Face safetensors 库)
在实际开发中,我们通常会使用像 Hugging Face 的 safetensors 库来处理 Safetensors 文件,它封装了上述加载逻辑,使得使用非常简便。

from safetensors import safe_open
import torch

# 假设你有一个名为 "model.safetensors" 的权重文件
file_path = "model.safetensors"

# 1. 打开 Safetensors 文件
# "safetensors.safe_open" 会处理底层的内存映射和 JSON 头部读取
with safe_open(file_path, framework="pt", device="cpu") as f:
    # 2. 获取所有张量的名称
    tensor_names = f.keys()
    print(f"文件中包含的张量名称:{tensor_names}")

    # 3. 按需加载特定张量
    # 这里的 "model.embed_tokens.weight" 只是一个示例张量名称
    # 实际名称取决于你的模型结构
    if "model.embed_tokens.weight" in tensor_names:
        embedding_weight = f.get_tensor("model.embed_tokens.weight")
        print(f"嵌入层权重形状:{embedding_weight.shape}")
        print(f"嵌入层权重数据类型:{embedding_weight.dtype}")
    else:
        print("未找到 'model.embed_tokens.weight' 张量。")

    # 4. 加载所有张量到 PyTorch 状态字典
    state_dict = {}
    for k in tensor_names:
        state_dict[k] = f.get_tensor(k)
    print("所有张量已加载到 state_dict 中。")

# 接下来,你可以将这个 state_dict 加载到你的 PyTorch 模型中
# 例如:model.load_state_dict(state_dict)

模型的权重中 state_dict 是什么

在 PyTorch 模型中,state_dict 是一个 Python 字典(dict),它存储了模型中所有可学习参数的名称和对应的 Tensor 值。你可以把它看作是模型所有权重的快照。

更具体地说,state_dict 包含以下信息:

键 (Key):每个键都是一个字符串,表示模型中一个特定层的名称。这些名称遵循一定的命名约定,例如 conv1.weight 表示第一层卷积层的权重,fc2.bias 表示第二层全连接层的偏置。
值 (Value):每个值都是一个 torch.Tensor 对象,包含了相应层的权重或偏置的值。
state_dict 的作用
state_dict 在 PyTorch 中扮演着至关重要的角色,主要体现在以下几个方面:

保存模型权重 (Saving Model Weights):

当你训练好一个模型后,你需要将它的权重保存到磁盘,以便后续加载和使用。
state_dict 提供了这种能力。你可以使用 torch.save(model.state_dict(), filepath) 将模型的 state_dict 保存到文件中。
对于大型模型,通常会使用 Safetensors 格式来保存 state_dict,因为它更安全、更高效。
加载模型权重 (Loading Model Weights):

当你需要加载预训练模型或之前保存的模型权重时,你需要使用 model.load_state_dict(state_dict) 将 state_dict 中的权重加载到模型中。
如果使用 Safetensors 文件,你需要先使用 safetensors 库加载权重到 state_dict,然后再加载到模型。
模型参数的访问和修改 (Accessing and Modifying Model Parameters):

你可以通过 model.state_dict() 访问模型的所有参数。
这允许你检查模型的权重,或者在某些情况下,手动修改它们。
模型迁移学习 (Transfer Learning):

在迁移学习中,你通常会加载预训练模型的 state_dict,然后只训练模型的部分层。
state_dict 使得你可以方便地加载和更新模型的特定部分的权重。
state_dict 与 Safetensors
虽然 state_dict 是 PyTorch 中表示模型权重的标准方式,但它本身只是一种内存中的数据结构。当你需要将 state_dict 保存到磁盘时,可以选择不同的文件格式。

传统的 PyTorch 模型权重通常保存为 .pth 或 .bin 格式。
Safetensors 是一种更安全、更高效的权重文件格式,它专门设计用于存储 state_dict。
Safetensors 文件包含一个 JSON 头部,描述了每个张量的元数据(名称、形状、数据类型等),以及实际的张量数据。
使用 safetensors 库可以方便地加载和保存 state_dict 到 Safetensors 文件。

import torch
from safetensors import safe_open

# 假设你有一个 PyTorch 模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 20)
        self.linear2 = torch.nn.Linear(20, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

model = MyModel()

# 1. 获取模型的 state_dict
state_dict = model.state_dict()
print("模型的 state_dict:", state_dict.keys())  # 打印 state_dict 中的键

# 2. 保存 state_dict 到 Safetensors 文件
torch.save(state_dict, "model.pth")  # 保存为 .pth 文件 (不推荐用于大型模型)
# 使用 safetensors 库保存 (推荐)
# from safetensors.torch import save_model, load_model
# save_model(model, "model.safetensors")

# 3. 从 Safetensors 文件加载 state_dict
# loaded_state_dict = torch.load("model.pth")  # 加载 .pth 文件
# 使用 safetensors 库加载 (推荐)
# loaded_state_dict = load_model("model.safetensors")
# model.load_state_dict(loaded_state_dict)  # 将加载的 state_dict 应用于模型

posted @ 2025-06-19 00:10  jack-chen666  阅读(110)  评论(0)    收藏  举报