.\models\swinv2\__init__.py
# 版权声明和许可证信息,指明代码版权归 HuggingFace Team 所有,使用 Apache License, Version 2.0 进行许可
# 如果不符合许可证要求,不能使用此文件中的代码
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入必要的类型检查模块
from typing import TYPE_CHECKING
# 导入必要的自定义异常和模块
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
# 定义导入结构,包含需要导入的模块和对象
_import_structure = {
"configuration_swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"],
}
# 检查是否存在 torch 库,如果不存在则抛出自定义异常
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
# 如果 torch 存在,添加额外的模块到导入结构中
_import_structure["modeling_swinv2"] = [
"SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Swinv2ForImageClassification",
"Swinv2ForMaskedImageModeling",
"Swinv2Model",
"Swinv2PreTrainedModel",
"Swinv2Backbone",
]
# 如果是类型检查阶段
if TYPE_CHECKING:
# 导入配置和模型相关的对象,用于类型检查
from .configuration_swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config
# 再次检查 torch 库是否存在,如果不存在则抛出自定义异常
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
# 导入模型相关的对象,用于类型检查
from .modeling_swinv2 import (
SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST,
Swinv2Backbone,
Swinv2ForImageClassification,
Swinv2ForMaskedImageModeling,
Swinv2Model,
Swinv2PreTrainedModel,
)
# 如果不是类型检查阶段
else:
import sys
# 使用 LazyModule 模式加载模块,延迟导入相关对象
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2022, Google and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Switch Transformers model configuration"""
from ...configuration_utils import PretrainedConfig # 导入预训练配置类
from ...utils import logging # 导入日志模块
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/switch-base-8": "https://huggingface.co/google/switch-base-8/blob/main/config.json",
}
class SwitchTransformersConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SwitchTransformersModel`]. It is used to
instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the
SwitchTransformers [google/switch-base-8](https://huggingface.co/google/switch-base
):
# 初始化 Transformer 参数
self.vocab_size = vocab_size # 设置词汇表大小
self.d_model = d_model # 设置 Transformer 模型的维度大小
self.d_kv = d_kv # 设置键和值的维度大小
self.d_ff = d_ff # 设置前馈网络的隐藏层大小
self.num_sparse_encoder_layers = num_sparse_encoder_layers # 编码器中稀疏层的数量
self.num_layers = num_layers # 总层数
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # 解码器层数,默认与总层数相同
self.num_sparse_decoder_layers = num_sparse_decoder_layers # 解码器中稀疏层的数量
# 每隔多少层设置一个稀疏层,用于编码器
if self.num_sparse_encoder_layers > 0:
self.encoder_sparse_step = self.num_layers // self.num_sparse_encoder_layers
else:
self.encoder_sparse_step = self.num_layers # 如果没有稀疏层,则步长为总层数,这会创建0个稀疏层
# 每隔多少层设置一个稀疏层,用于解码器
if self.num_sparse_decoder_layers > 0:
self.decoder_sparse_step = self.num_decoder_layers // self.num_sparse_decoder_layers
else:
self.decoder_sparse_step = self.num_decoder_layers # 如果没有稀疏层,则步长为总层数,这会创建0个稀疏层
self.num_heads = num_heads # 设置注意力头的数量
self.num_experts = num_experts # 设置专家的数量
self.expert_capacity = expert_capacity # 设置每个专家的容量
self.router_bias = router_bias # 设置路由器偏置
self.router_jitter_noise = router_jitter_noise # 设置路由器抖动噪声
if router_dtype not in ["float32", "float16", "bfloat16"]:
raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}")
self.router_dtype = router_dtype # 设置路由器数据类型
self.router_ignore_padding_tokens = router_ignore_padding_tokens # 是否忽略填充标记的路由
self.relative_attention_num_buckets = relative_attention_num_buckets # 相对注意力的桶数量
self.relative_attention_max_distance = relative_attention_max_distance # 相对注意力的最大距离
self.dropout_rate = dropout_rate # 设置丢弃率
self.layer_norm_epsilon = layer_norm_epsilon # 层归一化的 epsilon 值
self.initializer_factor = initializer_factor # 初始化因子
self.use_cache = use_cache # 是否使用缓存
self.add_router_probs = add_router_probs # 是否添加路由概率
self.router_z_loss_coef = router_z_loss_coef # 路由 Z 损失系数
self.router_aux_loss_coef = router_aux_loss_coef # 路由辅助损失系数
self.dense_act_fn = dense_act_fn # 密集层的激活函数
super().__init__(
pad_token_id=pad_token_id, # 填充标记 ID
eos_token_id=eos_token_id, # 终止标记 ID
is_encoder_decoder=is_encoder_decoder, # 是否是编码解码器
**kwargs, # 其它关键字参数
)
# 导入必要的库
import argparse # 命令行参数解析库
import json # JSON 数据处理库
import os # 系统操作库
import tensorstore as ts # TensorStore 库
import torch # PyTorch 深度学习库
from flax import serialization # Flax 序列化库
from flax.traverse_util import flatten_dict, unflatten_dict # Flax 的字典扁平化和反扁平化工具
from tensorflow.io import gfile # TensorFlow 文件操作库
from transformers.modeling_utils import dtype_byte_size # 计算数据类型字节大小的工具函数
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import (
rename_keys, # 从 Switch Transformers 原始 Flax 检查点转换到 PyTorch 的键重命名函数
)
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME # Transformers 模型权重文件索引名称和通用权重名称
from transformers.utils.hub import convert_file_size_to_int # 将文件大小转换为整数的函数
def rename_base_flax_keys(flax_key_tuple, flax_tensor):
"""
对基本 JAX 键进行重命名以适配 PyTorch。
"""
if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3:
# 对专家层的特定处理
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = torch.permute(flax_tensor, (0, 2, 1))
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple):
# 对线性层的特定处理
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
return flax_key_tuple, flax_tensor
def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path):
"""
获取键和 TensorStore 字典。
"""
if "metadata" in layer:
split_layer = layer.split("metadata")
curr_real_layer_name = "".join(split_layer[0])[:-1]
split_layer = [tuple(("metadata" + split_layer[1]).split("/"))]
elif "kvstore" in layer:
split_layer = layer.split("kvstore")
curr_real_layer_name = "".join(split_layer[0])[:-1]
split_layer = [tuple(("kvstore" + split_layer[1]).split("/"))]
else:
split_layer = layer.split("/")
curr_real_layer_name = "/".join(split_layer[:-1])
split_layer[-1] = (split_layer[-1],)
if "kvstore/path" in layer:
content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}"
elif "kvstore/driver" in layer:
content = "file"
else:
content = checkpoint_info[layer]
return curr_real_layer_name, split_layer, content
def rename_and_save_block(current_block, save_path):
"""
重命名当前块的键并保存。
"""
current_block = rename_keys(current_block)
new_current_block = {}
for k, v in current_block.items():
new_current_block[k.replace("/", ".")] = v
current_block = new_current_block
torch.save(current_block, save_path)
def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weights_name: str = WEIGHTS_NAME):
"""
动态分片检查点文件。
"""
max_shard_size = convert_file_size_to_int(max_shard_size) # 将最大分片大小转换为整数
sharded_state_dicts = [] # 存储分片后的状态字典列表
current_block = {} # 当前块的状态字典
current_block_size = 0 # 当前块的大小
total_size = 0 # 总共的大小
os.makedirs(dump_path, exist_ok=True) # 确保转储路径存在,不存在则创建
# 从检查点文件中恢复信息并扁平化
with gfile.GFile(switch_checkpoint_path + "/checkpoint", "rb") as fp:
checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"]
checkpoint_info = flatten_dict(checkpoint_info, sep="/")
all_layers = {} # 所有层的字典,用于存储层信息
# 遍历检查点信息中的每个层名称
for layer in checkpoint_info.keys():
# 获取真实的层名称、分割后的层名称及内容,通过函数获取
curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict(
layer, checkpoint_info, switch_checkpoint_path
)
# 如果当前真实层名称已经存在于所有层的字典中
if curr_real_layer_name in all_layers:
# 将内容存入已有的真实层名称对应的字典中的分割层中的最后一部分
all_layers[curr_real_layer_name][split_layer[-1]] = content
else:
# 创建新的真实层名称键,并存入内容
all_layers[curr_real_layer_name] = {split_layer[-1]: content}
# 遍历所有层的键
for key in all_layers.keys():
# 使用 tensorstore 打开未展开的字典格式的所有层的数据
raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result()
# 将原始权重数据转换为 PyTorch 的张量格式
raw_weights = torch.tensor(raw_weights)
# 计算权重张量的字节大小
weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype)
# 使用小型转换脚本中的重命名模式对键和原始权重进行重命名
key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights)
# 重新连接重命名后的键
key = "/".join(key)
# 如果当前块的大小加上权重大小超过了最大碎片大小
if current_block_size + weight_size > max_shard_size:
# 构建保存路径,包含碎片编号
save_path = os.path.join(
dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")
)
# 重命名并保存当前块
rename_and_save_block(current_block, save_path)
# 添加当前块的键到碎片状态字典中
sharded_state_dicts.append(current_block.keys())
# 删除当前块
del current_block
# 重新创建空的当前块和当前块大小
current_block = {}
current_block_size = 0
# 将处理后的原始权重数据添加到当前块中,转换为指定的数据类型
current_block[key] = raw_weights.to(getattr(torch, dtype))
# 更新当前块大小
current_block_size += weight_size
# 更新总大小
total_size += weight_size
# 添加最后一个块
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin"))
rename_and_save_block(current_block, save_path)
sharded_state_dicts.append(current_block.keys())
# 如果只有一个碎片,直接返回
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# 否则,构建索引
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
# 构建每个碎片文件的名称,包含碎片编号和总碎片数
shard_file = weights_name.replace(
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
)
# 临时文件名,用于重命名到最终的碎片文件
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
# 实际重命名文件到最终的碎片文件
os.rename(temp_filename, os.path.join(dump_path, shard_file))
# 记录每个碎片文件对应的碎片状态字典
shards[shard_file] = shard
# 遍历每个碎片的键
for key in shard:
# 记录每个键对应的碎片文件名称
weight_map[key] = shard_file
# 添加元数据
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
# 将索引写入文件
with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
# 返回元数据和索引
return metadata, index
if __name__ == "__main__":
# 创建一个参数解析器对象
parser = argparse.ArgumentParser()
# 添加必需的参数
parser.add_argument(
"--switch_t5x_checkpoint_path",
default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600",
type=str,
required=False,
help="Path to a directory containing a folder per layer. Follows the original Google format.",
)
# 添加可选参数 max_shard_size,用于指定最大分片大小,默认为 "10GB"
parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size")
# 添加可选参数 dtype,用于指定保存模型的数据类型,默认为 "bfloat16"
parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model")
# 添加可选参数 pytorch_dump_folder_path,用于指定 PyTorch 模型输出的路径
parser.add_argument(
"--pytorch_dump_folder_path",
default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted",
type=str,
required=False,
help="Path to the output pytorch model.",
)
# 解析命令行参数并存储到 args 对象中
args = parser.parse_args()
# 调用 shard_on_the_fly 函数,传递解析后的参数进行处理
shard_on_the_fly(
args.switch_t5x_checkpoint_path,
args.pytorch_dump_folder_path,
args.max_shard_size,
args.dtype,
)
def sanity_check():
# 导入所需的类和函数
from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration, T5Tokenizer
# 加载 Switch 模型的配置文件
config = SwitchTransformersConfig.from_pretrained("google/switch-base-8")
# 将配置保存到指定路径
config.save_pretrained("/home/arthur_huggingface_co/transformers/switch_converted")
# 加载转换后的 Switch 模型
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"/home/arthur_huggingface_co/transformers/switch_converted", device_map="auto"
)
# 加载 T5Tokenizer,用于处理文本输入
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
# 指定一个文本输入
text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
# 使用 tokenizer 对文本进行编码,生成输入的 token ids
input_ids = tokenizer(text, return_tensors="pt").input_ids
# 使用模型生成输出
out = model.generate(input_ids, decoder_start_token_id=0)
# 解码输出并打印结果
print(tokenizer.decode(out[0]))
# 设置文件编码为 UTF-8
# 版权声明,声明代码版权归 The HuggingFace Inc. team 所有
#
# 根据 Apache 许可证 2.0 版本,使用本文件需要遵循许可证的规定
# 详细信息请参考 http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律另有规定或书面同意,本软件是基于“按原样提供”的基础分发的,不提供任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证条款。
"""将 SwitchTransformersX 仓库的检查点转换为 JAX/FLAX 模型。"""
import argparse # 导入用于解析命令行参数的模块
import re # 导入正则表达式模块
from flax.traverse_util import flatten_dict, unflatten_dict # 导入用于扁平化和反扁平化字典的工具函数
from t5x import checkpoints # 导入 SwitchTransformersX 仓库的检查点处理模块
from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration # 导入 Switch Transformers 相关模型配置和生成模型类
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model # 导入用于加载 FLAX 权重到 PyTorch 模型的工具函数
from transformers.utils import logging # 导入日志记录模块
logging.set_verbosity_info() # 设置日志输出级别为 INFO
# 应该不包括由 `from_pt` 参数已经完成的内容
# 定义从原始模型到 Switch Transformers 的层名称映射字典
MOE_LAYER_NAME_MAPPING = {
"/attention/": "/0/SelfAttention/",
"/self_attention/": "/0/SelfAttention/",
"/encoder_decoder_attention/": "/1/EncDecAttention/",
"value": "v",
"query": "q",
"key": "k",
"out": "o",
"pre_self_attention_layer_norm": "0/layer_norm",
"pre_cross_attention_layer_norm": "1/layer_norm",
"pre_attention_layer_norm": "0/layer_norm", # 先前为 1,但似乎是错误的
"token_embedder": "shared",
"encoder_norm": "final_layer_norm",
"decoder_norm": "final_layer_norm",
"relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight",
"router/router_weights/w/": "router/classifier/",
"roer/roer_weights/w/": "router/classifier/",
"logits_dense": "lm_head",
}
def rename_keys(s_dict):
# 在 HF T5 中,我们有 block.{x}.layer.{y}. 对应于原始模型中的 layer.{x}
# 返回字典 s_dict 的键列表
keys = list(s_dict.keys())
# 1. Convert keys based on specified patterns
for key in keys:
# Define pattern to match and transform "layers_<number>" to "block/<number>/layer"
layer_to_block_of_layer = r".*/layers_(\d+)"
new_key = key
if re.match(layer_to_block_of_layer, key):
new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key)
# Define pattern to match and transform "encoder/" or "decoder/" paths
layer_to_block_of_layer = r"(encoder|decoder)\/"
if re.match(layer_to_block_of_layer, key):
groups = re.match(layer_to_block_of_layer, new_key).groups()
if groups[0] == "encoder":
new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key)
new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key)
elif groups[0] == "decoder":
new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key)
new_key = re.sub(r"/pre_mlp_layer_norm/", r"/2/layer_norm/", new_key)
# 2. Convert keys using predefined mapping dictionary MOE_LAYER_NAME_MAPPING
for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items():
if old_key in new_key:
new_key = new_key.replace(old_key, temp_key)
# Print the transformation from original key to new key
print(f"{key} -> {new_key}")
# Replace the original key in the dictionary with the transformed new_key
s_dict[new_key] = s_dict.pop(key)
# Adjust specific entries in the dictionary based on their keys
if "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict:
s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[
"encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"
].T
if "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict:
s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[
"decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"
].T
# 3. Handle keys containing "expert" separately
for key in list(s_dict.keys()):
if "expert" in key:
# Extract the number of experts and their weights
num_experts = s_dict[key].shape[0]
expert_weights = s_dict[key]
# Iterate over each expert, renaming and adding to the dictionary
for idx in range(num_experts):
s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weights[idx]
print(f"{key} -> {key.replace('expert/', f'experts/expert_{idx}/')}")
# Remove the original "expert" key from the dictionary
s_dict.pop(key)
# Return the modified dictionary
return s_dict
# GIN_TO_CONFIG_MAPPING 定义了从 GIN 配置参数到 SwitchTransformersConfig 参数的映射关系
GIN_TO_CONFIG_MAPPING = {
"NUM_ENCODER_LAYERS": "num_layers",
"NUM_DECODER_LAYERS": "num_decoder_layers",
"NUM_HEADS": "num_heads",
"HEAD_DIM": "d_kv",
"EMBED_DIM": "d_model",
"MLP_DIM": "d_ff",
"NUM_SELECTED_EXPERTS": "num_selected_experts",
"NUM_ENCODER_SPARSE_LAYERS": "num_sparse_encoder_layers",
"NUM_DECODER_SPARSE_LAYERS": "num_sparse_decoder_layers",
"dense.MlpBlock.activations": "feed_forward_proj",
}
def convert_gin_to_config(gin_file, num_experts):
# 将 Google 风格的配置文件转换为 Hugging Face 格式的配置
import regex as re
# 从文件中读取 GIN 配置内容
with open(gin_file, "r") as f:
raw_gin = f.read()
# 使用正则表达式匹配参数和值
regex_match = re.findall(r"(.*) = ([0-9.]*)", raw_gin)
args = {}
for param, value in regex_match:
# 根据预定义的映射将参数名转换为 SwitchTransformersConfig 的参数名,并将值转换为相应类型
if param in GIN_TO_CONFIG_MAPPING and value != "":
args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if "." in value else int(value)
# 提取激活函数类型,并添加到参数字典中
activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0]
args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1])
# 添加 num_experts 参数到参数字典中
args["num_experts"] = num_experts
# 使用参数创建 SwitchTransformersConfig 对象
config = SwitchTransformersConfig(**args)
return config
def convert_flax_checkpoint_to_pytorch(
flax_checkpoint_path, config_file, gin_file=None, pytorch_dump_path="./", num_experts=8
):
# 初始化 PyTorch 模型
# 打印正在加载的 flax 权重路径
print(f"Loading flax weights from : {flax_checkpoint_path}")
# 加载 flax 模型的参数
flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path)
if gin_file is not None:
# 如果提供了 gin 文件,则根据 gin 文件和 num_experts 转换为 SwitchTransformersConfig 对象
config = convert_gin_to_config(gin_file, num_experts)
else:
# 否则根据 config_file 创建 SwitchTransformersConfig 对象
config = SwitchTransformersConfig.from_pretrained(config_file)
# 使用配置文件创建 SwitchTransformersForConditionalGeneration 模型
pt_model = SwitchTransformersForConditionalGeneration(config)
# 将 flax 参数扁平化,重命名键名后再还原为字典
flax_params = flax_params["target"]
flax_params = flatten_dict(flax_params, sep="/")
flax_params = rename_keys(flax_params)
flax_params = unflatten_dict(flax_params, sep="/")
# 加载 flax 参数到 PyTorch 模型中
load_flax_weights_in_pytorch_model(pt_model, flax_params)
# 打印保存 PyTorch 模型的路径
print(f"Save PyTorch model to {pytorch_dump_path}")
pt_model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# 必选参数
parser.add_argument(
"--switch_t5x_checkpoint_path",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the"
" model architecture. If not provided, a `gin_file` has to be provided."
),
)
# 可选参数
parser.add_argument(
"--gin_file",
default=None,
type=str,
required=False,
help="Path to the gin config file. If not provided, a `config_file` has to be passed ",
)
parser.add_argument(
"--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model."
)
parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts")
args = parser.parse_args()
convert_flax_checkpoint_to_pytorch(
args.switch_t5x_checkpoint_path,
args.config_name,
args.gin_file,
args.pytorch_dump_folder_path,
args.num_experts,
)
# 添加一个命令行参数,指定输出 PyTorch 模型的路径,参数为必填项
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model."
)
# 添加一个命令行参数,指定专家数量,默认为 8,非必填项
parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts")
# 解析命令行参数并将其存储在 args 变量中
args = parser.parse_args()
# 调用函数 convert_flax_checkpoint_to_pytorch,将 Flax 模型转换为 PyTorch 模型
convert_flax_checkpoint_to_pytorch(
args.switch_t5x_checkpoint_path, # Flax 模型的路径
args.config_name, # 配置名称
args.gin_file, # GIN 文件路径
args.pytorch_dump_folder_path, # 输出的 PyTorch 模型路径
args.num_experts, # 专家数量
)
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r"""
计算负载平衡损失函数。
负载平衡损失函数用于Switch Transformers中,旨在通过调整专家分配概率来实现负载平衡,以优化模型性能。
Args:
router_probs (torch.Tensor):
形状为 [batch_size, sequence_length, num_experts] 的输入概率张量,表示每个位置选择每个专家的概率。
expert_indices (torch.Tensor):
形状为 [batch_size, sequence_length] 的整数张量,表示每个位置选择的专家索引。
Returns:
float:
标量,表示计算得到的负载平衡损失值。
"""
num_groups, tokens_per_group, _ = router_probs.shape
# 计算对数概率的和,对应于每个位置的专家选择概率
log_z = torch.logsumexp(router_probs, dim=-1)
# 计算负载平衡损失,以提高模型的稳定性
balancing_loss = log_z**2
# 返回平均负载平衡损失
return torch.sum(balancing_loss) / (num_groups * tokens_per_group)
# 计算辅助负载平衡损失,类似于Switch Transformer中的实现,使用PyTorch实现。
# 查看Switch Transformer论文(https://arxiv.org/abs/2101.03961)以获取更多细节。
# 此函数实现论文中第4到第6方程中的损失函数,旨在惩罚专家之间路由过于不平衡的情况。
# 参数:
# router_probs (`torch.Tensor`):
# 每个令牌分配给每个专家的概率。形状为 [batch_size, seqeunce_length, num_experts]。
# expert_indices (`torch.Tensor`):
# 形状为 [batch_size, seqeunce_length] 的索引张量,用于标识每个令牌选择的专家。
# 返回:
# 辅助损失值。
num_experts = router_probs.shape[-1]
# 将专家索引转换为int64类型,否则独热编码将失败
if expert_indices.dtype != torch.int64:
expert_indices = expert_indices.to(torch.int64)
# 如果专家索引张量的维度为2,则添加一个维度以匹配独热编码的要求
if len(expert_indices.shape) == 2:
expert_indices = expert_indices.unsqueeze(2)
# 创建独热编码,标识每个令牌是否分配给特定专家
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
# 对于每个令牌,确定其是否被路由到某个专家
expert_mask = torch.max(expert_mask, axis=-2).values
# 将独热编码张量转换为float32类型,否则计算平均值时会失败
expert_mask = expert_mask.to(torch.float32)
# 计算每个组和专家的令牌比例,用于平均计算
tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
# 计算每个组和专家的路由概率,用于平均计算
router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
# 返回平均令牌比例和路由概率乘积的平均值,乘以专家数量的平方
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
class SwitchTransformersTop1Router(nn.Module):
"""
Router using tokens choose top-1 experts assignment.
This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
token is processed by an expert**, or that each expert receives at least one token.
"""
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
self.num_experts = config.num_experts # 设置专家数量
self.expert_capacity = config.expert_capacity # 每个专家的容量
self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
self.jitter_noise = config.router_jitter_noise # 噪声抖动大小
self.ignore_padding_tokens = config.router_ignore_padding_tokens # 是否忽略填充标记
self.dtype = getattr(torch, config.router_dtype) # 指定的张量数据类型
def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Computes router probabilities from input hidden states.
Args:
hidden_states (`torch.Tensor`):
(batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
Returns:
router_probabilities (`torch.Tensor`):
Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
token and expert. Used for routing tokens to experts.
router_logits (`torch.Tensor`):
Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
This is used later for computing router z-loss.
"""
# 使用float32以确保稳定性。参见https://arxiv.org/abs/2101.03961中关于“选择性精度”的讨论。
# 还存储先前的dtype以便将输出转换回先前的dtype
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype) # 将输入张量转换为指定的数据类型
if self.training and self.jitter_noise > 0:
# 如果处于训练模式且设置了抖动噪声,则通过乘以均匀分布的随机数添加噪声
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
# Shape: [num_groups, tokens_per_group, num_experts]
self._cast_classifier() # 调用内部函数_cast_classifier()
router_logits = self.classifier(hidden_states) # 使用分类器计算路由器的逻辑输出
# 应用softmax并将数据类型转回原始的`dtype`
router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
return router_probabilities, router_logits # 返回路由器概率和logits
def _cast_classifier(self):
r"""
`bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
instance of the `Linear8bitLt` class by checking special attributes.
"""
# 检查 self.classifier 是否为 Linear8bitLt 类的实例,如果不是,则转换其数据类型
if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
self.classifier = self.classifier.to(self.dtype)
def forward(self, hidden_states: torch.Tensor) -> Tuple:
r"""
Generic forward function for every Router class. Each Router expects to have the same input hidden states
(`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
`router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
Args:
hidden_states (`torch.Tensor`) :
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
and the router logits. The router probabilities and logits are required to compute the loss.
"""
# 计算路由概率和路由 logits
router_probs, router_logits = self._compute_router_probabilities(hidden_states)
# 根据概率选择每个 token 分配的 expert 索引
expert_index = torch.argmax(router_probs, dim=-1)
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
# 通过累加判断是否超出 expert 的容量限制,并进行掩码处理
token_priority = torch.cumsum(expert_index, dim=-2)
expert_capacity_mask = token_priority <= self.expert_capacity
expert_index = expert_index * expert_capacity_mask
# 计算最大的路由概率,用于计算损失
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
return expert_index, router_probs, router_logits
# 复制并修改自transformers.models.t5.modeling_t5.py中的T5LayerNorm类,更名为SwitchTransformersLayerNorm
class SwitchTransformersLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
初始化层规范化模块,LayerNorm按照SwitchTransformers风格处理,不使用偏差和平均值减去操作。
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # 初始化缩放参数
self.variance_epsilon = eps # 防止除以零的情况,设置一个很小的正数
def forward(self, hidden_states):
"""
执行前向传播,计算并应用在hidden_states上的层规范化。
无需减去平均值,使用根均方计算变差。
"""
# 将隐藏状态转化为半精度浮点数,进行计算后再转换回来进行层规范化
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# 如果权重使用半精度浮点数或bfloat16,将隐藏状态转换到相同类型
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
# 将SwitchTransformersLayerNorm添加到ALL_LAYERNORM_LAYERS列表中以供使用
# 复制并修改自transformers.models.t5.modeling_t5.py中的T5DenseActDense类,更名并修改为SwitchTransformersDenseActDense
class SwitchTransformersDenseActDense(nn.Module):
def __init__(self, config: SwitchTransformersConfig):
"""
初始化稠密激活密集层模块,包含输入线性变换、激活函数应用、模版速率下采样线性变换。
"""
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) # 第一层线性变换
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) # 输出层线性变换
self.dropout = nn.Dropout(config.dropout_rate) # 带dropout的下采样层
self.act = ACT2FN[config.dense_act_fn] # 激活函数
def forward(self, hidden_states):
"""
执行前向传播,通过线性变换、激活函数、dropout和线性变换逐层更新hidden_states。
注意转换类型要和权重一致,对于tensor的转换也会基于数据类型的规则进行处理。
"""
hidden_states = self.wi(hidden_states) # 应用线性变换
hidden_states = self.act(hidden_states) # 激活函数处理
hidden_states = self.dropout(hidden_states) # 降低过拟合的dropout步骤
# 确保权重和计算结果数据类型一致,在某些情况下需要转换类型
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states) # 应用最终的线性变换
return hidden_states
# 定义SparseMLP类,实现Switch Transformers的稀疏多层感知机(MLP)模块
class SwitchTransformersSparseMLP(nn.Module):
"""
实现了Switch Transformers稀疏多层感知机(MLP)模块的特异性参数和结构,包括路由模块和专家模块。
"""
def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
"""
根据配置初始化SparseMLP类,包括所需的路由器和专家模块。
"""
super().__init__()
# 根据不同策略获取路由层
self.router = SwitchTransformersTop1Router(config)
# 初始化并配置专家模块列表
self.experts = nn.ModuleDict()
for idx in range(config.num_experts):
expert_name = f"expert_{idx}"
self.experts[expert_name] = expert_class(config) # 注册指定配置的专家模块
def forward(self, hidden_states):
r"""
Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
expert the corresponding hidden states.
"""
# Step 1: Get the router_mask from the router as well as the probabilities
# 从路由器获取路由器掩码、概率和对数概率
router_mask, router_probs, router_logits = self.router(hidden_states)
# 根据路由器掩码的最大值索引确定每个 token 分配给哪个专家
expert_index = torch.argmax(router_mask, dim=-1)
# 备注: 由于引入的路由器可能并不总是将所有 token 映射到一个路由器上,因此有些隐藏状态可能在层与层之间保持不变。
# 因此在更新之前需要克隆隐藏状态。
# 克隆隐藏状态,准备更新仅选择的部分
next_states = hidden_states.clone()
# 遍历专家列表,为每个专家分配对应的隐藏状态
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool() # 获取当前专家对应的 token 索引
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
# 更新隐藏状态,乘以路由器概率作为缩放因子
hidden_states = router_probs * next_states
# 返回更新后的隐藏状态以及路由器的输出信息(对数概率和专家索引)
return hidden_states, (router_logits, expert_index)
# 定义一个Switch Transformers的注意力模块,继承自PyTorch的nn.Module类
class SwitchTransformersAttention(nn.Module):
"""
Switch Transformers Attention module, based on PyTorch's nn.Module.
This module is responsible for handling the attention mechanism within Switch Transformers.
"""
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
# 初始化时根据给定的配置参数创建注意力层
self.self = SwitchTransformersSelfAttention(config)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
"""
Forward pass for the Switch Transformers Attention module.
Args:
hidden_states (`torch.Tensor`):
The input hidden states for the attention calculation.
attention_mask (`torch.Tensor`, optional):
Mask to avoid performing attention on padding tokens.
head_mask (`torch.Tensor`, optional):
Mask to nullify specific heads of the attention tensors.
output_attentions (`bool`, optional):
Whether to return attentions tensors.
Returns:
`torch.Tensor` or (`torch.Tensor`, `torch.Tensor`):
Depending on `output_attentions`, either returns the contextualized representation
or a tuple with the contextualized representation and the attention tensors.
"""
# 使用Switch Transformers自注意力层计算注意力
self_output = self.self(hidden_states, attention_mask, head_mask, output_attentions)
# 如果需要输出注意力张量,则将其返回;否则只返回上下文表示
if output_attentions:
attention_outputs = self_output[1] # attention_outputs is a tuple (self_output[1])
outputs = self_output[0] # self_output[0] is the contextualized representation
return outputs, attention_outputs
else:
return self_output # Return just the contextualized representation
# 初始化函数,接受一个配置对象和一个布尔值参数,用于设置是否具有相对注意力偏置
def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False):
# 调用父类的初始化函数
super().__init__()
# 设置当前层是否为解码器
self.is_decoder = config.is_decoder
# 设置是否具有相对注意力偏置
self.has_relative_attention_bias = has_relative_attention_bias
# 设置相对注意力的桶数量
self.relative_attention_num_buckets = config.relative_attention_num_buckets
# 设置相对注意力的最大距离
self.relative_attention_max_distance = config.relative_attention_max_distance
# 设置模型的维度
self.d_model = config.d_model
# 设置键值投影的维度
self.key_value_proj_dim = config.d_kv
# 设置注意力头的数量
self.n_heads = config.num_heads
# 设置dropout率
self.dropout = config.dropout_rate
# 计算内部维度,即注意力头乘以键值投影的维度
self.inner_dim = self.n_heads * self.key_value_proj_dim
# 初始化线性层,用于查询、键、值、输出
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
# 如果需要相对注意力偏置,初始化相对注意力偏置的嵌入层
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
# 初始化已剪枝的注意力头集合为空集
self.pruned_heads = set()
# 禁用梯度检查点
self.gradient_checkpointing = False
# 剪枝注意力头的方法
def prune_heads(self, heads):
# 如果没有需要剪枝的头,则直接返回
if len(heads) == 0:
return
# 找到需要剪枝的头以及它们的索引
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
# 对线性层进行剪枝
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
self.v = prune_linear_layer(self.v, index)
self.o = prune_linear_layer(self.o, index, dim=1)
# 更新超参数
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.key_value_proj_dim * self.n_heads
# 将剪枝的头添加到已剪枝的头集合中
self.pruned_heads = self.pruned_heads.union(heads)
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor - the relative position between memory and query positions
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer - number of buckets to categorize relative positions
max_distance: an integer - maximum distance considered for bucketing
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
# Initialize the relative bucket index
relative_buckets = 0
# Adjust num_buckets if bidirectional attention is disabled
if bidirectional:
num_buckets //= 2
# Calculate relative_buckets based on whether relative_position is positive
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
# Ensure relative_position is non-positive if bidirectional=False
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# relative_position is now in the range [0, inf)
# Determine if the relative_position is small (less than max_exact)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# Calculate relative_position_if_large for positions larger than max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
# Clamp relative_position_if_large to num_buckets - 1
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
# Determine final relative_buckets based on whether relative_position is small
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
# 如果未指定设备,则使用相对注意力偏置权重的设备
if device is None:
device = self.relative_attention_bias.weight.device
# 创建表示上下文位置的张量,范围是 [0, query_length)
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
# 创建表示记忆位置的张量,范围是 [0, key_length)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
# 计算相对位置,形状为 (query_length, key_length)
relative_position = memory_position - context_position
# 对相对位置进行分桶化处理,返回形状为 (query_length, key_length) 的张量
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
# 使用相对注意力偏置计算相对位置的值,形状为 (query_length, key_length, num_heads)
values = self.relative_attention_bias(relative_position_bucket)
# 调整维度顺序,形状变为 (1, num_heads, query_length, key_length)
values = values.permute([2, 0, 1]).unsqueeze(0)
# 返回计算得到的相对位置值
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers
class SwitchTransformersLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
# 初始化自注意力层,使用SwitchTransformersAttention替代T5中的SelfAttention
self.SelfAttention = SwitchTransformersAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
# 初始化层归一化模块,使用SwitchTransformersLayerNorm替代T5中的LayerNorm
self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
# 初始化dropout层,使用config中的dropout_rate
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
# 对隐藏状态进行层归一化处理
normed_hidden_states = self.layer_norm(hidden_states)
# 使用SwitchTransformersAttention进行自注意力计算
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# 将原始的隐藏状态与注意力输出加权和,然后应用dropout
hidden_states = hidden_states + self.dropout(attention_output[0])
# 构建输出元组,包含加权和后的隐藏状态及可能的注意力结果
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers
class SwitchTransformersLayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
# 初始化跨注意力层,使用SwitchTransformersAttention替代T5中的EncDecAttention
self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False)
# 初始化层归一化模块,使用SwitchTransformersLayerNorm替代T5中的LayerNorm
self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
# 初始化dropout层,使用config中的dropout_rate
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
# 对隐藏状态进行层归一化处理
normed_hidden_states = self.layer_norm(hidden_states)
# 使用SwitchTransformersAttention进行跨注意力计算
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
# 将原始的隐藏状态与注意力输出加权和,然后应用dropout
layer_output = hidden_states + self.dropout(attention_output[0])
# 构建输出元组,包含加权和后的隐藏状态及可能的注意力结果
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
class SwitchTransformersBlock(nn.Module):
# 该类尚未完成,需要继续实现Switch Transformers中的Block结构
pass
# 初始化方法,用于创建一个Switch Transformers层的模型
def __init__(self, config, has_relative_attention_bias=False, is_sparse=False):
# 调用父类的初始化方法
super().__init__()
# 根据配置确定是否为解码器
self.is_decoder = config.is_decoder
# 设置是否为稀疏模式
self.is_sparse = is_sparse
# 创建一个空的模块列表,用于存储各层的模块
self.layer = nn.ModuleList()
# 将自注意力层添加到模块列表中
self.layer.append(
SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)
)
# 如果是解码器,则添加交叉注意力层
if self.is_decoder:
self.layer.append(SwitchTransformersLayerCrossAttention(config))
# 添加前馈神经网络层到模块列表中
self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse))
# 前向传播方法,用于模型推断
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
output_router_logits=True,
return_dict=True,
# 定义 SwitchTransformersPreTrainedModel 类,继承自 PreTrainedModel
class SwitchTransformersPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 设置配置类为 SwitchTransformersConfig
config_class = SwitchTransformersConfig
# 设置基础模型前缀为 "switch_transformers"
base_model_prefix = "switch_transformers"
# 支持梯度检查点
supports_gradient_checkpointing = True
# 不需要分割的模块
_no_split_modules = ["SwitchTransformersBlock"]
# 定义属性 dummy_inputs,返回一个包含输入和注意力掩码的字典
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
# 定义内部函数 _shift_right,用于将输入 ids 向右移动
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
# 如果 decoder_start_token_id 未定义,抛出数值错误
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set"
" to the pad_token_id. See SwitchTransformers docs for more information"
)
# 将输入向右移动
if is_torch_fx_proxy(input_ids):
# 对于代理对象,不支持原生的项目赋值
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
# 如果 pad_token_id 未定义,抛出数值错误
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# 用 pad_token_id 替换标签中可能存在的 -100 值
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
# 定义 SwitchTransformersStack 类,继承自 SwitchTransformersPreTrainedModel
class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config) # 调用父类的构造函数,初始化模型基础配置
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) # 初始化词嵌入层
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight # 如果传入了预训练的词嵌入,使用传入的词嵌入权重
self.is_decoder = config.is_decoder # 标记模型是否为解码器
sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step # 设置稀疏步长
config.num_layers = config.num_decoder_layers if self.is_decoder else config.num_layers # 设置层数
self.block = nn.ModuleList() # 创建模块列表,用于存储多层块
# 循环创建多层块
for i in range(config.num_layers):
is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False # 判断当前层是否为稀疏层
self.block.append(
SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse)
) # 将创建的块添加到模块列表中
self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) # 初始化最终的层归一化层
self.dropout = nn.Dropout(config.dropout_rate) # 初始化丢弃层,用于防止过拟合
# 初始化权重并应用最终处理
self.post_init() # 执行额外的初始化步骤
self.device_map = None # 设备映射设为None
self.gradient_checkpointing = False # 梯度检查点设为False
def get_input_embeddings(self):
return self.embed_tokens # 返回输入的词嵌入层
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings # 设置新的输入词嵌入层
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_router_logits=True,
return_dict=None,
# SWITCH_TRANSFORMERS_START_DOCSTRING 是一个长字符串,包含了关于 SWITCH_TRANSFORMERS 模型的详细介绍和相关文献引用。
# 该模型由 William Fedus、Barret Zoph 和 Noam Shazeer 提出,是一种类似于 T5 的编码-解码模型,具有稀疏的前馈结构,采用专家混合 (MoE) 架构。
# 继承自 PreTrainedModel 类,可以查看超类文档以了解该库为所有模型实现的通用方法,如下载或保存模型、调整输入嵌入大小、修剪头等。
# 也是 PyTorch 的 torch.nn.Module 子类,可以像普通 PyTorch 模块一样使用,并参考 PyTorch 文档了解一般使用和行为。
# 参数:
# config ([SwitchTransformersConfig]): 模型配置类,包含模型的所有参数。使用配置文件初始化不会加载与模型关联的权重,只加载配置。查看 ~PreTrainedModel.from_pretrained 方法以加载模型权重。
SWITCH_TRANSFORMERS_START_DOCSTRING = r"""
The SWITCH_TRANSFORMERS model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with
Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by [William
Fedus](https://arxiv.org/search/cs?searchtype=author&query=Fedus%2C+W), [Barret
Zoph](https://arxiv.org/search/cs?searchtype=author&query=Zoph%2C+B), and [Noam
Shazeer](https://arxiv.org/search/cs?searchtype=author&query=Shazeer%2C+N). It's an encoder-decoder T5-like model
with sparse Feed Forward that stands for Mixture of Experts (MoE) architecture.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# SWITCH_TRANSFORMERS_INPUTS_DOCSTRING 是一个空字符串,可能用于描述 SWITCH_TRANSFORMERS 模型的输入。
SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r"""
"""
# SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING 是一个空字符串,可能用于描述 SWITCH_TRANSFORMERS 编码器的输入。
SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
# 输入序列标记在词汇表中的索引。SWITCH_TRANSFORMERS 是一个带有相对位置嵌入的模型,因此可以在左右两侧填充输入。
# 可以使用 `AutoTokenizer` 获取索引。详见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
# 如需了解如何为预训练准备 `input_ids`,请查看 [SWITCH_TRANSFORMERS Training](./switch_transformers#training)。
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
# 避免对填充的标记索引执行注意力操作的掩码。掩码值在 `[0, 1]` 之间:
# - 1 表示**不掩盖**的标记,
# - 0 表示**掩盖**的标记。
# [什么是注意力掩码?](../glossary#attention-mask)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
# 用于使自注意力模块中选择的头部失效的掩码。掩码值在 `[0, 1]` 之间:
# - 1 表示头部**不被掩盖**,
# - 0 表示头部**被掩盖**。
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
# 可选项,直接传递嵌入表示而不是传递 `input_ids`。如果您希望更好地控制如何将 `input_ids` 索引转换为相关联的向量,而不是使用模型的内部嵌入查找矩阵,则此选项很有用。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。有关详细信息,请查看返回张量下的 `attentions`。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。有关详细信息,请查看返回张量下的 `hidden_states`。
output_router_logits (`bool`, *optional*):
# 是否返回所有路由器的 logits。这对计算路由器损失很有用,在推断期间不应返回。
return_dict (`bool`, *optional*):
# 是否返回 `~utils.ModelOutput` 而不是普通元组。
"""
# 未来警告消息:head_mask 参数已分为两个输入参数 - head_mask 和 decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
@add_start_docstrings(
"The bare SWITCH_TRANSFORMERS Model transformer outputting raw hidden-states without any specific head on top.",
SWITCH_TRANSFORMERS_START_DOCSTRING,
)
class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 创建编码器和解码器的配置副本,并设置相关参数
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 初始化编码器
self.encoder = SwitchTransformersStack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
# 初始化解码器
self.decoder = SwitchTransformersStack(decoder_config, self.shared)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行化
self.device_map = None
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
# 设置新的输入嵌入层
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
# 如果配置要求词嵌入层权重共享,则共享编码器和解码器的嵌入层权重
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 剪枝模型的注意力头
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC)
# 定义模型的前向传播方法,接收多个输入参数并返回模型输出
def forward(
self,
# 输入序列的 token IDs,类型为长整型张量,可选
input_ids: Optional[torch.LongTensor] = None,
# 输入序列的注意力掩码,类型为浮点型张量,可选
attention_mask: Optional[torch.FloatTensor] = None,
# 解码器输入序列的 token IDs,类型为长整型张量,可选
decoder_input_ids: Optional[torch.LongTensor] = None,
# 解码器输入序列的注意力掩码,类型为布尔型张量,可选
decoder_attention_mask: Optional[torch.BoolTensor] = None,
# 头部掩码,类型为浮点型张量,用于控制哪些头部不参与计算,可选
head_mask: Optional[torch.FloatTensor] = None,
# 解码器头部掩码,类型为浮点型张量,可选
decoder_head_mask: Optional[torch.FloatTensor] = None,
# 跨注意力头部掩码,类型为张量,用于跨注意力层的头部控制,可选
cross_attn_head_mask: Optional[torch.Tensor] = None,
# 编码器输出,类型为元组的元组,可选
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
# 过去的键值对,类型为元组的元组,用于缓存,可选
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
# 输入的嵌入向量,类型为张量,可选
inputs_embeds: Optional[torch.Tensor] = None,
# 解码器输入的嵌入向量,类型为张量,可选
decoder_inputs_embeds: Optional[torch.Tensor] = None,
# 是否使用缓存,类型为布尔值,可选
use_cache: Optional[bool] = None,
# 是否输出注意力权重,类型为布尔值,可选
output_attentions: Optional[bool] = None,
# 是否输出隐藏状态,类型为布尔值,可选
output_hidden_states: Optional[bool] = None,
# 是否输出路由器 logits,类型为布尔值,可选
output_router_logits: Optional[bool] = None,
# 是否以字典形式返回结果,类型为布尔值,可选
return_dict: Optional[bool] = None,
# 函数参数列表结束,以下为函数体的实现
# 添加注释至类的开头,描述该类的主要功能为带有语言建模头的 SWITCH_TRANSFORMERS 模型
@add_start_docstrings(
"""SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING
)
class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel):
# 定义共享权重的键列表,这些权重将被绑定
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: SwitchTransformersConfig):
# 调用父类的初始化方法
super().__init__(config)
# 设置模型维度为配置中的 d_model 参数
self.model_dim = config.d_model
# 创建共享的嵌入层,用于输入编码器和解码器的词汇表
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制编码器配置,设定为非解码器模式,并关闭缓存和编码解码器模式
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 初始化编码器堆栈,传入编码器配置和共享的嵌入层
self.encoder = SwitchTransformersStack(encoder_config, self.shared)
# 复制解码器配置,设定为解码器模式,并关闭编码解码器模式
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
# 初始化解码器堆栈,传入解码器配置和共享的嵌入层
self.decoder = SwitchTransformersStack(decoder_config, self.shared)
# 创建语言模型头部,线性层,输入维度为 d_model,输出维度为词汇表大小,无偏置
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# 设置路由器的 Z 损失系数和辅助损失系数
self.router_z_loss_coef = config.router_z_loss_coef
self.router_aux_loss_coef = config.router_aux_loss_coef
# 初始化权重并应用最终处理
self.post_init()
# 模型并行化相关,设备映射设为 None
self.device_map = None
# 获取输入嵌入层对象
def get_input_embeddings(self):
return self.shared
# 设置新的输入嵌入层对象,并将其应用于编码器和解码器
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
# 绑定编码器和解码器的权重,如果配置要求绑定词嵌入权重
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 设置新的输出嵌入层对象
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# 获取输出嵌入层对象
def get_output_embeddings(self):
return self.lm_head
# 获取编码器对象
def get_encoder(self):
return self.encoder
# 获取解码器对象
def get_decoder(self):
return self.decoder
# 添加模型前向传播的文档注释,并替换返回文档注释为 Seq2SeqMoEOutput 类型,配置类为 _CONFIG_FOR_DOC
@add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC)
# 定义一个方法用于模型的前向传播,接收多个可选参数来处理输入和解码器相关的信息
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = True,
return_dict: Optional[bool] = None,
):
# 返回模型前向传播需要的所有参数,这些参数包括输入的编码和解码信息
# 这个方法用于模型的前向计算,处理输入数据并生成输出结果
...
# 定义一个内部方法,用于解析路由器输出,提取路由器的逻辑和专家索引
def _unpack_router_logits(self, router_outputs):
total_router_logits = []
total_expert_indexes = []
for router_output in router_outputs:
if len(router_output[0].shape) > 1:
router_logits, expert_indexes = router_output
total_router_logits.append(router_logits)
total_expert_indexes.append(expert_indexes)
# 将所有路由器的逻辑和专家索引拼接在一起,并按指定的维度连接
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
# 定义一个方法,用于为生成准备输入数据
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# 如果使用了过去的键值(past_key_values),则需要裁剪decoder_input_ids
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# 一些生成方法已经仅传递了最后一个输入ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# 默认使用旧的行为:仅保留最后一个ID
remove_prefix_length = input_ids.shape[1] - 1
# 裁剪输入的ID,保留从remove_prefix_length到末尾的部分
input_ids = input_ids[:, remove_prefix_length:]
# 返回一个包含生成所需输入的字典
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
# 定义一个方法,从标签(labels)中准备解码器的输入ID
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
# 调用内部方法 _shift_right,将标签(labels)右移一位,作为解码器的输入
return self._shift_right(labels)
# 重新排序缓存中的过去键值对,根据beam_idx参数进行重排序
def _reorder_cache(self, past_key_values, beam_idx):
# 如果decoder的过去状态未包含在输出中
# 禁用快速解码,并且无需重新排序
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past_key_values
# 初始化重新排序后的decoder过去状态的元组
reordered_decoder_past = ()
# 遍历每一层的过去状态
for layer_past_states in past_key_values:
# 初始化重新排序后的当前层过去状态的元组
reordered_layer_past_states = ()
# 遍历当前层的每一个过去状态
for layer_past_state in layer_past_states:
# 根据beam_idx参数选择正确的批次索引,以获取正确的过去状态
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
# 检查重新排序后的当前层过去状态的形状是否与原始的过去状态一致
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
"expected reordered_layer_past_states to have the same shape than layer_past_states, "
f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}"
)
# 检查重新排序后的当前层过去状态的长度是否与原始的过去状态一致
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
"expected layer_past_states to have the same length as reordered_layer_past_states, "
f"but got {len(layer_past_states)} and {len(reordered_layer_past_states)}"
)
# 将当前层重新排序后的过去状态添加到总的重新排序过的decoder过去状态中
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
# 返回重新排序后的decoder过去状态
return reordered_decoder_past
# 使用装饰器为类添加文档字符串,描述此类是一个 SWITCH_TRANSFORMERS 模型的编码器,输出编码器的原始隐藏状态而不包含特定的顶部头信息
@add_start_docstrings(
"The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head"
" on top.",
SWITCH_TRANSFORMERS_START_DOCSTRING,
)
class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
# 定义权重绑定的键列表,这些权重将与 encoder.embed_tokens.weight 共享
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig):
# 调用父类的初始化方法,传入配置参数
super().__init__(config)
# 创建一个共享的嵌入层,将词汇表大小和模型配置中的 d_model 作为参数
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置并调整为不使用缓存、非编码-解码模式的编码器配置
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 创建 SwitchTransformersStack 对象作为编码器,并传入共享的嵌入层
self.encoder = SwitchTransformersStack(encoder_config, self.shared)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行化设置,设备映射初始化为 None
self.device_map = None
# 返回共享的嵌入层对象
def get_input_embeddings(self):
return self.shared
# 设置新的输入嵌入层,并更新编码器中的共享嵌入层
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
# 如果配置要求,将编码器的嵌入层权重与共享的嵌入层权重进行绑定
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 返回编码器对象
def get_encoder(self):
return self.encoder
# 对模型的头部进行修剪,heads_to_prune 是一个字典,格式为 {层号: 需要在此层中修剪的头部列表}
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
# 在指定层中的自注意力模块中修剪头部
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
# 使用装饰器为前向传播方法添加文档字符串,描述其输入参数和返回类型
@add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MoEModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = True,
return_dict: Optional[bool] = None,
# 省略了函数体,因为代码截断在此处,应继续注释直到代码结束
) -> Union[Tuple[torch.FloatTensor], MoEModelOutput]:
r"""
Returns the encoder outputs based on the given inputs and optional configurations.
Example:
```
>>> from transformers import AutoTokenizer, SwitchTransformersEncoderModel
>>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
>>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
# Determine whether to use the return_dict based on the provided argument or default configuration
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Pass input arguments to the encoder module for processing
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
)
# Return the outputs generated by the encoder module
return encoder_outputs
# 导入必要的模块和函数
from typing import TYPE_CHECKING
# 从 utils 模块导入所需的异常和函数
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
# 定义模块的导入结构,包括配置和模型相关内容
_import_structure = {
"configuration_switch_transformers": [
"SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SwitchTransformersConfig",
"SwitchTransformersOnnxConfig",
]
}
# 检查是否 Torch 可用,如果不可用则引发 OptionalDependencyNotAvailable 异常
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
# 如果 Torch 可用,则导入模型相关内容到模块导入结构中
_import_structure["modeling_switch_transformers"] = [
"SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwitchTransformersEncoderModel",
"SwitchTransformersForConditionalGeneration",
"SwitchTransformersModel",
"SwitchTransformersPreTrainedModel",
"SwitchTransformersTop1Router",
"SwitchTransformersSparseMLP",
]
# 如果是类型检查阶段,则从相应的模块导入配置和模型内容
if TYPE_CHECKING:
from .configuration_switch_transformers import (
SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP,
SwitchTransformersConfig,
SwitchTransformersOnnxConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_switch_transformers import (
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST,
SwitchTransformersEncoderModel,
SwitchTransformersForConditionalGeneration,
SwitchTransformersModel,
SwitchTransformersPreTrainedModel,
SwitchTransformersSparseMLP,
SwitchTransformersTop1Router,
)
# 如果不是类型检查阶段,则将当前模块设为延迟加载模块
else:
import sys
# 使用 _LazyModule 将当前模块设为延迟加载模块,确保按需导入内容
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\t5\configuration_t5.py
# coding=utf-8
# 引入必要的模块和类
# 版权声明和许可协议
# 版权所有 2020 年,T5 作者和 HuggingFace 公司
#
# 根据 Apache 许可协议 2.0 版本(“许可证”),除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,
# 没有任何形式的担保或条件,无论是明示的还是默示的。
# 有关更多信息,请参阅许可协议。
""" T5 模型配置 """
# 导入必要的类型注解
from typing import Mapping
# 从相关模块导入所需的类和函数
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
# 获取 logger 对象,用于日志记录
logger = logging.get_logger(__name__)
# 预训练配置文件的映射表,包含了不同预训练模型的配置文件 URL
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google-t5/t5-small": "https://huggingface.co/google-t5/t5-small/resolve/main/config.json",
"google-t5/t5-base": "https://huggingface.co/google-t5/t5-base/resolve/main/config.json",
"google-t5/t5-large": "https://huggingface.co/google-t5/t5-large/resolve/main/config.json",
"google-t5/t5-3b": "https://huggingface.co/google-t5/t5-3b/resolve/main/config.json",
"google-t5/t5-11b": "https://huggingface.co/google-t5/t5-11b/resolve/main/config.json",
}
# T5Config 类,继承自 PretrainedConfig 类
class T5Config(PretrainedConfig):
r"""
这是一个配置类,用于存储 [`T5Model`] 或 [`TFT5Model`] 的配置。它用于根据指定的参数实例化 T5 模型,定义模型架构。
使用默认参数实例化配置对象将产生类似于 T5 [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可以用于控制模型输出。有关更多信息,请参阅 [`PretrainedConfig`] 的文档。
```
# 模型类型设定为 "t5"
model_type = "t5"
# 推断阶段忽略的关键字列表,这里包括 "past_key_values"
keys_to_ignore_at_inference = ["past_key_values"]
# 属性映射字典,将模型属性名映射到通用命名
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
# 初始化函数,用于初始化Transformer模型的各种参数和配置
def __init__(
self,
vocab_size=32128, # 词汇表大小,默认为32128
d_model=512, # 模型的维度,默认为512
d_kv=64, # 键和值的维度,默认为64
d_ff=2048, # Feed Forward层的维度,默认为2048
num_layers=6, # Transformer层的数量,默认为6
num_decoder_layers=None, # 解码器层的数量,默认为num_layers的值,保持对称性
num_heads=8, # 多头注意力机制中的头数,默认为8
relative_attention_num_buckets=32, # 相对注意力机制中的桶数,默认为32
relative_attention_max_distance=128, # 相对注意力机制的最大距离,默认为128
dropout_rate=0.1, # Dropout率,默认为0.1
layer_norm_epsilon=1e-6, # Layer Normalization中的epsilon,默认为1e-6
initializer_factor=1.0, # 初始化因子,默认为1.0
feed_forward_proj="relu", # 前向传播投影层的激活函数,默认为'relu'
is_encoder_decoder=True, # 是否是编码器-解码器结构,默认为True
use_cache=True, # 是否使用缓存,默认为True
pad_token_id=0, # 填充token的ID,默认为0
eos_token_id=1, # EOS(句子结束)token的ID,默认为1
classifier_dropout=0.0, # 分类器中的Dropout率,默认为0.0
**kwargs, # 其他关键字参数
):
self.vocab_size = vocab_size # 初始化词汇表大小
self.d_model = d_model # 初始化模型的维度
self.d_kv = d_kv # 初始化键和值的维度
self.d_ff = d_ff # 初始化Feed Forward层的维度
self.num_layers = num_layers # 初始化Transformer层的数量
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # 解码器层数,默认为num_layers的值,保持对称性
self.num_heads = num_heads # 初始化多头注意力机制中的头数
self.relative_attention_num_buckets = relative_attention_num_buckets # 初始化相对注意力机制中的桶数
self.relative_attention_max_distance = relative_attention_max_distance # 初始化相对注意力机制的最大距离
self.dropout_rate = dropout_rate # 初始化Dropout率
self.classifier_dropout = classifier_dropout # 初始化分类器中的Dropout率
self.layer_norm_epsilon = layer_norm_epsilon # 初始化Layer Normalization中的epsilon
self.initializer_factor = initializer_factor # 初始化初始化因子
self.feed_forward_proj = feed_forward_proj # 初始化前向传播投影层的激活函数
self.use_cache = use_cache # 初始化是否使用缓存
act_info = self.feed_forward_proj.split("-") # 拆分前向传播投影激活函数的信息
self.dense_act_fn = act_info[-1] # 设置密集层的激活函数
self.is_gated_act = act_info[0] == "gated" # 判断是否为门控激活函数
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
# 如果激活函数信息长度大于1且不是'gated'或长度大于2,则引发值错误
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# 为了向后兼容性
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new" # 设置密集层的激活函数为'gelu_new'
super().__init__(
pad_token_id=pad_token_id, # 初始化填充token的ID
eos_token_id=eos_token_id, # 初始化EOS(句子结束)token的ID
is_encoder_decoder=is_encoder_decoder, # 初始化是否是编码器-解码器结构
**kwargs, # 其他关键字参数
)
class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# 定义 T5 模型的配置类,继承自带过去信息的 Seq2Seq ONNX 配置类
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
# 定义输入属性,返回一个映射类型,将字符串键映射到包含整数和字符串的字典
# 常见的输入配置,包括输入 ID 和注意力掩码
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
"attention_mask": {0: "batch", 1: "encoder_sequence"},
}
# 如果使用过去信息
if self.use_past:
# 调整注意力掩码以包含过去的编码器序列信息
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
# 添加解码器输入 ID 的配置
common_inputs["decoder_input_ids"] = {0: "batch"}
# 添加解码器注意力掩码的配置,包括过去的解码器序列信息和当前序列信息
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
# 添加默认的解码器输入 ID 配置
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
# 添加默认的解码器注意力掩码配置
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
# 如果使用过去信息,调用内部方法填充键值对
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
# 返回最终的输入配置字典
return common_inputs
@property
def default_onnx_opset(self) -> int:
# 定义默认的 ONNX 运算集版本号为 13
return 13
.\models\t5\convert_t5x_checkpoint_to_flax.py
# 设置脚本的编码格式为 UTF-8
# 版权声明,声明脚本归 HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证 2.0 版本使用本文件,除非符合许可证要求,否则不得使用该文件
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,本软件是基于“原样”提供的,不提供任何明示或暗示的担保或条件
# 请查阅许可证以获取详细的权利和限制条款
#
"""Convert T5X checkpoints from the original repository to JAX/FLAX model."""
# 导入必要的库和模块
import argparse
# 从 t5x 模块中导入 checkpoints
from t5x import checkpoints
# 从 transformers 库中导入 FlaxT5ForConditionalGeneration 类和 T5Config 类
from transformers import FlaxT5ForConditionalGeneration, T5Config
# 定义函数,将 T5X 检查点转换为 Flax 模型
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
# 使用指定的配置名称创建 T5Config 对象
config = T5Config.from_pretrained(config_name)
# 使用配置对象创建 FlaxT5ForConditionalGeneration 模型
flax_model = FlaxT5ForConditionalGeneration(config=config)
# 加载给定路径上的 T5X 检查点,返回 T5X 模型对象
t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
# 检查是否在 t5x_model 的目标部分的编码器层中存在名为 "wi_0" 的子项
split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
# Encoder 部分的转换操作未提供,留待后续补充
# 遍历配置中指定数量的层
for layer_index in range(config.num_layers):
# 构建当前层的名称,格式为 "layers_<层索引>"
layer_name = f"layers_{str(layer_index)}"
# 获取当前层的自注意力机制相关参数
t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
# 获取当前层的自注意力机制前的归一化参数
t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
# 根据条件选择当前层的多层感知机的参数
if split_mlp_wi:
t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
else:
t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
# 获取当前层的多层感知机的输出参数
t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
# 获取当前层多层感知机前的归一化参数
t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
# 将 T5X 模型中的参数赋值给 Flax 模型的对应位置
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][
"kernel"
] = t5x_attention_key
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][
"kernel"
] = t5x_attention_out
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][
"kernel"
] = t5x_attention_query
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][
"kernel"
] = t5x_attention_value
# 设置 Flax 模型当前层的注意力机制前的归一化参数
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][
"weight"
] = t5x_attention_layer_norm
# 根据条件选择并设置 Flax 模型当前层的多层感知机的参数
if split_mlp_wi:
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][
"kernel"
] = t5x_mlp_wi_0
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][
"kernel"
] = t5x_mlp_wi_1
else:
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"][
"kernel"
] = t5x_mlp_wi
# 设置 Flax 模型当前层多层感知机的输出参数
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"][
"kernel"
] = t5x_mlp_wo
# 设置 Flax 模型当前层多层感知机前的归一化参数
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][
"weight"
] = t5x_mlp_layer_norm
# 仅适用于第一层(layer 0)的操作:获取 T5X 模型的编码器相对位置偏置的嵌入矩阵的转置
t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
# 将 t5x_encoder_rel_embedding 赋值给 flax_model.params 中的特定路径
flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
"embedding"
] = t5x_encoder_rel_embedding
# 将 t5x_model 中的特定路径的值赋给 flax_model.params 中的特定路径
t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
# 将 t5x_model 中的特定路径的值赋给 flax_model.params 中的特定路径
# 这是针对解码器的最终归一化层
tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
# 将 t5x_model 中的特定路径的值赋给 flax_model.params 中的特定路径
# 仅对解码器的第一个层的第一个自注意力模块使用相对注意力偏置
t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
"embedding"
] = t5x_decoder_rel_embedding
# 将 t5x_model 中的特定路径的值赋给 flax_model.params 中的特定路径
# 这是共享的令牌嵌入
tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
flax_model.params["shared"]["embedding"] = tx5_token_embeddings
# 如果在 t5x_model 的特定路径中存在 "logits_dense",则将其值赋给 flax_model.params 中的特定路径
# 这是语言模型头部的内核,仅适用于版本 v1.1 的检查点
if "logits_dense" in t5x_model["target"]["decoder"]:
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
# 将转换后的模型保存到指定路径
flax_model.save_pretrained(flax_dump_folder_path)
# 打印成功转换的消息
print("T5X Model was sucessfully converted!")
if __name__ == "__main__":
# 如果当前脚本作为主程序运行,则执行以下代码块
parser = argparse.ArgumentParser()
# 创建参数解析器对象
# Required parameters
# 添加必需的命令行参数
parser.add_argument(
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
)
# t5x_checkpoint_path 参数,指定了 TX5 模型的检查点路径
parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of T5 model.")
# config_name 参数,指定了 T5 模型的配置名称
parser.add_argument(
"--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
)
# flax_dump_folder_path 参数,指定了输出 FLAX 模型的文件夹路径
# 解析命令行参数
args = parser.parse_args()
# 调用函数 convert_t5x_checkpoint_to_flax,将命令行参数传递给函数
convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
.\models\t5\convert_t5x_checkpoint_to_pytorch.py
# 指定文件编码为 UTF-8
# 版权声明,版权归谷歌有限责任公司和HuggingFace公司所有
#
# 根据Apache许可证2.0版进行许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件按“原样”分发,
# 没有任何明示或暗示的担保或条件。
# 有关许可证的详细信息,请参阅许可证。
"""
将T5X检查点转换为PyTorch格式
步骤:
- 根据https://cloud.google.com/storage/docs/gsutil_install 安装gsutil
- 在https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints获取T5X检查点 示例:
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
- 创建或下载相应模型的配置。例如,对于T5 v1.1 small,您可以使用
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
- 转换:
```
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
--pytorch_dump_path=$HOME/t5_1_1_small_pt
```
"""
import argparse # 导入命令行参数解析模块
import collections # 导入collections模块
import torch # 导入PyTorch库
from flax import traverse_util # 导入flax库的traverse_util模块
from t5x import checkpoints # 从t5x库导入checkpoints模块
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration # 从transformers库导入必要的类
from transformers.utils import logging # 从transformers库导入logging模块用于日志记录
logging.set_verbosity_info() # 设置日志级别为信息级别
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
"""返回(self-)attention的KOQV参数,不进行转置。"""
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
return k, o, q, v
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""返回层的MLP参数,不进行转置。"""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
return wi, wo
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
"""返回层的层归一化参数。"""
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool):
"""将T5X-Flax的参数转换为Transformers-PyTorch格式。"""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
# v1.1模型具有具有wi_0和wi_1而不是wi的门控GeLU
# 检查旧模型中是否存在指定路径,判断是否要分离 MLP 的权重
split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old
# 打印是否分离 MLP 的信息
print("Split MLP:", split_mlp_wi)
# 创建一个新的有序字典用于存储转换后的模型参数
new = collections.OrderedDict()
# 共享的嵌入层权重
new["shared.weight"] = old["token_embedder/embedding"]
# 编码器部分的参数转换
for i in range(num_layers):
# 第 i 个块,第 0 层(自注意力层)
# 获取自注意力层前的层归一化权重
layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm")
# 获取自注意力层中的注意力权重(k, o, q, v)
k, o, q, v = t5x_attention_lookup(old, i, "encoder", "attention")
new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
# 第 i 个块,第 1 层(MLP 层)
# 获取 MLP 层前的层归一化权重
layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm")
# 获取 MLP 层中的权重(wi, wo)
wi, wo = t5x_mlp_lookup(old, i, "encoder", split_mlp_wi)
new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
if split_mlp_wi:
new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T
new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T
else:
new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T
new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T
# 编码器的第一个块的自注意力层的相对注意力偏置权重
new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
"encoder/relpos_bias/rel_embedding"
].T
# 编码器最终层的归一化权重
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
if not is_encoder_only:
# 如果不是仅编码器模式,则执行解码器部分
# 解码器部分
for i in range(num_decoder_layers):
# 对于每个解码器层 i:
# Block i, layer 0 (Self Attention).
# 第 i 块,第 0 层 (自注意力)
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
# 获取旧模型中解码器第 i 块的预自注意力层规范化参数
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
# 获取旧模型中解码器第 i 块的自注意力参数 k, o, q, v
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
# 设置新模型中解码器第 i 块的第 0 层的层规范化权重
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
# 设置新模型中解码器第 i 块的第 0 层自注意力的 k 权重
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
# 设置新模型中解码器第 i 块的第 0 层自注意力的 o 权重
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
# 设置新模型中解码器第 i 块的第 0 层自注意力的 q 权重
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
# 设置新模型中解码器第 i 块的第 0 层自注意力的 v 权重
# Block i, layer 1 (Cross Attention).
# 第 i 块,第 1 层 (交叉注意力)
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
# 获取旧模型中解码器第 i 块的预交叉注意力层规范化参数
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
# 获取旧模型中解码器第 i 块的编码器-解码器注意力参数 k, o, q, v
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
# 设置新模型中解码器第 i 块的第 1 层的层规范化权重
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
# 设置新模型中解码器第 i 块的第 1 层交叉注意力的 k 权重
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
# 设置新模型中解码器第 i 块的第 1 层交叉注意力的 o 权重
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
# 设置新模型中解码器第 i 块的第 1 层交叉注意力的 q 权重
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
# 设置新模型中解码器第 i 块的第 1 层交叉注意力的 v 权重
# Block i, layer 2 (MLP).
# 第 i 块,第 2 层 (多层感知机)
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
# 获取旧模型中解码器第 i 块的预多层感知机层规范化参数
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
# 获取旧模型中解码器第 i 块的多层感知机参数 wi, wo
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
# 设置新模型中解码器第 i 块的第 2 层的层规范化权重
if split_mlp_wi:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
# 设置新模型中解码器第 i 块的第 2 层多层感知机的 wi_0 权重
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
# 设置新模型中解码器第 i 块的第 2 层多层感知机的 wi_1 权重
else:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
# 设置新模型中解码器第 i 块的第 2 层多层感知机的 wi 权重
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
# 设置新模型中解码器第 i 块的第 2 层多层感知机的 wo 权重
# 解码器最终层规范化权重
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
# 设置新模型中解码器的最终层规范化权重
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
"decoder/relpos_bias/rel_embedding"
].T
# 设置新模型中解码器第 0 块第 0 层自注意力的相对注意力偏置权重
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
# 语言模型头部(仅在 v1.1 版本的检查点中,v1.0 版本使用嵌入代替)
if "decoder/logits_dense/kernel" in old:
# 如果旧模型中存在 "decoder/logits_dense/kernel"
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
# 设置新模型中的 lm_head 权重为旧模型中 logits_dense/kernel 的转置
# 返回新模型
return new
def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model."""
# 创建一个有序字典的状态字典,使用 torch.from_numpy 将每个参数的副本转换为张量
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
# 如果缺少 "encoder.embed_tokens.weight",则用 "shared.weight" 补充
if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
# 如果不仅仅是编码器,还需处理解码器和语言模型头部的参数
if not is_encoder_only:
# 如果缺少 "decoder.embed_tokens.weight",则用 "shared.weight" 补充
if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
# 对于旧版本 1.0 的模型,如果缺少 "lm_head.weight",则打印警告并用 "shared.weight" 补充
if "lm_head.weight" not in state_dict: # For old 1.0 models.
print("Using shared word embeddings as lm_head.")
state_dict["lm_head.weight"] = state_dict["shared.weight"]
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model with the T5X converted params."""
# 加载 T5X checkpoint 中的变量并进行转换为 PyTorch 格式
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(
variables,
num_layers=config.num_layers,
num_decoder_layers=config.num_decoder_layers,
is_encoder_only=is_encoder_only,
)
# 生成状态字典并加载到模型中
state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# 从配置文件加载配置并初始化 PyTorch 模型
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# 根据是否仅为编码器,选择初始化 T5EncoderModel 或 T5ForConditionalGeneration 模型
if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
# 从 TensorFlow checkpoint 加载权重到 PyTorch 模型
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
# 保存转换后的 PyTorch 模型
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
# 验证是否成功加载保存的检查点
model.from_pretrained(pytorch_dump_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
# 必需参数
parser.add_argument(
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
# 添加命令行参数 `--pytorch_dump_path`,指定输出的 PyTorch 模型路径,该参数是必须的
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
# 添加命令行参数 `--is_encoder_only`,表示是否为编码器-解码器模型的标志,该参数为布尔类型,默认为 False
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
# 解析命令行参数,并将结果保存在 args 变量中
args = parser.parse_args()
# 调用函数 convert_t5x_checkpoint_to_pytorch,将给定的 T5X 模型转换为 PyTorch 模型
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
这段代码片段用于解析命令行参数,并调用一个函数来执行 T5X 模型到 PyTorch 模型的转换。
.\models\t5\convert_t5_original_tf_checkpoint_to_pytorch.py
# coding=utf-8
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5 checkpoint."""
import argparse # 导入解析命令行参数的模块
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 # 导入转换模型相关的类和函数
from transformers.utils import logging # 导入日志记录工具
logging.set_verbosity_info() # 设置日志记录级别为INFO
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# 初始化一个PyTorch模型配置
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}") # 打印正在根据配置构建PyTorch模型的消息
model = T5ForConditionalGeneration(config) # 使用配置创建T5条件生成模型
# 从TensorFlow的检查点文件中加载权重
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
# 保存PyTorch模型
print(f"Save PyTorch model to {pytorch_dump_path}") # 打印保存PyTorch模型到指定路径的消息
model.save_pretrained(pytorch_dump_path) # 将模型保存为PyTorch可用的格式
if __name__ == "__main__":
parser = argparse.ArgumentParser() # 创建命令行参数解析器
# 必填参数
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True,
help="Path to the TensorFlow checkpoint path." # TensorFlow检查点文件的路径
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained T5 model. \n"
"This specifies the model architecture."
), # 预训练T5模型对应的配置JSON文件,指定了模型的架构
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True,
help="Path to the output PyTorch model." # 输出PyTorch模型的路径
)
args = parser.parse_args() # 解析命令行参数
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) # 执行转换操作
.\models\t5\modeling_flax_t5.py
# coding=utf-8
# Copyright 2021 T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax T5 model."""
import copy
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
_CONFIG_FOR_DOC = "T5Config"
remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
# 初始化一个与input_ids相同形状的零张量
shifted_input_ids = jnp.zeros_like(input_ids)
# 将input_ids向右移动一位
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
# 在首位插入decoder_start_token_id
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
# 将所有-100的位置替换为pad_token_id
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class FlaxT5LayerNorm(nn.Module):
hidden_size: int
dtype: jnp.dtype = jnp.float32
eps: float = 1e-6
weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
def setup(self):
# 创建权重参数
self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
def __call__(self, hidden_states):
"""
Construct a layernorm module in the T5 style; No bias and no subtraction of mean.
"""
# layer norm should always be calculated in float32
# 计算隐藏状态的方差,并在最后一个轴上求均值,保持维度不变
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
# 对隐藏状态进行标准化,除以标准差(方差的平方根),加上小的常量 self.eps 避免除以零
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
# 返回加权后的标准化隐藏状态
return self.weight * hidden_states
# 定义一个名为 FlaxT5DenseActDense 的神经网络模块类
class FlaxT5DenseActDense(nn.Module):
# 配置属性,指定为 T5Config 类型
config: T5Config
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 模块设置方法,用于初始化模块的各个组件
def setup(self):
# 计算初始化权重标准差,根据配置的初始化因子和模型维度
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
# 创建第一个全连接层 wi,用于 d_ff 到 d_model 的映射
self.wi = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std), # 使用正态分布初始化权重
dtype=self.dtype,
)
# 创建第二个全连接层 wo,用于 d_model 到 d_ff 的映射
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std), # 使用正态分布初始化权重
dtype=self.dtype,
)
# 创建一个 Dropout 层,用于在训练时随机置零部分输入单元,防止过拟合
self.dropout = nn.Dropout(self.config.dropout_rate)
# 根据配置选择激活函数,ACT2FN 是一个激活函数映射字典
self.act = ACT2FN[self.config.dense_act_fn]
# 定义模块的调用方法,用于执行前向传播
def __call__(self, hidden_states, deterministic=True):
# 将输入 hidden_states 经过第一个全连接层 wi
hidden_states = self.wi(hidden_states)
# 使用配置中指定的激活函数进行激活
hidden_states = self.act(hidden_states)
# 对激活后的结果进行 Dropout 操作
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
# 经过第二个全连接层 wo,得到最终输出
hidden_states = self.wo(hidden_states)
return hidden_states
# 定义一个名为 FlaxT5DenseGatedActDense 的神经网络模块类,继承自 nn.Module
class FlaxT5DenseGatedActDense(nn.Module):
# 配置属性,指定为 T5Config 类型
config: T5Config
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 模块设置方法,用于初始化模块的各个组件
def setup(self):
# 计算初始化权重标准差,根据配置的初始化因子和模型维度
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
# 创建第一个全连接层 wi_0,用于 d_ff 到 d_model 的映射
self.wi_0 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std), # 使用正态分布初始化权重
dtype=self.dtype,
)
# 创建第二个全连接层 wi_1,用于 d_ff 到 d_model 的映射
self.wi_1 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std), # 使用正态分布初始化权重
dtype=self.dtype,
)
# 创建第三个全连接层 wo,用于 d_model 到 d_ff 的映射
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std), # 使用正态分布初始化权重
dtype=self.dtype,
)
# 创建一个 Dropout 层,用于在训练时随机置零部分输入单元,防止过拟合
self.dropout = nn.Dropout(self.config.dropout_rate)
# 根据配置选择激活函数,ACT2FN 是一个激活函数映射字典
self.act = ACT2FN[self.config.dense_act_fn]
# 定义模块的调用方法,用于执行前向传播
def __call__(self, hidden_states, deterministic):
# 经过第一个全连接层 wi_0,并使用配置中指定的激活函数进行激活
hidden_gelu = self.act(self.wi_0(hidden_states))
# 经过第二个全连接层 wi_1
hidden_linear = self.wi_1(hidden_states)
# gated activation function:将 gelu 激活后的结果与 linear 相乘
hidden_states = hidden_gelu * hidden_linear
# 对结果进行 Dropout 操作
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
# 经过第三个全连接层 wo,得到最终输出
hidden_states = self.wo(hidden_states)
return hidden_states
# 定义一个名为 FlaxT5LayerFF 的神经网络模块类,继承自 nn.Module
class FlaxT5LayerFF(nn.Module):
# 配置属性,指定为 T5Config 类型
config: T5Config
# 数据类型,默认为 jnp.float32 用于计算的数据类型
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 在对象初始化时设置网络结构
def setup(self):
# 如果配置要求使用带门控激活函数的 DenseReluDense 模块
if self.config.is_gated_act:
# 使用带门控激活函数的 DenseReluDense 初始化
self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype)
else:
# 否则使用普通的 DenseActDense 初始化
self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype)
# 初始化 LayerNorm 层,设置隐藏层维度和 epsilon 值
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
# 初始化 Dropout 层,设置丢弃率
self.dropout = nn.Dropout(self.config.dropout_rate)
# 对象调用方法,对隐藏状态进行处理
def __call__(self, hidden_states, deterministic=True):
# 对隐藏状态进行 LayerNorm 处理
forwarded_states = self.layer_norm(hidden_states)
# 通过 DenseReluDense 模块进行前向传播处理
forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
# 使用 Dropout 处理后的前向传播结果,与原始隐藏状态相加
hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个名为 FlaxT5Attention 的神经网络模块,继承自 nn.Module
class FlaxT5Attention(nn.Module):
# 类属性:配置信息,来自于 T5Config 类
config: T5Config
# 是否包含相对注意力偏置的标志,默认为 False
has_relative_attention_bias: bool = False
# 是否是因果注意力(causal attention)的标志,默认为 False
causal: bool = False
# 计算过程中使用的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 初始化方法
def setup(self):
# 相对注意力的桶数,从配置中获取
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
# 相对注意力的最大距离,从配置中获取
self.relative_attention_max_distance = self.config.relative_attention_max_distance
# 模型维度,从配置中获取
self.d_model = self.config.d_model
# 键值投影的维度,从配置中获取
self.key_value_proj_dim = self.config.d_kv
# 注意力头的数量,从配置中获取
self.n_heads = self.config.num_heads
# 丢弃率,从配置中获取
self.dropout = self.config.dropout_rate
# 内部维度,等于注意力头数量乘以键值投影维度
self.inner_dim = self.n_heads * self.key_value_proj_dim
# 初始化权重的标准差
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
# 创建查询(q)、键(k)、值(v)和输出(o)的全连接层
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std), # 正态分布初始化权重
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std), # 正态分布初始化权重
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std), # 正态分布初始化权重
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std), # 正态分布初始化权重
dtype=self.dtype,
)
# 如果有相对注意力偏置,创建相对注意力偏置的嵌入层
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std), # 正态分布初始化嵌入层权重
dtype=self.dtype,
)
# 静态方法
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
# 如果允许双向注意力,则将桶的数量减半,并根据相对位置的正负决定桶的偏移
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
# 如果不允许双向注意力,则将相对位置限制在非正数范围内
relative_position = -jnp.clip(relative_position, a_max=0)
# 现在,relative_position 的范围是 [0, inf)
# 将一半的桶用于精确增量位置
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# 另一半的桶用于对数增量位置,直到 max_distance
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
# 创建上下文位置矩阵和记忆位置矩阵
context_position = jnp.arange(query_length, dtype="i4")[:, None]
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
# 计算相对位置并将其转换为桶索引
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.causal), # 根据是否因果关系来决定是否双向
num_buckets=self.relative_attention_num_buckets, # 桶的数量
max_distance=self.relative_attention_max_distance, # 最大距离
)
# 计算相对注意力偏置值
values = self.relative_attention_bias(relative_position_bucket)
values = values.transpose((2, 0, 1))[None, :, :, :] # 转置并扩展维度以匹配模型输出格式
return values
def _split_heads(self, hidden_states):
# 将隐藏状态重塑为多头注意力的形状
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
# 将输入的隐藏状态重塑为指定维度的形状
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
# 使用 nn.compact 装饰器定义的方法,用于将单个输入令牌的投影键、值状态与先前步骤中的缓存状态连接起来
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# 检测是否通过缺少现有缓存数据进行初始化
is_initialized = self.has_variable("cache", "cached_key")
# 初始化缓存的键和值,如果未初始化则为零数组
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
# 缓存索引,用于跟踪缓存的位置
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# 使用新的一维空间切片更新键和值缓存
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
# 更新缓存中的键和值
cached_key.value = key
cached_value.value = value
# 更新缓存索引以反映已更新的缓存向量数量
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# 用于缓存的因果掩码:我们的单个查询位置应仅参与已生成和缓存的键位置的自注意力,而不是剩余的零元素
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
# 返回连接后的键、值以及更新后的注意力掩码
return key, value, attention_mask
# 创建位置偏置的方法,接受关键状态、查询状态、注意力掩码、初始化缓存、序列长度和因果注意力掩码偏移量作为参数
def _create_position_bias(
self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
):
# 检查缓存是否已填充,条件包括因果关系和存在特定缓存键,且不是初始化缓存
cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
# 计算键的长度
key_length = key_states.shape[1]
# 如果缓存已填充,则查询长度等于键的长度,否则等于查询状态的长度
query_length = key_length if cache_is_filled else query_states.shape[1]
# 如果模型支持相对注意力偏置,则计算位置偏置
if self.has_relative_attention_bias:
position_bias = self.compute_bias(query_length, key_length)
# 否则,如果存在注意力掩码,则创建与其相同形状的零张量
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
# 否则,默认创建形状为 (1, self.n_heads, query_length, key_length) 的零张量
else:
position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
# 如果缓存已填充,则仅需取最后一个查询位置的偏置
if cache_is_filled:
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
position_bias = jax.lax.dynamic_slice(
position_bias,
(0, 0, causal_attention_mask_shift, 0),
(1, self.n_heads, seq_length, max_decoder_length),
)
# 返回位置偏置张量
return position_bias
# 在调用实例时,接收隐藏状态、注意力掩码、键值状态、位置偏置等多个参数
def __call__(
self,
hidden_states,
attention_mask=None,
key_value_states=None,
position_bias=None,
use_cache=False,
output_attentions=False,
deterministic=True,
init_cache=False,
# 定义一个 FlaxT5Block 类,继承自 nn.Module
class FlaxT5Block(nn.Module):
# T5 模型的配置参数
config: T5Config
# 是否具有相对注意力偏置,默认为 False
has_relative_attention_bias: bool = False
# 计算中使用的数据类型,默认为 jnp.float32
# 初始化方法,设置模块的组件
def setup(self):
# 创建自注意力层对象 SelfAttention,使用 FlaxT5Attention 类
self.SelfAttention = FlaxT5Attention(
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
causal=self.config.causal,
dtype=self.dtype,
)
# 创建层归一化对象 layer_norm,使用 FlaxT5LayerNorm 类
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
# 创建 Dropout 层对象 dropout,使用 nn.Dropout 类,设置丢弃率为 config.dropout_rate
self.dropout = nn.Dropout(self.config.dropout_rate)
# 调用方法,定义模块的前向传播逻辑
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
output_attentions=False,
deterministic=True,
init_cache=False,
):
# 对输入的隐藏状态进行层归一化处理
normed_hidden_states = self.layer_norm(hidden_states)
# 使用 SelfAttention 层处理归一化后的隐藏状态
attention_output = self.SelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
# 将原始隐藏状态与经 Dropout 处理后的注意力输出相加,实现残差连接
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
# 构建输出元组,包括更新后的隐藏状态和可能的注意力输出
outputs = (hidden_states,) + attention_output[1:] # 如果有输出注意力信息,将其添加到输出中
return outputs
# 初始化方法,设置模型的相关属性和层次结构
def setup(self):
# 获取配置中的causal参数
self.causal = self.config.causal
# 初始化self.layer为一个元组,包含FlaxT5LayerSelfAttention层对象
self.layer = (
FlaxT5LayerSelfAttention(
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
name=str(0),
dtype=self.dtype,
),
)
# 初始化feed_forward_index为1
feed_forward_index = 1
# 如果causal为True,则添加FlaxT5LayerCrossAttention层对象到self.layer中
if self.causal:
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
feed_forward_index += 1
# 添加FlaxT5LayerFF层对象到self.layer中,名称为feed_forward_index
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
# 模型调用方法,执行自注意力和交叉注意力计算,并返回相应的输出
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
return_dict=True,
deterministic=True,
init_cache=False,
):
# 执行自注意力层计算
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
# 更新hidden_states为自注意力输出的第一个元素
hidden_states = self_attention_outputs[0]
# 保留自注意力输出和相对位置权重
attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
# 如果需要执行交叉注意力计算
do_cross_attention = self.causal and encoder_hidden_states is not None
if do_cross_attention:
# 执行交叉注意力层计算
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
)
# 更新hidden_states为交叉注意力输出的第一个元素
hidden_states = cross_attention_outputs[0]
# 将交叉注意力输出和相对位置权重添加到attention_outputs中
attention_outputs = attention_outputs + cross_attention_outputs[1:]
# 应用Feed Forward层计算
hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
# 初始化输出为包含hidden_states的元组
outputs = (hidden_states,)
# 将attention_outputs添加到输出元组中
outputs = outputs + attention_outputs
# 返回包含hidden-states、present_key_value_states、(self-attention position bias)、
# (self-attention weights)、(cross-attention position bias)、(cross-attention weights)的元组
return outputs
class FlaxT5LayerCollection(nn.Module):
config: T5Config
has_relative_attention_bias: bool
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
# 初始化单个 T5 层
self.layer = FlaxT5Block(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
deterministic=True,
init_cache=False,
):
# 调用单个 T5 层进行计算
return self.layer(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
class FlaxT5BlockCollection(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def setup(self):
# 根据配置初始化 T5 块集合
self.causal = self.config.causal
if self.gradient_checkpointing:
# 如果启用梯度检查点,则使用 remat 函数包装 FlaxT5LayerCollection
FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
# 否则,创建普通的 FlaxT5LayerCollection 实例列表
self.blocks = [
FlaxT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
def __call__(
self,
hidden_states=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
deterministic: bool = True,
init_cache: bool = False,
):
# 调用 T5 块集合进行前向传播
# 这里假设 self.blocks 包含了多个 T5 块,每个块处理一部分输入数据
# 返回的结果取决于具体的 T5 模型结构和参数设置
pass # 此处的 pass 语句表示函数没有具体的返回内容,实际使用时需根据具体需求实现
):
# 如果需要输出隐藏状态,初始化一个空元组
all_hidden_states = () if output_hidden_states else None
# 如果需要输出注意力权重,初始化一个空元组
all_attentions = () if output_attentions else None
# 如果需要输出交叉注意力权重且模型是因果的,初始化一个空元组
all_cross_attentions = () if (output_attentions and self.causal) else None
# 初始化位置偏置为 None
position_bias = None
# 初始化编码器-解码器位置偏置为 None
encoder_decoder_position_bias = None
# 遍历每个 Transformer 模块层
for i, layer_module in enumerate(self.blocks):
# 如果需要输出隐藏状态,将当前隐藏状态添加到 all_hidden_states 元组中
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 调用 Transformer 层进行前向传播
layer_outputs = layer_module(
hidden_states,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
deterministic,
init_cache,
)
# 更新隐藏状态为当前层的输出隐藏状态
hidden_states = layer_outputs[0]
# 更新位置偏置为当前层的自注意力位置偏置
position_bias = layer_outputs[1]
# 如果模型是因果的且存在编码器隐藏状态,更新编码器-解码器位置偏置
if self.causal and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
# 如果需要输出注意力权重,将当前层的注意力权重添加到 all_attentions 元组中
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
# 如果模型是因果的,将当前层的交叉注意力权重添加到 all_cross_attentions 元组中
if self.causal:
all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
# 返回 Transformer 模型的输出,包括最终的隐藏状态、所有层的隐藏状态、所有层的注意力权重和交叉注意力权重
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class FlaxT5Stack(nn.Module):
config: T5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # 计算中使用的数据类型,默认为 jnp.float32
gradient_checkpointing: bool = False # 是否启用梯度检查点
def setup(self):
self.causal = self.config.causal # 是否是因果关系模型
# 初始化 T5 模型的块集合
self.block = FlaxT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 最终的层归一化
self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate) # 用于随机失活的 Dropout 层
def __call__(
self,
input_ids=None, # 输入的 token IDs
attention_mask=None, # 注意力掩码
encoder_hidden_states=None, # 编码器隐藏状态
encoder_attention_mask=None, # 编码器注意力掩码
output_attentions: bool = False, # 是否输出注意力权重
output_hidden_states: bool = False, # 是否输出所有隐藏状态
return_dict: bool = True, # 是否返回字典格式结果
deterministic: bool = True, # 是否确定性计算
init_cache: bool = False, # 是否初始化缓存
):
hidden_states = self.embed_tokens(input_ids) # 嵌入 token IDs 得到隐藏状态
hidden_states = self.dropout(hidden_states, deterministic=deterministic) # 使用 Dropout 层
outputs = self.block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
deterministic=deterministic,
init_cache=init_cache,
) # 使用 T5 模型块进行前向传播计算
hidden_states = outputs[0] # 取得输出中的隐藏状态
hidden_states = self.final_layer_norm(hidden_states) # 最终的层归一化
hidden_states = self.dropout(hidden_states, deterministic=deterministic) # 再次应用 Dropout
# 添加最后一层
all_hidden_states = None
if output_hidden_states:
all_hidden_states = outputs.hidden_states
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
if output_hidden_states:
return (
hidden_states,
all_hidden_states,
) + outputs[2:] # 返回不同类型的输出
return (hidden_states,) + outputs[1:]
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) # 返回带有注意力和交叉注意力的最终模型输出
T5_ENCODE_INPUTS_DOCSTRING = r"""
# 接收输入参数:
# - input_ids (`jnp.ndarray`,形状为 `(batch_size, sequence_length)`):
# 表示输入序列标记在词汇表中的索引。T5 是一个带有相对位置嵌入的模型,因此可以在左右两侧对输入进行填充。
# 可以使用 [`AutoTokenizer`] 获取索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
# 若要了解有关预训练中如何准备 `input_ids` 的更多信息,请查看 [T5 Training](./t5#training)。
# - attention_mask (`jnp.ndarray`,形状为 `(batch_size, sequence_length)`),*可选*:
# 遮盖掩码,用于避免在填充标记索引上执行注意力操作。遮盖值在 `[0, 1]` 之间:
# - 1 表示 **未遮盖** 的标记,
# - 0 表示 **遮盖** 的标记。
# [什么是注意力遮盖?](../glossary#attention-mask)
# - output_attentions (`bool`,*可选*):
# 是否返回所有注意力层的注意力张量。查看返回的张量中的 `attentions` 以获取更多细节。
# - output_hidden_states (`bool`,*可选*):
# 是否返回所有层的隐藏状态。查看返回的张量中的 `hidden_states` 以获取更多细节。
# - return_dict (`bool`,*可选*):
# 是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
"""
定义一个文档字符串常量,用于描述T5模型输入的参数说明文档。
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
解码器输入序列标记在词汇表中的索引。
可以使用[`AutoTokenizer`]获取。详见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]。
[decoder_input_ids是什么?](../glossary#decoder-input-ids)
在训练时,应提供`decoder_input_ids`。
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
元组包含(`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)
`last_hidden_state`的形状为`(batch_size, sequence_length, hidden_size)`,*可选*是编码器最后一层的隐藏状态序列。
用于解码器的交叉注意力。
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
遮罩,避免对填充标记索引执行注意力计算。遮罩中的值选择在`[0, 1]`:
- 1表示**未遮罩**的标记,
- 0表示**已遮罩**的标记。
[注意力遮罩是什么?](../glossary#attention-mask)
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *可选*):
默认行为:生成一个张量,忽略`decoder_input_ids`中的填充标记。默认情况下也将使用因果遮罩。
如果要更改填充行为,应根据需要进行修改。有关默认策略的更多信息,请参见[论文中的图1](https://arxiv.org/abs/1910.13461)。
past_key_values (`Dict[str, np.ndarray]`, *可选*, 由`init_cache`返回或传递先前的`past_key_values`):
预先计算的隐藏状态字典(注意力块中的键和值)。可用于快速自回归解码。预计算的键和值隐藏状态的形状为*[batch_size, max_length]*。
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。有关返回张量中`attentions`的更多细节,请参见文档。
output_hidden_states (`bool`, *可选*):
是否返回所有层的隐藏状态。有关返回张量中`hidden_states`的更多细节,请参见文档。
return_dict (`bool`, *可选*):
是否返回[`~utils.ModelOutput`]而不是普通元组。
"""
# 初始化方法,用于创建一个新的模型实例
def __init__(
self,
config: T5Config,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
# 使用给定的配置和参数实例化模块对象
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
# 调用父类的初始化方法,传入配置、模块对象以及其他参数
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# 启用梯度检查点功能的方法
def enable_gradient_checkpointing(self):
# 更新模块对象,设置梯度检查点为 True
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# 初始化模型权重的方法
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量 input_ids,全零张量
input_ids = jnp.zeros(input_shape, dtype="i4")
# 创建与 input_ids 形状相同的全一张量 attention_mask
attention_mask = jnp.ones_like(input_ids)
args = [input_ids, attention_mask]
# 如果模块类不是 FlaxT5EncoderModule,则初始化解码器相关输入张量
if self.module_class not in [FlaxT5EncoderModule]:
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args.extend([decoder_input_ids, decoder_attention_mask])
# 切分随机数生成器 rng,用于参数和 dropout
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用随机数生成器和输入张量初始化模块参数,返回随机初始化后的参数
random_params = self.module.init(
rngs,
*args,
)["params"]
# 如果传入了已有的参数 params,则将缺失的参数填充为随机初始化的参数值
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
# 返回填充后的参数冻结字典
return freeze(unflatten_dict(params))
else:
# 否则直接返回随机初始化的参数
return random_params
# 覆盖父类的 __call__ 方法,定义模型的前向传播
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: jnp.ndarray = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 如果 `output_attentions` 参数未指定,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 `output_hidden_states` 参数未指定,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果 `return_dict` 参数未指定,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果缺少 `decoder_input_ids` 参数,则抛出数值错误异常
if decoder_input_ids is None:
raise ValueError(
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
" here."
)
# 准备编码器输入的注意力掩码
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 准备解码器输入的注意力掩码
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 处理可能存在的伪随机数生成器
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 调用模型的应用方法,传递参数和输入数据
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
# 初始化缓存函数,用于自动回归解码的快速初始化
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
用于快速自动回归解码的批大小。定义了初始化缓存时的批大小。
max_length (`int`):
自动回归解码的最大可能长度。定义了初始化缓存时的序列长度。
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` 包括 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
`last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,
*可选*: 是编码器最后一层的隐藏状态序列。在解码器的交叉注意力中使用。
"""
# 初始化用于检索缓存的输入变量
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
# 获取解码器模块
decoder_module = module._get_decoder_module()
# 调用解码器模块进行前向传播
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
**kwargs,
)
# 使用指定方法进行初始化,仅需调用解码器以初始化缓存
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward,
)
# 返回解冻后的缓存变量
return unfreeze(init_variables["cache"])
@add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=T5Config)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, return_tensors="np")
>>> encoder_outputs = model.encode(**inputs)
```"""
# 根据参数设置或者默认配置决定是否输出注意力机制
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 根据参数设置或者默认配置决定是否输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 根据参数设置或者默认配置决定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果未提供注意力掩码,则创建一个全为1的掩码,与输入张量维度相同
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 如果有需要,处理任何的伪随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 定义编码器前向函数
def _encoder_forward(module, input_ids, attention_mask, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, **kwargs)
# 调用模型的应用方法,传入参数并执行编码器前向计算
return self.module.apply(
{"params": params or self.params}, # 使用给定的参数或者默认参数
input_ids=jnp.array(input_ids, dtype="i4"), # 将输入张量转换为JAX数组
attention_mask=jnp.array(attention_mask, dtype="i4"), # 将注意力掩码转换为JAX数组
output_attentions=output_attentions, # 是否输出注意力机制
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否以字典格式返回输出
deterministic=not train, # 是否确定性计算,即非训练模式
rngs=rngs, # 传入的伪随机数生成器
method=_encoder_forward, # 调用的方法,即编码器的前向计算函数
)
@add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=T5Config)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
"""
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder-decoder transformer pre-trained in a
text-to-text denoising generative setting.
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matters related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`T5Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5Module(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
"""
Retrieve the encoder module of the T5 model.
"""
return self.encoder
def _get_decoder_module(self):
"""
Retrieve the decoder module of the T5 model.
"""
return self.decoder
# 初始化模型参数,包括共享的嵌入层和编码器、解码器的配置
def setup(self):
# 初始化共享的嵌入层,使用给定的词汇大小和模型维度,使用正态分布进行初始化
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
# 复制编码器配置,并设置非因果性(causal=False)
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
# 初始化编码器模型,使用共享的嵌入层和复制后的配置
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 复制解码器配置,并设置因果性(causal=True),以及解码器层数
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
# 初始化解码器模型,使用共享的嵌入层和复制后的配置
self.decoder = FlaxT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 模型调用函数,用于执行编码和解码操作
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
):
# 确定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 编码阶段(训练和第一次预测通道)
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 解码阶段
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0], # 使用编码器的隐藏状态作为输入
encoder_attention_mask=attention_mask, # 使用编码器的注意力掩码
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 如果不返回字典格式的输出,则将编码器和解码器的输出合并返回
if not return_dict:
return decoder_outputs + encoder_outputs
# 返回字典格式的输出,包括编码器和解码器的各种隐藏状态和注意力分布
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
# 使用 FlaxT5PreTrainedModel 作为基类定义 FlaxT5Model 类
class FlaxT5Model(FlaxT5PreTrainedModel):
# 将 module_class 属性设置为 FlaxT5Module
module_class = FlaxT5Module
# 调用 append_call_sample_docstring 函数,为 FlaxT5Model 类添加示例和文档字符串
append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
# 定义 FLAX_T5_MODEL_DOCSTRING 变量,包含返回值和示例的文档字符串
FLAX_T5_MODEL_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxT5Model
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = FlaxT5Model.from_pretrained("google-t5/t5-small")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="np"
... ).input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
# 调用 overwrite_call_docstring 函数,为 FlaxT5Model 类替换调用时的文档字符串
overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING)
# 调用 append_replace_return_docstrings 函数,为 FlaxT5Model 类添加返回值文档字符串
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
# 使用 add_start_docstrings 函数为 FlaxT5EncoderModule 类添加类注释和初始文档字符串
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
# 定义 FlaxT5EncoderModule 类,继承自 nn.Module
class FlaxT5EncoderModule(nn.Module):
# 包含 T5Config 类型的 config 属性和 jnp.float32 类型的 dtype 属性
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
# 定义 setup 方法,初始化模型的共享参数和编码器
def setup(self):
# 创建共享的嵌入层,使用正态分布初始化
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
# 复制配置以用于编码器,并设置相关属性
encoder_config = copy.deepcopy(self.config)
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
# 创建 FlaxT5Stack 编码器
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 定义 __call__ 方法,处理模型的正向传播
def __call__(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict: bool = True,
deterministic: bool = True,
):
# 如果需要编码(训练或第一次预测),调用编码器进行处理
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 返回编码器的输出
return encoder_outputs
# 将模块类设置为FlaxT5EncoderModule
module_class = FlaxT5EncoderModule
# 使用装饰器为__call__方法添加文档字符串,文档字符串来源于T5_ENCODE_INPUTS_DOCSTRING
@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray, # 输入的token IDs,作为JAX数组
attention_mask: Optional[jnp.ndarray] = None, # 可选的注意力遮罩,如果为None则设为全1数组
output_attentions: Optional[bool] = None, # 是否输出注意力权重,如果为None则使用self.config.output_attentions
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,如果为None则使用self.config.output_hidden_states
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,如果为None则使用self.config.return_dict
train: bool = False, # 是否为训练模式
params: dict = None, # 参数字典,默认为None
dropout_rng: PRNGKey = None, # dropout的随机数生成器,如果为None则表示不使用dropout
):
# 确定是否输出注意力权重
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 确定是否输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 确定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 准备编码器的输入
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids) # 如果注意力遮罩为None,则全设为1的数组
# 处理可能存在的任何PRNG(伪随机数生成器)
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 调用self.module的apply方法,对输入参数进行编码
return self.module.apply(
{"params": params or self.params}, # 模型参数字典
input_ids=jnp.array(input_ids, dtype="i4"), # token IDs转换为JAX数组,数据类型为32位整数
attention_mask=jnp.array(attention_mask, dtype="i4"), # 注意力遮罩转换为JAX数组,数据类型为32位整数
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式的输出
deterministic=not train, # 是否确定性运行,即非训练模式
rngs=rngs, # PRNG(伪随机数生成器)字典
)
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
# 返回编码器模块
return self.encoder
def _get_decoder_module(self):
# 返回解码器模块
return self.decoder
def setup(self):
self.model_dim = self.config.d_model
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 初始化编码器模型
self.encoder = FlaxT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
# 初始化解码器模型
self.decoder = FlaxT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 初始化语言模型头部
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
# T5 条件生成模型的调用函数
# 参数说明:
# input_ids: 输入序列的 token IDs
# attention_mask: 注意力遮罩,指示哪些位置是 padding 的
# decoder_input_ids: 解码器的输入 token IDs
# decoder_attention_mask: 解码器的注意力遮罩
# encoder_outputs: 编码器的输出
# output_attentions: 是否输出注意力权重
# output_hidden_states: 是否输出隐藏状态
# return_dict: 是否返回字典格式的输出
# deterministic: 是否使用确定性推断
# 函数主体根据输入参数执行条件生成任务,输出生成的结果
):
# 如果 return_dict 参数为 None,则根据配置决定是否使用默认值 self.config.use_return_dict
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 编码阶段
encoder_outputs = self.encoder(
input_ids=input_ids, # 输入的 token IDs
attention_mask=attention_mask, # 注意力掩码,指示哪些位置是有效的
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式的输出
deterministic=deterministic, # 是否确定性计算
)
hidden_states = encoder_outputs[0] # 获取编码器的隐藏状态
# 解码阶段
decoder_outputs = self.decoder(
input_ids=decoder_input_ids, # 解码器的输入 token IDs
attention_mask=decoder_attention_mask, # 解码器的注意力掩码
encoder_hidden_states=hidden_states, # 编码器的隐藏状态,作为解码器的输入
encoder_attention_mask=attention_mask, # 编码器的注意力掩码,用于解码器的注意力机制
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式的输出
deterministic=deterministic, # 是否确定性计算
)
sequence_output = decoder_outputs[0] # 获取解码器的输出序列
if self.config.tie_word_embeddings:
# 如果配置中指定共享词嵌入,则在投影到词汇表之前进行输出的重新缩放
# 参考:https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
if self.config.tie_word_embeddings:
# 如果配置中指定共享词嵌入,则从共享的变量中获取嵌入层参数,并应用于 lm_head
shared_embedding = self.shared.variables["params"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
else:
lm_logits = self.lm_head(sequence_output) # 否则直接将输出序列传递给 lm_head
if not return_dict:
# 如果不需要返回字典格式的输出,则返回一组元组
return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
# 返回 FlaxSeq2SeqLMOutput 类型的对象,包含详细的输出信息
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
module_class = FlaxT5ForConditionalGenerationModule
@add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=T5Config)
# 定义解码方法,用于生成模型输出
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 准备用于生成的输入数据,返回生成所需的上下文和注意力掩码
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
# 初始化缓存,准备生成所需的过去键值
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# 创建扩展的注意力掩码,用于遮蔽输入之外的位置
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
extended_attention_mask = jax.lax.dynamic_update_slice(
extended_attention_mask, decoder_attention_mask, (0, 0)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
}
# 更新用于生成的输入,将模型输出的过去键值更新到输入参数中
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
return model_kwargs
FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
返回:
生成模型的输出结果。
示例:
```
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
>>> # 生成摘要
>>> summary_ids = model.generate(inputs["input_ids"]).sequences
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
"""
overwrite_call_docstring(
# 导入FlaxT5ForConditionalGeneration类,并合并文档字符串常量T5_INPUTS_DOCSTRING和FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
# 将文档字符串追加到指定类的方法中,并替换已有的文档字符串(如果存在),然后返回文档字符串修饰后的方法。
append_replace_return_docstrings(
FlaxT5ForConditionalGeneration, # 将文档字符串添加到 FlaxT5ForConditionalGeneration 类中的方法
output_type=FlaxSeq2SeqLMOutput, # 指定输出类型为 FlaxSeq2SeqLMOutput
config_class=_CONFIG_FOR_DOC # 使用 _CONFIG_FOR_DOC 配置类
)
.\models\t5\modeling_t5.py
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model."""
import copy
import math
import os
import warnings
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
replace_return_docstrings,
)
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
####################################################
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google-t5/t5-small",
"google-t5/t5-base",
"google-t5/t5-large",
"google-t5/t5-3b",
"google-t5/t5-11b",
# See all T5 models at https://huggingface.co/models?filter=t5
]
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# 从 TensorFlow 模型加载权重
# 获取 TensorFlow 模型中所有变量的名称和形状
init_vars = tf.train.list_variables(tf_path)
# 初始化空列表,用于存储变量名称
names = []
# 初始化空字典,用于存储 TensorFlow 权重数组
tf_weights = {}
# 遍历 TensorFlow 模型中的每个变量名和形状
for name, shape in init_vars:
# 记录日志,显示当前加载的 TensorFlow 权重的名称和形状
logger.info(f"Loading TF weight {name} with shape {shape}")
# 使用 TensorFlow API 加载指定名称的变量数据
array = tf.train.load_variable(tf_path, name)
# 将当前变量名称添加到名称列表中
names.append(name)
# 将加载的 TensorFlow 权重数据存储到字典中,以变量名称作为键
tf_weights[name] = array
# 记录日志,显示未复制到 PyTorch 模型的 TensorFlow 权重的名称列表
logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
# 返回 PyTorch 模型对象
return model
####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
####################################################
# 定义了一个原始字符串常量,用于并行处理和取消并行处理模型时的文档字符串说明
PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
it will evenly distribute blocks across all devices.
Args:
device_map (`Dict[int, list]`, optional, defaults to None):
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
automatically mapped to the first device (for esoteric reasons). That means that the first device should
have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
following number of attention modules:
- google-t5/t5-small: 6
- google-t5/t5-base: 12
- google-t5/t5-large: 24
- google-t5/t5-3b: 24
- google-t5/t5-11b: 24
Example:
```
# Here is an example of a device map on a machine with 4 GPUs using google-t5/t5-3b, which has a total of 24 attention modules:
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
device_map = {
0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23],
}
model.parallelize(device_map)
```
"""
# 定义了一个原始字符串常量,用于取消模型并行处理时的文档字符串说明
DEPARALLELIZE_DOCSTRING = r"""
Moves the model to cpu from a model parallel state.
Example:
```
# On a 4 GPU machine with google-t5/t5-3b:
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
device_map = {
0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23],
}
model.parallelize(device_map) # Splits the model across several devices
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
```
"""
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # 初始化权重参数为1,用于层归一化
self.variance_epsilon = eps # 初始化方差 epsilon 参数
def forward(self, hidden_states):
# 计算隐藏状态的方差,转换为 float32 类型,然后沿着最后一个维度计算均值
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
# 使用 rsqrt 函数计算标准差的倒数,对隐藏状态进行 layer normalization
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# 如果权重的数据类型是半精度浮点数(float16 或 bfloat16),则将隐藏状态转换为相同的数据类型
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
# 返回经过权重调整后的隐藏状态
return self.weight * hidden_states
try:
# 尝试导入来自apex.normalization的FusedRMSNorm模块
from apex.normalization import FusedRMSNorm
# 将FusedRMSNorm赋值给T5LayerNorm,并禁止flake8检查
T5LayerNorm = FusedRMSNorm # noqa
# 打印信息日志,表明发现了apex.normalization.FusedRMSNorm,将使用它代替T5LayerNorm
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
# 如果导入失败,则使用普通的T5LayerNorm
pass
except Exception:
# 如果导入过程中出现任何异常,则记录警告日志
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass
# 将T5LayerNorm添加到ALL_LAYERNORM_LAYERS列表中
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
# 初始化权重为config.d_model到config.d_ff的线性层,没有偏置
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
# 初始化权重为config.d_ff到config.d_model的线性层,没有偏置
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
# 根据config.dropout_rate初始化Dropout层
self.dropout = nn.Dropout(config.dropout_rate)
# 根据配置选择激活函数,存储在self.act中
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
# 输入hidden_states经过self.wi线性层
hidden_states = self.wi(hidden_states)
# 使用self.act激活函数处理hidden_states
hidden_states = self.act(hidden_states)
# 对hidden_states应用Dropout
hidden_states = self.dropout(hidden_states)
# 如果self.wo.weight是Tensor类型,并且hidden_states的dtype不等于self.wo.weight的dtype,并且self.wo.weight的dtype不是torch.int8
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
# 将hidden_states转换到self.wo.weight的dtype
hidden_states = hidden_states.to(self.wo.weight.dtype)
# 输入hidden_states经过self.wo线性层
hidden_states = self.wo(hidden_states)
# 返回处理后的hidden_states
return hidden_states
class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
# 初始化两个权重为config.d_model到config.d_ff的线性层,没有偏置
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
# 初始化权重为config.d_ff到config.d_model的线性层,没有偏置
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
# 根据config.dropout_rate初始化Dropout层
self.dropout = nn.Dropout(config.dropout_rate)
# 根据配置选择激活函数,存储在self.act中
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
# 输入hidden_states经过self.wi_0线性层后使用self.act激活函数处理
hidden_gelu = self.act(self.wi_0(hidden_states))
# 输入hidden_states经过self.wi_1线性层
hidden_linear = self.wi_1(hidden_states)
# 将hidden_gelu和hidden_linear相乘得到hidden_states
hidden_states = hidden_gelu * hidden_linear
# 对hidden_states应用Dropout
hidden_states = self.dropout(hidden_states)
# 为了使8位量化在google/flan-t5-xxl中起作用,self.wo被保持为float32。
# 参见https://github.com/huggingface/transformers/issues/20287
# 同时确保权重不是`int8`,以防用户强制将`_keep_in_fp32_modules`设为`None`
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
# 将hidden_states转换到self.wo.weight的dtype
hidden_states = hidden_states.to(self.wo.weight.dtype)
# 输入hidden_states经过self.wo线性层
hidden_states = self.wo(hidden_states)
# 返回处理后的hidden_states
return hidden_states
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
# 如果config.is_gated_act为True,则使用T5DenseGatedActDense,否则使用T5DenseActDense
if config.is_gated_act:
self.DenseReluDense = T5DenseGatedActDense(config)
else:
self.DenseReluDense = T5DenseActDense(config)
# 初始化Layer Norm层,参数为config.d_model和config.layer_norm_epsilon
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
# 根据config.dropout_rate初始化Dropout层
self.dropout = nn.Dropout(config.dropout_rate)
#`
# 定义一个前向传播函数,接受隐藏状态作为输入
def forward(self, hidden_states):
# 对输入的隐藏状态进行层归一化处理
forwarded_states = self.layer_norm(hidden_states)
# 将归一化后的隐藏状态输入到一个全连接层+ReLU激活函数+全连接层的组合中
forwarded_states = self.DenseReluDense(forwarded_states)
# 对第二个全连接层的输出进行dropout处理,并将结果加回到原始的隐藏状态中
hidden_states = hidden_states + self.dropout(forwarded_states)
# 返回更新后的隐藏状态作为输出
return hidden_states
# 定义一个名为 T5Attention 的类,继承自 nn.Module,表示它是一个PyTorch模型组件
class T5Attention(nn.Module):
# 构造方法,初始化注意力机制的各种参数和组件
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
# 是否为解码器
self.is_decoder = config.is_decoder
# 是否包含相对注意力偏置
self.has_relative_attention_bias = has_relative_attention_bias
# 相对注意力的桶数
self.relative_attention_num_buckets = config.relative_attention_num_buckets
# 相对注意力的最大距离
self.relative_attention_max_distance = config.relative_attention_max_distance
# 模型的维度
self.d_model = config.d_model
# 键值投影的维度
self.key_value_proj_dim = config.d_kv
# 注意力头的数量
self.n_heads = config.num_heads
# Dropout率
self.dropout = config.dropout_rate
# 内部维度,即注意力头数乘以键值投影维度
self.inner_dim = self.n_heads * self.key_value_proj_dim
# 初始化注意力计算的线性层,用于查询、键、值和输出
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
# 如果有相对注意力偏置,则初始化相对注意力偏置的嵌入层
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
# 初始化一个集合,用于存储被剪枝的注意力头的索引
self.pruned_heads = set()
# 是否启用梯度检查点
self.gradient_checkpointing = False
# 方法:剪枝指定的注意力头
def prune_heads(self, heads):
# 如果没有要剪枝的头,则直接返回
if len(heads) == 0:
return
# 找到可剪枝的注意力头和对应的索引
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
# 剪枝线性层
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
self.v = prune_linear_layer(self.v, index)
self.o = prune_linear_layer(self.o, index, dim=1)
# 更新超参数
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.key_value_proj_dim * self.n_heads
# 将剪枝的头添加到集合中
self.pruned_heads = self.pruned_heads.union(heads)
# 静态方法,用于其它辅助功能或算法的实现,这里没有具体实现给出
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor - the difference in positions between memory and query
bidirectional: a boolean - whether the attention is bidirectional or not
num_buckets: an integer - number of buckets to categorize relative positions into
max_distance: an integer - maximum distance for categorizing relative positions
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
# Initialize relative_buckets to 0
relative_buckets = 0
# Adjust num_buckets if bidirectional is True
if bidirectional:
num_buckets //= 2
# Calculate relative_buckets based on whether relative_position > 0
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
# Take absolute value of relative_position
relative_position = torch.abs(relative_position)
else:
# Set relative_position to negative of its minimum value or 0
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# Determine if relative_position is small (less than max_exact)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# Calculate relative_position_if_large for larger relative positions
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
# Clamp relative_position_if_large to num_buckets - 1
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
# Determine final relative_buckets using conditional assignment based on is_small
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
# 如果未提供设备,则使用相对注意力偏置权重的设备
if device is None:
device = self.relative_attention_bias.weight.device
# 创建一个张量,表示查询序列的位置索引,形状为 (query_length, 1)
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
# 创建一个张量,表示键序列的位置索引,形状为 (1, key_length)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
# 计算相对位置,形状为 (query_length, key_length)
relative_position = memory_position - context_position
# 将相对位置映射到相对位置桶中,形状仍为 (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
# 使用相对位置桶来获取相对注意力偏置值,形状为 (query_length, key_length, num_heads)
values = self.relative_attention_bias(relative_position_bucket)
# 对值张量进行维度置换和扩展,形状为 (1, num_heads, query_length, key_length)
values = values.permute([2, 0, 1]).unsqueeze(0)
# 返回计算得到的相对位置偏置值
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
# 定义 T5 模型的自注意力层
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
# 初始化自注意力机制
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
# 初始化层归一化
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
# 初始化 dropout
self.dropout = nn.Dropout(config.dropout_rate)
# 前向传播函数,接受一些参数和张量 hidden_states
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
# 对输入的 hidden_states 进行层归一化处理
normed_hidden_states = self.layer_norm(hidden_states)
# 将归一化后的 hidden_states 输入到 SelfAttention 层
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# 将原始 hidden_states 和经过 dropout 处理后的注意力输出相加
hidden_states = hidden_states + self.dropout(attention_output[0])
# 输出包括更新后的 hidden_states 和额外的注意力信息(如果有的话)
outputs = (hidden_states,) + attention_output[1:] # 如果需要,添加注意力信息
return outputs
# 定义 T5 模型的跨注意力层
class T5LayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
# 初始化编码-解码注意力机制
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
# 初始化层归一化
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
# 初始化 dropout
self.dropout = nn.Dropout(config.dropout_rate)
# 前向传播函数,接受一些参数和张量 hidden_states,key_value_states
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
# 对输入的 hidden_states 进行层归一化处理
normed_hidden_states = self.layer_norm(hidden_states)
# 将归一化后的 hidden_states 输入到 EncDecAttention 层
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
# 将原始 hidden_states 和经过 dropout 处理后的注意力输出相加
layer_output = hidden_states + self.dropout(attention_output[0])
# 输出包括更新后的 hidden_states 和额外的注意力信息(如果有的话)
outputs = (layer_output,) + attention_output[1:] # 如果需要,添加注意力信息
return outputs
# 定义 T5 模型的块
class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
# 标记该块是否为解码器块
self.is_decoder = config.is_decoder
# 初始化层列表
self.layer = nn.ModuleList()
# 添加自注意力层到层列表
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
# 如果是解码器,添加编码-解码注意力层到层列表
if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config))
# 添加前馈神经网络层到层列表
self.layer.append(T5LayerFF(config))
# 定义 Transformer 模型的前向传播函数,接受多个参数:
# - hidden_states: 输入的隐藏状态
# - attention_mask: 可选参数,用于屏蔽不需要关注的位置
# - position_bias: 可选参数,用于位置偏置
# - encoder_hidden_states: 可选参数,编码器的隐藏状态
# - encoder_attention_mask: 可选参数,编码器的注意力屏蔽
# - encoder_decoder_position_bias: 可选参数,编码器到解码器的位置偏置
# - layer_head_mask: 可选参数,用于层头的屏蔽
# - cross_attn_layer_head_mask: 可选参数,用于交叉注意力的层头屏蔽
# - past_key_value: 可选参数,过去的键值对,用于生成缓存
# - use_cache: 是否使用缓存,默认为 False
# - output_attentions: 是否输出注意力权重,默认为 False
# - return_dict: 是否返回结果字典,默认为 True
class T5ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: T5Config):
super().__init__()
# 定义一个全连接层,输入维度为config.d_model,输出维度为config.d_model
self.dense = nn.Linear(config.d_model, config.d_model)
# 定义一个Dropout层,概率为config.classifier_dropout
self.dropout = nn.Dropout(p=config.classifier_dropout)
# 定义一个全连接层,输入维度为config.d_model,输出维度为config.num_labels
self.out_proj = nn.Linear(config.d_model, config.num_labels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 对输入的hidden_states进行dropout处理
hidden_states = self.dropout(hidden_states)
# 通过全连接层self.dense进行线性变换
hidden_states = self.dense(hidden_states)
# 对线性变换的结果进行tanh激活函数处理
hidden_states = torch.tanh(hidden_states)
# 再次对处理后的hidden_states进行dropout处理
hidden_states = self.dropout(hidden_states)
# 通过全连接层self.out_proj进行线性变换,得到最终的输出
hidden_states = self.out_proj(hidden_states)
return hidden_states
class T5PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = T5Config
load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["T5Block"]
_keep_in_fp32_modules = ["wo"]
@property
def dummy_inputs(self):
# 创建一个包含虚拟输入数据的字典
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
def _shift_right(self, input_ids):
# 获取decoder起始标记的ID和pad标记的ID
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
"See T5 docs for more information."
)
# 将输入向右移动一位
if is_torch_fx_proxy(input_ids):
# 对于代理对象,不支持原生的项目赋值操作
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
# 创建一个与input_ids形状相同的全零张量
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
# 将input_ids的内容向右移动一位,并将decoder起始标记填充到第一位
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# 将标签中可能存在的-100值替换为pad_token_id
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class T5Stack(T5PreTrainedModel):
pass
# 使用给定的配置和嵌入令牌(如果提供),初始化一个 T5Stack 对象
def __init__(self, config, embed_tokens=None):
# 调用父类的初始化方法
super().__init__(config)
# 将嵌入令牌保存到对象属性中
self.embed_tokens = embed_tokens
# 检查配置中是否设置了解码器标志,并保存到对象属性中
self.is_decoder = config.is_decoder
# 创建包含多个 T5Block 的模块列表(每个 T5Block 对象表示 T5 模型的一个层)
self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
)
# 初始化最终层归一化对象,用于处理模型的输出
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
# 创建一个 dropout 层,用于随机失活
self.dropout = nn.Dropout(config.dropout_rate)
# 初始化权重并进行最终的处理
self.post_init()
# Model parallel (模型并行设置)
self.model_parallel = False # 默认情况下不使用模型并行
self.device_map = None # 设备映射初始化为 None
self.gradient_checkpointing = False # 梯度检查点设置为 False
@add_start_docstrings(PARALLELIZE_DOCSTRING)
# 并行化方法,用于将模型放置到多个设备上
def parallelize(self, device_map=None):
# 发出警告,表明此方法即将被移除
warnings.warn(
"`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
# 检查设备映射的有效性
self.device_map = (
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
)
# 确保设备映射与层的数量匹配
assert_device_map(self.device_map, len(self.block))
# 标记模型已启用模型并行
self.model_parallel = True
# 确定第一个设备和最后一个设备的名称
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
# 将每个层移动到对应的设备
for k, v in self.device_map.items():
for layer in v:
cuda_device = "cuda:" + str(k)
self.block[layer] = self.block[layer].to(cuda_device)
# 将嵌入令牌移到第一个设备
self.embed_tokens = self.embed_tokens.to(self.first_device)
# 将最终层归一化移到最后一个设备
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
# 反并行化方法,用于将模型从多个设备恢复到单设备
def deparallelize(self):
# 发出警告,表明此方法即将被移除
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 将模型并行标志设置为 False
self.model_parallel = False
# 设备映射置为 None
self.device_map = None
# 将第一个设备和最后一个设备都设置为 "cpu"
self.first_device = "cpu"
self.last_device = "cpu"
# 将每个层移回到 CPU
for i in range(len(self.block)):
self.block[i] = self.block[i].to("cpu")
# 将嵌入令牌和最终层归一化层移回到 CPU
self.embed_tokens = self.embed_tokens.to("cpu")
self.final_layer_norm = self.final_layer_norm.to("cpu")
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 获取输入嵌入层对象
def get_input_embeddings(self):
return self.embed_tokens
# 设置输入嵌入层对象
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
# 定义模型的前向传播方法,接受多个输入参数,用于处理输入序列的各种信息
def forward(
self,
input_ids=None, # 输入的 token IDs,用于表示输入序列
attention_mask=None, # 注意力遮罩,指定哪些位置需要参与注意力计算
encoder_hidden_states=None, # 编码器隐藏状态,用于某些模型的特定任务
encoder_attention_mask=None, # 编码器注意力遮罩,指定哪些编码器隐藏状态需要注意
inputs_embeds=None, # 输入的嵌入表示,代替 input_ids 使用
head_mask=None, # 头部遮罩,用于指定哪些注意力头部需要被屏蔽
cross_attn_head_mask=None, # 跨注意力头部遮罩,类似于 head_mask,但用于跨注意力机制
past_key_values=None, # 过去的键-值对,用于支持增量式生成的情况
use_cache=None, # 是否使用缓存,用于存储中间计算结果以加速推理
output_attentions=None, # 是否输出注意力权重
output_hidden_states=None, # 是否输出所有隐藏状态
return_dict=None, # 是否返回一个字典作为输出
# T5 模型的文档字符串,用于说明该模型的提出背景和特性
T5_START_DOCSTRING = r"""
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
text-to-text denoising generative setting.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`T5Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# T5 模型的输入文档字符串,暂未提供具体内容,保留空字符串
T5_INPUTS_DOCSTRING = r"""
"""
# T5 编码器输入的文档字符串,暂未提供具体内容,保留空字符串
T5_ENCODER_INPUTS_DOCSTRING = r"""
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
# 输入序列的标记索引,在词汇表中。对于 T5 模型,相对位置嵌入使得可以在输入的左右两侧进行填充。
# 可以使用 `AutoTokenizer` 获取这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
# 如何准备 `input_ids` 进行预训练,请查看[T5 Training](./t5#training)。
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
# 遮罩,用于避免对填充的标记索引执行注意力操作。遮罩的取值范围为 `[0, 1]`:
# - 1 表示对应的标记**未被遮罩**,
# - 0 表示对应的标记**被遮罩**。
# [什么是注意力遮罩?](../glossary#attention-mask)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
# 用于将自注意力模块中的部分头部置零的遮罩。遮罩的取值范围为 `[0, 1]`:
# - 1 表示该头部**未被遮罩**,
# - 0 表示该头部**被遮罩**。
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
# 可选参数,可以直接传递嵌入表示而不是 `input_ids`。如果希望对如何将 `input_ids` 索引转换为相关联向量有更多控制权,那么这很有用。
# 这对于超越模型内部嵌入查找矩阵有更多控制的情况很有用。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。查看返回的张量中的 `attentions` 以获取更多细节。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。查看返回的张量中的 `hidden_states` 以获取更多细节。
return_dict (`bool`, *optional*):
# 是否返回一个 `~utils.ModelOutput` 而不是一个普通的元组。
"""
# 警告消息,用于将来的警告:head_mask 参数已分成两个输入参数 - head_mask 和 decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class T5Model(T5PreTrainedModel):
# 在模型加载时忽略的意外键列表
_keys_to_ignore_on_load_unexpected = [
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
# 共享权重的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
# 创建共享的嵌入层,用于处理词汇大小和模型维度的embedding
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置并设置编码器
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
# 复制配置并设置解码器
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行计算标志
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# 发出警告,提醒方法即将被弃用
warnings.warn(
"`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
" 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
# 根据传入的设备映射或自动生成平衡设备映射
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 检查设备映射的有效性
assert_device_map(self.device_map, len(self.encoder.block))
# 对编码器和解码器进行并行化处理
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
# 发出一个关于函数过时的警告,提示使用者此功能将在 Transformers 的 v5 版本中移除
def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 调用编码器对象的 deparallelize 方法,取消并行化设置
self.encoder.deparallelize()
# 调用解码器对象的 deparallelize 方法,取消并行化设置
self.decoder.deparallelize()
# 将编码器移动到 CPU 上执行
self.encoder = self.encoder.to("cpu")
# 将解码器移动到 CPU 上执行
self.decoder = self.decoder.to("cpu")
# 禁用模型并行化设置
self.model_parallel = False
# 将设备映射设置为空
self.device_map = None
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 获取输入嵌入层对象的方法
def get_input_embeddings(self):
return self.shared
# 设置输入嵌入层对象的方法
def set_input_embeddings(self, new_embeddings):
# 更新共享的嵌入层对象
self.shared = new_embeddings
# 更新编码器的输入嵌入层对象
self.encoder.set_input_embeddings(new_embeddings)
# 更新解码器的输入嵌入层对象
self.decoder.set_input_embeddings(new_embeddings)
# 内部方法,用于绑定权重(如果配置允许)
def _tie_weights(self):
if self.config.tie_word_embeddings:
# 绑定或克隆编码器的词嵌入权重与共享的嵌入层对象
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 绑定或克隆解码器的词嵌入权重与共享的嵌入层对象
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 获取编码器对象的方法
def get_encoder(self):
return self.encoder
# 获取解码器对象的方法
def get_decoder(self):
return self.decoder
# 内部方法,用于剪枝模型的注意力头
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 遍历需要剪枝的层及其对应的头信息
for layer, heads in heads_to_prune.items():
# 对编码器的某一层的注意力模型进行头剪枝操作
self.encoder.layer[layer].attention.prune_heads(heads)
# 此函数用于模型的前向传播
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
# 忽略加载时不期望的键列表
_keys_to_ignore_on_load_unexpected = [
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
# 被绑定权重的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
# 模型维度
self.model_dim = config.d_model
# 共享的嵌入层,用于输入词汇表大小和模型维度
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制编码器配置,并设置为非解码器模式
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 创建编码器实例
self.encoder = T5Stack(encoder_config, self.shared)
# 复制解码器配置,并设置为解码器模式
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
# 创建解码器实例
self.decoder = T5Stack(decoder_config, self.shared)
# 线性层,用于语言模型的输出,输入维度为模型维度,输出维度为词汇表大小,无偏置
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行化
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# 发出警告,此方法即将弃用
warnings.warn(
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
# 获取设备映射,如果未提供则使用均衡映射
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 验证设备映射的有效性
assert_device_map(self.device_map, len(self.encoder.block))
# 将编码器并行化
self.encoder.parallelize(self.device_map)
# 将解码器并行化
self.decoder.parallelize(self.device_map)
# 将语言模型头移到解码器的第一个设备上
self.lm_head = self.lm_head.to(self.decoder.first_device)
# 设置模型为模型并行化状态
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
# 发出警告,此方法即将弃用
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 取消编码器的并行化
self.encoder.deparallelize()
# 取消解码器的并行化
self.decoder.deparallelize()
# 将编码器移到CPU
self.encoder = self.encoder.to("cpu")
# 将解码器移到CPU
self.decoder = self.decoder.to("cpu")
# 将语言模型头移到CPU
self.lm_head = self.lm_head.to("cpu")
# 设置模型为非模型并行化状态
self.model_parallel = False
self.device_map = None
# 清空CUDA缓存
torch.cuda.empty_cache()
def get_input_embeddings(self):
# 返回共享的嵌入层
return self.shared
# 设置模型的输入词嵌入
def set_input_embeddings(self, new_embeddings):
# 将新的词嵌入赋给共享的嵌入层
self.shared = new_embeddings
# 设置编码器的输入词嵌入
self.encoder.set_input_embeddings(new_embeddings)
# 设置解码器的输入词嵌入
self.decoder.set_input_embeddings(new_embeddings)
# 绑定权重(或克隆)以确保编码器和解码器共享相同的词嵌入
def _tie_weights(self):
if self.config.tie_word_embeddings:
# 绑定或克隆编码器的嵌入层与共享的嵌入层
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 绑定或克隆解码器的嵌入层与共享的嵌入层
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 设置模型的输出词嵌入
def set_output_embeddings(self, new_embeddings):
# 设置语言模型头部的新词嵌入
self.lm_head = new_embeddings
# 返回模型的输出词嵌入
def get_output_embeddings(self):
return self.lm_head
# 返回编码器实例
def get_encoder(self):
return self.encoder
# 返回解码器实例
def get_decoder(self):
return self.decoder
# 模型前向传播方法,用于执行T5模型的输入到输出的转换
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 省略T5模型的前向传播逻辑,由装饰器管理
# 准备用于生成的输入,这里主要用于生成文本
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# 省略准备生成文本输入的逻辑,可以传递各种参数给模型
):
# 如果使用了过去的键值(past_key_values),则裁剪decoder_input_ids
if past_key_values is not None:
# 获取过去键值的长度
past_length = past_key_values[0][0].shape[2]
# 一些生成方法已经只传递最后一个输入ID
if input_ids.shape[1] > past_length:
# 如果输入的input_ids长度大于过去的长度,裁剪掉前面的部分
remove_prefix_length = past_length
else:
# 默认旧的行为:保留最后一个ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids, # 返回裁剪后的decoder_input_ids
"past_key_values": past_key_values, # 返回过去的键值
"encoder_outputs": encoder_outputs, # 返回编码器的输出
"attention_mask": attention_mask, # 返回注意力掩码
"head_mask": head_mask, # 返回头部掩码
"decoder_head_mask": decoder_head_mask, # 返回解码器头部掩码
"decoder_attention_mask": decoder_attention_mask, # 返回解码器的注意力掩码
"cross_attn_head_mask": cross_attn_head_mask, # 返回交叉注意力头部掩码
"use_cache": use_cache, # 返回是否使用缓存的标志
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
# 根据标签准备解码器的输入ids,将标签向右移动一位
return self._shift_right(labels)
def _reorder_cache(self, past_key_values, beam_idx):
# 如果解码器的过去状态未包含在输出中
# 快速解码被禁用,无需重新排序
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
# 从层过去状态中获取正确的批次索引
# past的批次维度在第二个位置
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# 需要为每个四个键/值状态设置正确的past
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
# 添加模型的文档字符串,描述了这个类是一个 T5 编码器模型,输出编码器的原始隐藏状态而不带任何特定的头部
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
# 在加载模型时需要保持权重一致的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight"]
# 加载时需要忽略的意外键列表,这里排除了包含"decoder"的键
_keys_to_ignore_on_load_unexpected = [r"decoder"]
def __init__(self, config: T5Config):
super().__init__(config)
# 共享的嵌入层,根据配置创建一个词汇表大小为config.vocab_size,维度为config.d_model的嵌入层
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置以配置编码器,并设置一些属性
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 创建 T5 堆栈编码器,使用共享的嵌入层
self.encoder = T5Stack(encoder_config, self.shared)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行处理相关属性初始化
self.model_parallel = False
self.device_map = None
# 添加模型并行化的文档字符串,警告此方法在后续版本中将被移除
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
warnings.warn(
"`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
# 获取设备映射,如果未提供设备映射,则默认使用均衡的设备映射
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 检查设备映射的合法性
assert_device_map(self.device_map, len(self.encoder.block))
# 调用编码器的并行化方法,设置模型并行为真
self.encoder.parallelize(self.device_map)
self.model_parallel = True
# 添加反并行化的文档字符串,警告此方法在后续版本中将被移除
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 调用编码器的反并行化方法,将编码器转移到 CPU 上,并设置模型并行为假
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 返回共享的嵌入层
def get_input_embeddings(self):
return self.shared
# 设置新的输入嵌入层,并更新编码器的输入嵌入层
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
# 绑定权重的私有方法,如果配置中要求绑定词嵌入权重,则绑定编码器的嵌入词汇表和共享的嵌入层
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 返回编码器对象
def get_encoder(self):
return self.encoder
# 剪枝模型中的注意力头,heads_to_prune 是一个字典,表示需要在每层剪枝的注意力头的列表
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
# 调用编码器堆栈中每层的自注意力模块的剪枝头方法
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
# 添加 T5 编码器模型前向传播的文档字符串
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
# 将函数返回值的文档字符串中的输出类型替换为BaseModelOutput,配置类为_CONFIG_FOR_DOC
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
# 前向传播函数,接受多个输入参数,并返回Union类型的torch.FloatTensor或BaseModelOutput
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # 输入的token IDs,类型为可选的长整型张量
attention_mask: Optional[torch.FloatTensor] = None, # 注意力遮罩张量,类型为可选的浮点数张量
head_mask: Optional[torch.FloatTensor] = None, # 头部遮罩张量,类型为可选的浮点数张量
inputs_embeds: Optional[torch.FloatTensor] = None, # 嵌入输入张量,类型为可选的浮点数张量
output_attentions: Optional[bool] = None, # 是否输出注意力张量,类型为可选的布尔值
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态张量,类型为可选的布尔值
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,类型为可选的布尔值
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
返回值:
如果return_dict不为None,则返回return_dict;否则返回self.config.use_return_dict。
示例:
```
>>> from transformers import AutoTokenizer, T5EncoderModel
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用编码器模型处理输入,并获取编码器的输出结果
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 返回编码器的输出结果
return encoder_outputs
"""
T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
"""
# 定义 T5 序列分类模型,其顶部有一个线性层(位于汇聚输出之上),用于例如 GLUE 任务
@add_start_docstrings(
"""
T5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
T5_START_DOCSTRING,
)
class T5ForTokenClassification(T5PreTrainedModel):
# 指定权重共享的关键键列表
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
# 初始化 T5 配置
self.num_labels = config.num_labels
# 创建 T5 编码器模型
self.transformer = T5EncoderModel(config)
# 添加一个丢弃层
self.dropout = nn.Dropout(config.classifier_dropout)
# 添加一个线性分类器层
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并执行最终处理
self.post_init()
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None, # 输入的token IDs张量,可选
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码张量,可选
head_mask: Optional[torch.Tensor] = None, # 头部掩码张量,可选
inputs_embeds: Optional[torch.Tensor] = None, # 输入的嵌入张量,可选
labels: Optional[torch.Tensor] = None, # 用于计算标记分类损失的标签张量,可选
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,可选
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
# 确定是否使用配置中的返回字典设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 将输入传递给transformer模型,并获取输出
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从输出中获取隐藏状态并应用dropout
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
# 将隐藏状态传递给分类器,获取预测的逻辑回归输出
logits = self.classifier(hidden_states)
# 如果提供了标签,则计算损失
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# 根据return_dict标志,返回不同的输出格式
if not return_dict:
output = (logits, outputs[2:-1]) # 仅在不返回字典时输出隐藏状态
return ((loss,) + output) if loss is not None else output
# 返回TokenClassifierOutput对象,包含损失、预测的逻辑回归、隐藏状态和注意力权重
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 使用特定的文档字符串初始化模型,该模型包含一个用于提取问答任务(如SQuAD)的跨度分类头部(在隐藏状态输出之上的线性层,用于计算“跨度起始logits”和“跨度结束logits”)。
@add_start_docstrings(
"""
T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
T5_START_DOCSTRING,
)
class T5ForQuestionAnswering(T5PreTrainedModel):
# 在加载时忽略的键列表,遇到不期待的键
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
# 权重共享的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
# 初始化方法,接受一个T5Config对象作为参数
def __init__(self, config: T5Config):
super().__init__(config)
# 模型维度设为配置文件中的d_model值
self.model_dim = config.d_model
# 共享的词嵌入层,使用配置文件中的vocab_size和d_model创建
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置以创建编码器
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 使用T5Stack模块和共享的词嵌入层创建编码器
self.encoder = T5Stack(encoder_config, self.shared)
# 复制配置以创建解码器
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
# 使用T5Stack模块和共享的词嵌入层创建解码器
self.decoder = T5Stack(decoder_config, self.shared)
# 输出的标签数量为配置文件中的num_labels
self.num_labels = config.num_labels
# 线性层,将隐藏大小映射到标签数量
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并进行最终处理
self.post_init()
# 模型并行设置为False
self.model_parallel = False
# 返回共享的词嵌入层对象
def get_input_embeddings(self):
return self.shared
# 设置新的输入词嵌入层
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
# 同时更新编码器和解码器的输入词嵌入层
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
# 绑定权重,如果配置中设置了tie_word_embeddings为True
def _tie_weights(self):
if self.config.tie_word_embeddings:
# 绑定或克隆权重以使编码器和解码器共享词嵌入层的权重
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 返回编码器对象
def get_encoder(self):
return self.encoder
# 返回解码器对象
def get_decoder(self):
return self.decoder
# 用于模型前向传播方法的装饰器,添加了T5输入文档字符串
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
# 替换返回值文档字符串为Seq2SeqQuestionAnsweringModelOutput类型,使用配置类_CONFIG_FOR_DOC
@replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
# 定义神经网络模型的前向传播函数,接受多个可选的输入参数,类型为 PyTorch 的张量或布尔值
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # 输入文本的词编号张量,可选
attention_mask: Optional[torch.FloatTensor] = None, # 输入文本的注意力掩码张量,可选
decoder_input_ids: Optional[torch.LongTensor] = None, # 解码器输入的词编号张量,可选
decoder_attention_mask: Optional[torch.BoolTensor] = None, # 解码器的注意力掩码张量,可选
head_mask: Optional[torch.FloatTensor] = None, # 多头注意力的头部掩码张量,可选
decoder_head_mask: Optional[torch.FloatTensor] = None, # 解码器多头注意力的头部掩码张量,可选
cross_attn_head_mask: Optional[torch.Tensor] = None, # 交叉注意力的头部掩码张量,可选
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, # 编码器的输出,包含多个张量的元组,可选
start_positions: Optional[torch.LongTensor] = None, # 起始位置张量,用于损失计算,可选
end_positions: Optional[torch.LongTensor] = None, # 结束位置张量,用于损失计算,可选
inputs_embeds: Optional[torch.FloatTensor] = None, # 输入的嵌入向量张量,可选
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, # 解码器输入的嵌入向量张量,可选
use_cache: Optional[bool] = None, # 是否使用缓存,用于解码器的 Transformer 模型,可选
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选
return_dict: Optional[bool] = None, # 是否以字典形式返回输出,可选