Transformers-源码解析-一百二十-

Transformers 源码解析(一百二十)

.\models\vit_msn\__init__.py

# 导入所需模块和函数
from typing import TYPE_CHECKING
# 从当前项目的utils模块中导入异常类和LazyModule类,还有is_torch_available函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构,包含了configuration_vit_msn的两个对象
_import_structure = {"configuration_vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"]}

# 检查是否存在torch库,若不存在则引发OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果torch可用,则在_import_structure中添加modeling_vit_msn的四个对象
    _import_structure["modeling_vit_msn"] = [
        "VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ViTMSNModel",
        "ViTMSNForImageClassification",
        "ViTMSNPreTrainedModel",
    ]

# 如果是类型检查模式,导入具体的配置和模型类
if TYPE_CHECKING:
    # 从当前模块的configuration_vit_msn中导入两个对象
    from .configuration_vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig

    # 再次检查是否存在torch库,若不存在则引发OptionalDependencyNotAvailable异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果torch可用,则从当前模块的modeling_vit_msn中导入四个对象
        from .modeling_vit_msn import (
            VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST,
            ViTMSNForImageClassification,
            ViTMSNModel,
            ViTMSNPreTrainedModel,
        )

# 如果不是类型检查模式,将当前模块设置为_LazyModule的实例
else:
    import sys

    # 将当前模块设为_LazyModule的实例,用于惰性加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\vivit\configuration_vivit.py

# coding=utf-8
# 定义模块的版权信息和编码格式

# 导入必要的模块和类
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的映射字典,将模型名称映射到其配置文件的 URL
VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/vivit-b-16x2-kinetics400": (
        "https://huggingface.co/google/vivit-b-16x2-kinetics400/resolve/main/config.json"
    ),
    # 可在此处查看所有 ViViT 模型:https://huggingface.co/models?filter=vivit
}


class VivitConfig(PretrainedConfig):
    r"""
    这是用于存储 [`VivitModel`] 配置的配置类。根据指定的参数实例化配置对象,定义模型架构。
    使用默认参数实例化配置对象将产生类似于 ViViT [google/vivit-b-16x2-kinetics400]
    (https://huggingface.co/google/vivit-b-16x2-kinetics400) 架构的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
    """
    pass
    # 定义模型类型为 "vivit"
    model_type = "vivit"
    
    # 初始化函数,设置模型的各项配置参数
    def __init__(
        self,
        image_size=224,  # 图像尺寸,默认为 224
        num_frames=32,  # 每个视频的帧数,默认为 32
        tubelet_size=[2, 16, 16],  # 每个 tubelet 的尺寸,默认为 [2, 16, 16]
        num_channels=3,  # 输入通道数,默认为 3
        hidden_size=768,  # 编码器层和池化层的维度,默认为 768
        num_hidden_layers=12,  # Transformer 编码器中的隐藏层层数,默认为 12
        num_attention_heads=12,  # 每个注意力层中的注意力头数,默认为 12
        intermediate_size=3072,  # Transformer 编码器中“中间”(即前馈)层的维度,默认为 3072
        hidden_act="gelu_fast",  # 编码器和池化器中的非线性激活函数,默认为 "gelu_fast"
        hidden_dropout_prob=0.0,  # 嵌入层、编码器和池化器中全连接层的 dropout 概率,默认为 0.0
        attention_probs_dropout_prob=0.0,  # 注意力概率的 dropout 比率,默认为 0.0
        initializer_range=0.02,  # 初始化所有权重矩阵的截断正态分布的标准差,默认为 0.02
        layer_norm_eps=1e-06,  # 层归一化层使用的 epsilon,默认为 1e-06
        qkv_bias=True,  # 是否为查询、键和值添加偏置,默认为 True
        **kwargs,  # 其他关键字参数
    ):
        ):
        # 初始化模型参数
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps

        # 初始化视频特征提取器的参数
        self.image_size = image_size  # 图像大小
        self.num_frames = num_frames  # 视频帧数
        self.tubelet_size = tubelet_size  # 视频片段大小
        self.num_channels = num_channels  # 视频通道数
        self.qkv_bias = qkv_bias  # 查询、键、值的偏置项

        # 调用父类的初始化方法
        super().__init__(**kwargs)

.\models\vivit\convert_vivit_flax_to_pytorch.py

# coding=utf-8
# 声明编码格式为 UTF-8

# Copyright 2023 The HuggingFace Inc. team.
# 版权声明,版权归 The HuggingFace Inc. 团队所有

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache 许可证版本 2.0 授权使用此代码

# you may not use this file except in compliance with the License.
# 除非遵循许可证,否则不得使用此文件

# You may obtain a copy of the License at
# 可以在以下网址获取许可证的副本

#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非法律要求或书面同意,否则按"原样"分发本软件,无论是明示的还是隐含的保证,包括但不限于,适销性、特定用途的适用性和非侵权性。

# See the License for the specific language governing permissions and
# limitations under the License.
# 请查阅许可证以获取权限和限制的详细信息

"""Convert Flax ViViT checkpoints from the original repository to PyTorch. URL:
https://github.com/google-research/scenic/tree/main/scenic/projects/vivit
"""
# 转换来自原始存储库的 Flax ViViT 检查点到 PyTorch 格式的工具
# 原始存储库 URL: https://github.com/google-research/scenic/tree/main/scenic/projects/vivit

import argparse
# 导入用于命令行解析的模块

import json
# 导入处理 JSON 的模块

import os.path
# 导入处理文件路径的模块

from collections import OrderedDict
# 导入有序字典的模块

import numpy as np
# 导入处理数值计算的模块

import requests
# 导入发送 HTTP 请求的模块

import torch
# 导入 PyTorch 深度学习框架

from flax.training.checkpoints import restore_checkpoint
# 从 Flax 框架中导入恢复检查点的功能

from huggingface_hub import hf_hub_download
# 从 Hugging Face Hub 导入下载函数

from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor
# 从 Transformers 库导入 ViViT 相关组件

from transformers.image_utils import PILImageResampling
# 从 Transformers 库导入图像处理的模块


def download_checkpoint(path):
    # 定义下载检查点文件的函数,参数为保存路径 `path`

    url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint"
    # 指定检查点文件的下载 URL

    with open(path, "wb") as f:
        # 以二进制写入模式打开文件 `path`

        with requests.get(url, stream=True) as req:
            # 发起带有流式传输的 GET 请求

            for chunk in req.iter_content(chunk_size=2048):
                # 遍历请求的数据块,每次处理大小为 2048 字节

                f.write(chunk)
                # 将数据块写入文件


def get_vivit_config() -> VivitConfig:
    # 定义获取 ViViT 配置的函数,返回类型为 VivitConfig

    config = VivitConfig()
    # 创建 ViViT 配置对象

    config.num_labels = 400
    # 设置标签数量为 400

    repo_id = "huggingface/label-files"
    # 定义标签文件的存储库 ID
    filename = "kinetics400-id2label.json"
    # 定义标签文件的名称

    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
    # 从 Hugging Face Hub 下载标签文件,加载为字典格式
    id2label = {int(k): v for k, v in id2label.items()}
    # 将标签的键转换为整数类型
    config.id2label = id2label
    # 将 id2label 字典赋值给配置对象的 id2label 属性

    config.label2id = {v: k for k, v in id2label.items()}
    # 构建标签到 id 的反向映射字典

    return config
    # 返回配置对象


# We will verify our results on a video of eating spaghetti
# Frame indices used: [ 47, 51, 55, 59, 63, 67, 71, 75, 80, 84, 88, 92, 96, 100, 104, 108, 113, 117,
# 121, 125, 129, 133, 137, 141, 146, 150, 154, 158, 162, 166, 170, 174]
def prepare_video():
    # 定义准备视频数据的函数

    file = hf_hub_download(
        repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset"
    )
    # 从 Hugging Face Hub 下载具有吃意大利面视频帧的 NumPy 文件

    video = np.load(file)
    # 加载视频文件为 NumPy 数组

    return list(video)
    # 将视频数据转换为列表并返回


def transform_attention(current: np.ndarray):
    # 定义处理注意力张量的函数,参数为当前张量 `current`

    if np.ndim(current) == 2:
        # 如果张量是二维的

        return transform_attention_bias(current)
        # 调用处理注意力偏置的函数并返回结果

    elif np.ndim(current) == 3:
        # 如果张量是三维的

        return transform_attention_kernel(current)
        # 调用处理注意力核心的函数并返回结果

    else:
        # 如果张量维度不符合预期

        raise Exception(f"Invalid number of dimesions: {np.ndim(current)}")
        # 抛出异常,指示张量维度无效


def transform_attention_bias(current: np.ndarray):
    # 定义处理注意力偏置的函数,参数为当前偏置 `current`

    return current.flatten()
    # 将偏置展平并返回结果


def transform_attention_kernel(current: np.ndarray):
    # 定义处理注意力核心的函数,参数为当前核心 `current`

    return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T
    # 调整核心张量的形状并返回转置结果


def transform_attention_output_weight(current: np.ndarray):
    # 定义处理注意力输出权重的函数,参数为当前权重 `current`
    # 将当前数组 `current` 重新整形为二维数组,行数为原数组行数乘以列数,列数不变
    return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T
# 根据给定索引 i,从状态字典中获取 Transformer 模型的第 i 个编码器块的状态
def transform_state_encoder_block(state_dict, i):
    state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"]

    # 构建当前编码器块在模型状态字典中的前缀
    prefix = f"encoder.layer.{i}."

    # 创建新的状态字典,将原始状态字典中的数据按指定格式进行转换和重组
    new_state = {
        prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"],
        prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]),
        prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"],
        prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]),
        prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"],
        prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"],
        prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"],
        prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"],
        prefix + "attention.attention.query.bias": transform_attention(
            state["MultiHeadDotProductAttention_0"]["query"]["bias"]
        ),
        prefix + "attention.attention.query.weight": transform_attention(
            state["MultiHeadDotProductAttention_0"]["query"]["kernel"]
        ),
        prefix + "attention.attention.key.bias": transform_attention(
            state["MultiHeadDotProductAttention_0"]["key"]["bias"]
        ),
        prefix + "attention.attention.key.weight": transform_attention(
            state["MultiHeadDotProductAttention_0"]["key"]["kernel"]
        ),
        prefix + "attention.attention.value.bias": transform_attention(
            state["MultiHeadDotProductAttention_0"]["value"]["bias"]
        ),
        prefix + "attention.attention.value.weight": transform_attention(
            state["MultiHeadDotProductAttention_0"]["value"]["kernel"]
        ),
        prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"],
        prefix + "attention.output.dense.weight": transform_attention_output_weight(
            state["MultiHeadDotProductAttention_0"]["out"]["kernel"]
        ),
    }

    return new_state


# 获取给定状态字典中的 Transformer 模型的编码器块总数
def get_n_layers(state_dict):
    # 使用列表推导计算包含字符串 "encoderblock_" 的键的数量
    return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"].keys()])


# 转换整个状态字典,根据需要添加分类头部分
def transform_state(state_dict, classification_head=False):
    # 获取 Transformer 模型中的编码器块总数
    transformer_layers = get_n_layers(state_dict)

    # 创建一个有序字典用于存储新的状态数据
    new_state = OrderedDict()

    # 转换编码器归一化层的偏置和权重
    new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"]
    new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"]

    # 转换嵌入层的投影权重和偏置
    new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose(
        state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2)
    )
    new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"]

    # 转换分类标记的嵌入向量
    new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"]

    # 返回转换后的新状态字典
    return new_state
    # 将输入状态字典中的位置嵌入张量更新到新状态字典中的指定键
    new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][
        "pos_embedding"
    ]
    
    # 遍历每个Transformer层,更新新状态字典
    for i in range(transformer_layers):
        new_state.update(transform_state_encoder_block(state_dict, i))
    
    # 如果存在分类头部,调整新状态字典的键名并更新分类器权重和偏置
    if classification_head:
        # 更新新状态字典中的键名前缀为"vivit."
        new_state = {"vivit." + k: v for k, v in new_state.items()}
        # 转置并更新分类器权重
        new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"])
        # 转置并更新分类器偏置
        new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"])
    
    # 将新状态字典中的值转换为PyTorch张量并返回
    return {k: torch.tensor(v) for k, v in new_state.items()}
# 检查图像处理器设置与原始实现是否一致
# 原始实现可以在此链接中找到:https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/data/video_tfrecord_dataset.py
# 数据集特定配置:
# https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py
def get_processor() -> VivitImageProcessor:
    # 创建 VivitImageProcessor 实例
    extractor = VivitImageProcessor()

    # 断言确保是否执行了图像大小调整
    assert extractor.do_resize is True
    # 断言确保调整后的最短边为256像素
    assert extractor.size == {"shortest_edge": 256}
    # 断言确保是否执行了中心裁剪
    assert extractor.do_center_crop is True
    # 断言确保裁剪尺寸为224x224像素
    assert extractor.crop_size == {"width": 224, "height": 224}
    # 断言确保使用双线性重采样
    assert extractor.resample == PILImageResampling.BILINEAR

    # 在这里参考:https://github.com/deepmind/dmvr/blob/master/dmvr/modalities.py
    # 可以看到 add_image 函数中 normalization_mean 和 normalization_std 的默认值分别设为 0 和 1
    # 这意味着没有进行归一化操作(而 ViViT 在调用此函数时也没有覆盖这些值)
    assert extractor.do_normalize is False
    # 断言确保是否执行了重新缩放
    assert extractor.do_rescale is True
    # 断言确保重新缩放因子为 1/255
    assert extractor.rescale_factor == 1 / 255

    # 断言确保是否执行了零中心化
    assert extractor.do_zero_centering is True

    # 返回图像处理器实例
    return extractor


def convert(output_path: str):
    # Flax 模型的路径
    flax_model_path = "checkpoint"

    # 如果 Flax 模型路径不存在,则下载检查点
    if not os.path.exists(flax_model_path):
        download_checkpoint(flax_model_path)

    # 恢复检查点的状态字典
    state_dict = restore_checkpoint(flax_model_path, None)
    # 对状态字典进行转换,包括分类头部的变换
    new_state = transform_state(state_dict, classification_head=True)

    # 获取 ViViT 的配置
    config = get_vivit_config()

    # 断言确保图像大小为 224
    assert config.image_size == 224
    # 断言确保帧数为 32
    assert config.num_frames == 32

    # 创建 ViViT 的视频分类模型
    model = VivitForVideoClassification(config)
    # 加载模型的状态字典
    model.load_state_dict(new_state)
    # 设为评估模式
    model.eval()

    # 获取图像处理器实例
    extractor = get_processor()

    # 准备视频数据
    video = prepare_video()
    # 使用图像处理器处理视频数据,返回 PyTorch 张量
    inputs = extractor(video, return_tensors="pt")

    # 对模型进行推理
    outputs = model(**inputs)

    # 期望的输出形状
    expected_shape = torch.Size([1, 400])
    # 期望的输出切片
    expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658])

    # 断言确保模型输出的 logits 的形状正确
    assert outputs.logits.shape == expected_shape
    # 断言确保前5个 logits 的值与期望值在指定的误差范围内
    assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5]

    # 将模型保存为预训练模型
    model.save_pretrained(output_path)
    # 保存图像处理器的预训练状态
    extractor.save_pretrained(output_path)


if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_model_name", "-o", type=str, help="输出转换后的 HuggingFace 模型的路径")

    args = parser.parse_args()
    # 调用 convert 函数,传入输出路径参数
    convert(args.output_model_name)

.\models\vivit\image_processing_vivit.py

# 指定编码为 UTF-8
# 版权声明,版权归 HuggingFace Inc. 团队所有,遵循 Apache License 2.0
# 详细版权信息可在 http://www.apache.org/licenses/LICENSE-2.0 获取
"""Vivit 的图像处理类。"""

# 引入必要的类型声明
from typing import Dict, List, Optional, Union

# 引入 numpy 库并重命名为 np
import numpy as np

# 从 transformers.utils 中导入 is_vision_available 函数
from transformers.utils import is_vision_available

# 从 transformers.utils.generic 中导入 TensorType 类型
from transformers.utils.generic import TensorType

# 导入自定义的图像处理工具和相关函数
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
    get_resize_output_image_size,
    rescale,
    resize,
    to_channel_dimension_format,
)
# 导入与图像处理相关的工具函数和常量
from ...image_utils import (
    IMAGENET_STANDARD_MEAN,
    IMAGENET_STANDARD_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    infer_channel_dimension_format,
    is_scaled_image,
    is_valid_image,
    to_numpy_array,
    valid_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
# 导入 logging 模块
from ...utils import logging

# 如果 vision 可用,则导入 PIL 库
if is_vision_available():
    import PIL

# 获取 logger 对象
logger = logging.get_logger(__name__)


def make_batched(videos) -> List[List[ImageInput]]:
    """将视频列表批量化为 Vivit 需要的格式。

    Args:
        videos: 输入的视频数据,可以是单个视频或嵌套列表/元组的视频集合。

    Returns:
        List[List[ImageInput]]: 批量化后的视频列表,每个元素为一个视频帧列表。
    
    Raises:
        ValueError: 如果无法从给定的视频数据创建批量化视频。
    """
    # 如果 videos 是嵌套列表或元组,并且第一个元素是嵌套列表或元组,并且第一个视频帧是有效图像
    if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
        return videos

    # 如果 videos 是列表或元组,并且第一个元素是有效图像
    elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
        return [videos]

    # 如果 videos 是有效图像
    elif is_valid_image(videos):
        return [[videos]]

    # 如果无法从 videos 创建批量化视频,抛出 ValueError 异常
    raise ValueError(f"Could not make batched video from {videos}")


class VivitImageProcessor(BaseImageProcessor):
    r"""
    构建 Vivit 图像处理器。

    继承自 BaseImageProcessor 类。
    """
    def __init__(self):
        """初始化 Vivit 图像处理器。"""
        super().__init__()
    # 定义函数参数和默认值,用于图像预处理
    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            是否调整图像的高度和宽度尺寸到指定的 `size`。可以在 `preprocess` 方法中的 `do_resize` 参数中被覆盖。
        size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
            调整后的输出图像尺寸。图像的最短边将调整为 `size["shortest_edge"]`,同时保持原始图像的纵横比。可以在 `preprocess` 方法中的 `size` 参数中被覆盖。
        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
            调整图像尺寸时使用的重采样滤波器。可以在 `preprocess` 方法中的 `resample` 参数中被覆盖。
        do_center_crop (`bool`, *optional*, defaults to `True`):
            是否对图像进行中心裁剪到指定的 `crop_size`。可以在 `preprocess` 方法中的 `do_center_crop` 参数中被覆盖。
        crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
            应用中心裁剪后的图像尺寸。可以在 `preprocess` 方法中的 `crop_size` 参数中被覆盖。
        do_rescale (`bool`, *optional*, defaults to `True`):
            是否按照指定的缩放因子 `rescale_factor` 进行图像缩放。可以在 `preprocess` 方法中的 `do_rescale` 参数中被覆盖。
        rescale_factor (`int` or `float`, *optional*, defaults to `1/127.5`):
            如果进行图像缩放,定义要使用的缩放因子。可以在 `preprocess` 方法中的 `rescale_factor` 参数中被覆盖。
        offset (`bool`, *optional*, defaults to `True`):
            是否在正负方向同时进行图像缩放。可以在 `preprocess` 方法中的 `offset` 参数中被覆盖。
        do_normalize (`bool`, *optional*, defaults to `True`):
            是否对图像进行归一化。可以在 `preprocess` 方法中的 `do_normalize` 参数中被覆盖。
        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            如果归一化图像,定义要使用的均值。这是一个浮点数或长度等于图像通道数的浮点数列表。可以在 `preprocess` 方法中的 `image_mean` 参数中被覆盖。
        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            如果归一化图像,定义要使用的标准差。这是一个浮点数或长度等于图像通道数的浮点数列表。可以在 `preprocess` 方法中的 `image_std` 参数中被覆盖。
    """
    
    model_input_names = ["pixel_values"]
    # 初始化函数,用于设置图像预处理的各项参数
    def __init__(
        self,
        do_resize: bool = True,  # 是否进行图像尺寸调整的标志
        size: Dict[str, int] = None,  # 图像尺寸的字典,包含最短边和可能的其他尺寸参数
        resample: PILImageResampling = PILImageResampling.BILINEAR,  # 图像调整大小时的重采样方法
        do_center_crop: bool = True,  # 是否进行中心裁剪的标志
        crop_size: Dict[str, int] = None,  # 裁剪后的图像尺寸的字典,包含高度和宽度
        do_rescale: bool = True,  # 是否进行图像像素值缩放的标志
        rescale_factor: Union[int, float] = 1 / 127.5,  # 图像像素值缩放的因子
        offset: bool = True,  # 是否进行图像像素值偏移的标志
        do_normalize: bool = True,  # 是否进行图像像素值标准化的标志
        image_mean: Optional[Union[float, List[float]]] = None,  # 图像像素值的均值
        image_std: Optional[Union[float, List[float]]] = None,  # 图像像素值的标准差
        **kwargs,  # 其他参数,以字典形式传入
    ) -> None:
        # 调用父类的初始化方法,传入额外的关键字参数
        super().__init__(**kwargs)
        
        # 如果 size 参数为 None,则设为默认值 {"shortest_edge": 256}
        size = size if size is not None else {"shortest_edge": 256}
        # 根据参数获取最终确定的图像尺寸字典,确保不是正方形
        size = get_size_dict(size, default_to_square=False)
        
        # 如果 crop_size 参数为 None,则设为默认值 {"height": 224, "width": 224}
        crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
        # 根据参数获取最终确定的裁剪尺寸字典
        crop_size = get_size_dict(crop_size, param_name="crop_size")

        # 将各参数值赋给对象的属性
        self.do_resize = do_resize
        self.size = size
        self.do_center_crop = do_center_crop
        self.crop_size = crop_size
        self.resample = resample
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.offset = offset
        self.do_normalize = do_normalize
        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
        
        # 图像处理器对象支持的键列表
        self._valid_processor_keys = [
            "videos",
            "do_resize",
            "size",
            "resample",
            "do_center_crop",
            "crop_size",
            "do_rescale",
            "rescale_factor",
            "offset",
            "do_normalize",
            "image_mean",
            "image_std",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]
    # 重新调整图像大小的函数,基于输入参数对图像进行变换

    # size参数表示输出图像的尺寸,根据get_size_dict函数获取确切的尺寸字典
    size = get_size_dict(size, default_to_square=False)

    # 如果size字典中包含"shortest_edge"键,根据最短边长度调整输出图像尺寸
    if "shortest_edge" in size:
        # 调用get_resize_output_image_size函数计算调整后的图像尺寸
        output_size = get_resize_output_image_size(
            image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
        )
    # 如果size字典中同时包含"height"和"width"键,直接使用指定的高度和宽度
    elif "height" in size and "width" in size:
        output_size = (size["height"], size["width"])
    # 如果size字典中的键不符合要求,抛出数值错误异常
    else:
        raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")

    # 调用resize函数对图像进行实际的大小调整操作,返回调整后的图像
    return resize(
        image,
        size=output_size,
        resample=resample,
        data_format=data_format,
        input_data_format=input_data_format,
        **kwargs,
    )
    ):
        """
        Rescale an image by a scale factor.

        If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
        1/127.5, the image is rescaled between [-1, 1].
            image = image * scale - 1

        If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
            image = image * scale

        Args:
            image (`np.ndarray`):
                Image to rescale.
            scale (`int` or `float`):
                Scale to apply to the image.
            offset (`bool`, *optional*):
                Whether to scale the image in both negative and positive directions.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be the same as the input image.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        # 调用rescale函数,对图像进行重新缩放处理
        rescaled_image = rescale(
            image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
        )

        # 如果offset为True,则对重新缩放后的图像进行偏移处理
        if offset:
            rescaled_image = rescaled_image - 1

        # 返回经过缩放和可能偏移处理后的图像
        return rescaled_image

    def _preprocess_image(
        self,
        image: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        offset: bool = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        """Preprocesses a single image."""

        # 验证预处理参数的有效性,确保所有参数都被正确设置
        validate_preprocess_arguments(
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            do_resize=do_resize,
            size=size,
            resample=resample,
        )

        if offset and not do_rescale:
            # 如果设置了 offset 但未设置 do_rescale,则抛出数值错误异常
            raise ValueError("For offset, do_rescale must also be set to True.")

        # 将输入的图像转换为 numpy 数组
        image = to_numpy_array(image)

        if is_scaled_image(image) and do_rescale:
            # 如果图像已经被缩放,并且需要进行重新缩放,则发出警告
            logger.warning_once(
                "It looks like you are trying to rescale already rescaled images. If the input"
                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
            )

        if input_data_format is None:
            # 推断输入数据的通道维度格式
            input_data_format = infer_channel_dimension_format(image)

        if do_resize:
            # 如果需要进行 resize 操作,则调用 resize 方法进行处理
            image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

        if do_center_crop:
            # 如果需要进行中心裁剪,则调用 center_crop 方法进行处理
            image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)

        if do_rescale:
            # 如果需要进行缩放操作,则调用 rescale 方法进行处理
            image = self.rescale(image=image, scale=rescale_factor, offset=offset, input_data_format=input_data_format)

        if do_normalize:
            # 如果需要进行归一化,则调用 normalize 方法进行处理
            image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

        # 将图像数据转换为指定的通道维度格式
        image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
        return image

    def preprocess(
        self,
        videos: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        offset: bool = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: ChannelDimension = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,

.\models\vivit\modeling_vivit.py

# 定义 VivitTubeletEmbeddings 类,用于构建 Vivit 模型的 Tubelet embeddings
class VivitTubeletEmbeddings(nn.Module):
    """
    Construct Vivit Tubelet embeddings.

    This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
    shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.

    The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
    (width // tubelet_size[2]).
    """

    def __init__(self, config):
        super().__init__()
        # 初始化 Tubelet embeddings 相关参数
        self.num_frames = config.num_frames  # 视频帧数
        self.image_size = config.image_size  # 视频帧的尺寸
        self.patch_size = config.tubelet_size  # Tubelet 的尺寸
        # 计算 patches 的数量,用于 Transformer 编码器的输入长度
        self.num_patches = (
            (self.image_size // self.patch_size[2])
            * (self.image_size // self.patch_size[1])
            * (self.num_frames // self.patch_size[0])
        )
        self.embed_dim = config.hidden_size  # 嵌入向量的维度

        # 使用 3D 卷积层将视频帧转换为嵌入向量
        self.projection = nn.Conv3d(
            config.num_channels,  # 输入视频的通道数
            config.hidden_size,   # 输出嵌入向量的维度
            kernel_size=config.tubelet_size,  # 卷积核大小,即 Tubelet 的尺寸
            stride=config.tubelet_size  # 卷积的步长,与 Tubelet 尺寸相同
        )
    # 定义前向传播方法,接受像素值作为输入
    def forward(self, pixel_values):
        # 获取输入张量的批大小、帧数、通道数、高度和宽度
        batch_size, num_frames, num_channels, height, width = pixel_values.shape
        # 检查输入图像的高度和宽度是否与模型要求的大小相匹配
        if height != self.image_size or width != self.image_size:
            # 如果不匹配,抛出数值错误异常
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
            )

        # 将输入像素值重新排列为 (batch_size, num_channels, num_frames, height, width)
        pixel_values = pixel_values.permute(0, 2, 1, 3, 4)

        # 使用投影层处理重新排列后的像素值
        x = self.projection(pixel_values)
        # 对处理后的输出进行扁平化,并转置维度,以便得到期望的形状
        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
        # 返回处理后的张量作为前向传播的输出
        return x
# 定义 VivitEmbeddings 类,继承自 nn.Module,用于视频数据的嵌入处理
class VivitEmbeddings(nn.Module):
    """
    Vivit Embeddings.

    Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
    """

    def __init__(self, config):
        super().__init__()

        # 初始化一个可学习的 CLS token 参数,形状为 [1, 1, hidden_size]
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        
        # 使用 VivitTubeletEmbeddings 类生成视频帧的嵌入表示
        self.patch_embeddings = VivitTubeletEmbeddings(config)

        # 初始化位置嵌入,形状为 [1, num_patches + 1, hidden_size]
        self.position_embeddings = nn.Parameter(
            torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
        )
        
        # 定义一个 dropout 层,用于随机置零输入张量的部分元素
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 存储配置信息
        self.config = config

    def forward(self, pixel_values):
        # 获取输入张量的 batch size
        batch_size = pixel_values.shape[0]
        
        # 生成视频帧的嵌入表示
        embeddings = self.patch_embeddings(pixel_values)

        # 复制并添加 CLS token 到嵌入表示中,维度变为 [batch_size, num_patches + 1, hidden_size]
        cls_tokens = self.cls_token.repeat([batch_size, 1, 1])
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # 添加位置编码到每个 token 上
        embeddings = embeddings + self.position_embeddings

        # 对嵌入表示进行 dropout 处理
        embeddings = self.dropout(embeddings)

        return embeddings


# 从 transformers.models.vit.modeling_vit.ViTSelfAttention 复制并修改为 VivitSelfAttention
class VivitSelfAttention(nn.Module):
    def __init__(self, config: VivitConfig) -> None:
        super().__init__()
        
        # 检查 hidden_size 是否能被 num_attention_heads 整除,否则抛出异常
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 初始化注意力头的数量和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 定义 dropout 层,用于注意力概率的随机置零
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    # 将输入张量重塑为注意力分数的形状
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 定义自注意力层的前向传播函数
    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ):
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 使用 self.query 对隐藏状态进行查询,得到混合查询层
        mixed_query_layer = self.query(hidden_states)

        # 使用 self.key 对隐藏状态进行键的转换,并调整维度以备点积计算
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        
        # 使用 self.value 对隐藏状态进行值的转换,并调整维度以备点积计算
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        
        # 使用混合查询层和键层的转置进行点积操作,得到原始的注意力分数
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 根据注意力头的大小对注意力分数进行缩放
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 对注意力分数进行归一化得到注意力概率
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 使用 dropout 对注意力概率进行随机丢弃处理
        attention_probs = self.dropout(attention_probs)

        # 如果有头部遮罩,则将注意力概率与头部遮罩相乘
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 将注意力概率与值层进行加权求和,得到上下文层
        context_layer = torch.matmul(attention_probs, value_layer)

        # 调整上下文层的维度顺序,并确保连续的存储
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

        # 将调整后的上下文层重塑为指定的形状
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        # 返回上下文层和可能的注意力概率,如果需要输出注意力信息
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs
# 从 transformers.models.vit.modeling_vit.ViTSelfOutput 复制而来,进行了 ViT -> Vivit 的改名
class VivitSelfOutput(nn.Module):
    """
    The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: VivitConfig) -> None:
        super().__init__()
        # 定义一个全连接层,输入和输出的维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个 dropout 层,以 config.hidden_dropout_prob 的概率随机置零输入张量的元素
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的 hidden_states 应用全连接层 self.dense
        hidden_states = self.dense(hidden_states)
        # 对应用全连接层后的 hidden_states 应用 dropout 层
        hidden_states = self.dropout(hidden_states)

        # 返回处理后的 hidden_states
        return hidden_states


# 从 transformers.models.vit.modeling_vit.ViTAttention 复制而来,进行了 ViT -> Vivit 的改名
class VivitAttention(nn.Module):
    def __init__(self, config: VivitConfig) -> None:
        super().__init__()
        # 实例化 VivitSelfAttention 类,传入配置 config,并赋值给 self.attention
        self.attention = VivitSelfAttention(config)
        # 实例化 VivitSelfOutput 类,传入配置 config,并赋值给 self.output
        self.output = VivitSelfOutput(config)
        # 初始化一个空集合,用于存储要剪枝的注意力头部
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        # 找到要剪枝的头部和索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储剪枝的头部
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 调用 self.attention 的 forward 方法,传入 hidden_states, head_mask 和 output_attentions
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)

        # 将 self_outputs[0] 和 hidden_states 作为输入,调用 self.output 的 forward 方法
        attention_output = self.output(self_outputs[0], hidden_states)

        # 如果需要输出注意力,则将 attentions 添加到 outputs 中
        outputs = (attention_output,) + self_outputs[1:]
        return outputs


class VivitIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 定义一个全连接层,输入维度为 config.hidden_size,输出维度为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 定义一个 dropout 层,以 config.hidden_dropout_prob 的概率随机置零输入张量的元素
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 如果 config.hidden_act 是字符串类型,则选择相应的激活函数;否则直接使用 config.hidden_act
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
    # 定义前向传播函数,接收隐藏状态作为输入
    def forward(self, hidden_states):
        # 使用全连接层对隐藏状态进行线性变换
        hidden_states = self.dense(hidden_states)
        # 应用激活函数对线性变换后的结果进行非线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 对非线性变换后的结果应用dropout操作,以减少过拟合风险
        hidden_states = self.dropout(hidden_states)

        # 返回处理后的隐藏状态作为输出
        return hidden_states
class VivitOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 定义一个全连接层,输入尺寸为config中的中间层大小,输出尺寸为config中的隐藏层大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 定义一个dropout层,使用config中的隐藏层dropout概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        # 对输入的hidden_states进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的hidden_states进行dropout
        hidden_states = self.dropout(hidden_states)
        # 将dropout后的hidden_states与输入的input_tensor相加作为最终输出
        hidden_states = hidden_states + input_tensor
        return hidden_states


class VivitLayer(nn.Module):
    """This corresponds to the EncoderBlock class in the scenic/vivit implementation."""

    def __init__(self, config):
        super().__init__()
        # 定义VivitLayer的属性,用于分块的前馈传播chunk_size_feed_forward,序列长度维度seq_len_dim为1
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        # 定义VivitLayer中的注意力层、中间层、输出层,分别使用给定的config初始化
        self.attention = VivitAttention(config)
        self.intermediate = VivitIntermediate(config)
        self.output = VivitOutput(config)
        # 在self-attention之前应用层归一化
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 在self-attention之后应用层归一化
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states, head_mask=None, output_attentions=False):
        # 将输入的hidden_states和head_mask传递给self.attention,获取self-attention的输出
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # 在self-attention之前应用层归一化
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        # 如果需要输出注意力权重,则将其添加到outputs中
        outputs = self_attention_outputs[1:]

        # 第一个残差连接,将self-attention的输出加到原始的hidden_states上
        hidden_states = attention_output + hidden_states

        # 在self-attention之后应用层归一化
        layer_output = self.layernorm_after(hidden_states)
        # 将归一化后的输出传递给中间层进行处理
        layer_output = self.intermediate(layer_output)

        # 第二个残差连接,在输出层中应用处理后的layer_output,并与原始的hidden_states相加
        layer_output = self.output(layer_output, hidden_states)

        # 将最终的输出组装成outputs
        outputs = (layer_output,) + outputs

        return outputs


class VivitEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 创建一个由VivitLayer实例组成的列表,长度为config中指定的隐藏层数量num_hidden_layers
        self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)])
        # 梯度检查点标志设置为False
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则初始化一个空元组,否则置为None
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 根据头部掩码是否存在,决定是否应用头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果启用了梯度检查点且处于训练模式,则使用梯度检查点函数
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 否则直接调用层模块进行前向传播
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则添加当前层输出的注意力权重到元组中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态添加到隐藏状态元组中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不使用返回字典格式,则返回一个元组,过滤掉None值
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        # 使用BaseModelOutput格式返回结果
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
        """
        # VivitPooler 类,用于汇集模型隐藏状态的第一个令牌的隐藏状态
        An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
        models.
        """

    config_class = VivitConfig
    base_model_prefix = "vivit"
    main_input_name = "pixel_values"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """
        # 初始化模型权重
        Initialize the weights
        """
        if isinstance(module, (nn.Linear, nn.Conv3d)):
            # 略有不同于 TF 版本,使用正态分布初始化权重
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, nn.Parameter):
            module.data.normal_(mean=0.0, std=self.config.initializer_range)
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`VivitImageProcessor`]. See
            [`VivitImageProcessor.preprocess`] for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
    "The bare ViViT Transformer model outputting raw hidden-states without any specific head on top.",
    VIVIT_START_DOCSTRING,
)
"""
class VivitModel(VivitPreTrainedModel):
    """
    ViViT Transformer model for raw hidden-states.

    Args:
        config: ViViT model configuration instance.
        add_pooling_layer: Whether to add a pooling layer on top of the encoder.

    Attributes:
        embeddings: ViViT embeddings module.
        encoder: ViViT encoder module.
        layernorm: Layer normalization module.
        pooler: Optional pooling layer for final representation.

    Methods:
        get_input_embeddings(): Retrieve the patch embeddings from the model.
        _prune_heads(heads_to_prune): Prune attention heads in the model.
        forward(...): Forward pass of the ViViT model.

    """
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        # Initialize ViViT components
        self.embeddings = VivitEmbeddings(config)
        self.encoder = VivitEncoder(config)

        # Layer normalization and optional pooling layer
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = VivitPooler(config) if add_pooling_layer else None

        # Initialize weights and final processing steps
        self.post_init()

    def get_input_embeddings(self):
        """
        Retrieve the patch embeddings from the ViViT model.

        Returns:
            embeddings: Patch embeddings used by the model.
        """
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes specific attention heads in the ViViT model.

        Args:
            heads_to_prune (dict): Dictionary mapping layer numbers to lists of heads to prune.
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass of the ViViT model.

        Args:
            pixel_values: Pixel values of the input video frames.
            head_mask: Mask to exclude certain attention heads.
            output_attentions: Whether to output attentions weights.
            output_hidden_states: Whether to output hidden states.
            return_dict: Whether to return a dictionary.

        Returns:
            BaseModelOutputWithPooling or torch.Tensor: Model outputs.
        """
        # Forward pass logic goes here
        pass

"""
@add_start_docstrings(
    """ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
[CLS] token) e.g. for Kinetics-400.""",
    VIVIT_START_DOCSTRING,
)
"""
class VivitForVideoClassification(VivitPreTrainedModel):
    """
    ViViT Transformer model with a video classification head.

    Args:
        config: ViViT model configuration instance.

    Attributes:
        num_labels: Number of classification labels.
        vivit: ViViT base model.
        classifier: Linear classification layer.

    Methods:
        forward(...): Forward pass of the model for video classification.

    """
    def __init__(self, config):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vivit = VivitModel(config, add_pooling_layer=False)

        # Classifier head
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        # Initialize weights and final processing steps
        self.post_init()

    @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass of the ViViT model for video classification.

        Args:
            pixel_values: Pixel values of the input video frames.
            head_mask: Mask to exclude certain attention heads.
            labels: Labels for classification.
            output_attentions: Whether to output attentions weights.
            output_hidden_states: Whether to output hidden states.
            return_dict: Whether to return a dictionary.

        Returns:
            ImageClassifierOutput or torch.Tensor: Model outputs.
        """
        # Forward pass logic goes here
        pass

.\models\vivit\__init__.py

# flake8: noqa
# 忽略 flake8 检查,因为这里没有办法仅忽略 "F401 '...' imported but unused" 警告而保留其它警告。
# 这样做是为了确保保留其它警告,而不对本模块进行检查。

# Copyright 2023 The HuggingFace Team. All rights reserved.
# 版权声明,版权归 HuggingFace Team 所有。

# Licensed under the Apache License, Version 2.0 (the "License");
# 授权协议,使用 Apache License, Version 2.0 版本。

# you may not use this file except in compliance with the License.
# 只有在遵守协议的情况下才可以使用此文件。

# You may obtain a copy of the License at
# 您可以在以下网址获取协议的副本

#     http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 没有任何明示或暗示的担保或条件。

# See the License for the specific language governing permissions and
# limitations under the License.
# 详细信息请参阅许可证,以了解权限限制和特定语言的许可。

from typing import TYPE_CHECKING

# 依赖于 isort 来合并导入

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义模块的导入结构
_import_structure = {
    "configuration_vivit": ["VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VivitConfig"],
}

# 尝试导入视觉功能,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["image_processing_vivit"] = ["VivitImageProcessor"]

# 尝试导入 Torch,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_vivit"] = [
        "VIVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "VivitModel",
        "VivitPreTrainedModel",
        "VivitForVideoClassification",
    ]

# 如果 TYPE_CHECKING 为真,则从相应模块导入特定内容
if TYPE_CHECKING:
    from .configuration_vivit import VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, VivitConfig

    # 尝试导入视觉功能,如果不可用则忽略
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .image_processing_vivit import VivitImageProcessor

    # 尝试导入 Torch,如果不可用则忽略
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vivit import (
            VIVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
            VivitForVideoClassification,
            VivitModel,
            VivitPreTrainedModel,
        )

# 如果不在 TYPE_CHECKING 模式下,则导入 LazyModule 动态生成模块
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\wav2vec2\configuration_wav2vec2.py

# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Wav2Vec2 model configuration
"""

import functools  # 导入functools模块,提供高阶函数操作工具
import operator   # 导入operator模块,提供函数形式的操作符接口

from ...configuration_utils import PretrainedConfig  # 从全局导入PretrainedConfig类
from ...utils import logging  # 从全局导入logging工具


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

# 定义Wav2Vec2预训练模型配置文件映射字典,包含模型名称和对应的配置文件下载链接
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
    # 查看所有Wav2Vec2模型,链接见https://huggingface.co/models?filter=wav2vec2
}


class Wav2Vec2Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
    Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the Wav2Vec2
    [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import Wav2Vec2Config, Wav2Vec2Model

    >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
    >>> configuration = Wav2Vec2Config()

    >>> # Initializing a model (with random weights) from the facebook/wav2vec2-base-960h style configuration
    >>> model = Wav2Vec2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "wav2vec2"  # 指定模型类型为"wav2vec2"
    # 初始化函数,用于创建一个 Transformer 模型的实例
    def __init__(
        self,
        vocab_size=32,  # 词汇表大小,默认为32
        hidden_size=768,  # 隐藏层大小,默认为768
        num_hidden_layers=12,  # Transformer 的隐藏层层数,默认为12
        num_attention_heads=12,  # 注意力头的数量,默认为12
        intermediate_size=3072,  # Feedforward 层的中间层大小,默认为3072
        hidden_act="gelu",  # 隐藏层激活函数,默认为GELU
        hidden_dropout=0.1,  # 隐藏层的Dropout率,默认为0.1
        activation_dropout=0.1,  # 激活函数的Dropout率,默认为0.1
        attention_dropout=0.1,  # 注意力层的Dropout率,默认为0.1
        feat_proj_dropout=0.0,  # 特征投影层的Dropout率,默认为0.0
        feat_quantizer_dropout=0.0,  # 特征量化器的Dropout率,默认为0.0
        final_dropout=0.1,  # 最终输出层的Dropout率,默认为0.1
        layerdrop=0.1,  # 层级Dropout率,默认为0.1
        initializer_range=0.02,  # 参数初始化范围,默认为0.02
        layer_norm_eps=1e-5,  # Layer Norm 的 epsilon 值,默认为1e-5
        feat_extract_norm="group",  # 特征提取的归一化方式,默认为"group"
        feat_extract_activation="gelu",  # 特征提取的激活函数,默认为GELU
        conv_dim=(512, 512, 512, 512, 512, 512, 512),  # 卷积层的维度设置,默认为一个元组
        conv_stride=(5, 2, 2, 2, 2, 2, 2),  # 卷积层的步幅设置,默认为一个元组
        conv_kernel=(10, 3, 3, 3, 3, 2, 2),  # 卷积层的核大小设置,默认为一个元组
        conv_bias=False,  # 是否使用卷积层的偏置,默认为False
        num_conv_pos_embeddings=128,  # 卷积位置嵌入的数量,默认为128
        num_conv_pos_embedding_groups=16,  # 卷积位置嵌入的分组数量,默认为16
        do_stable_layer_norm=False,  # 是否使用稳定的Layer Norm,默认为False
        apply_spec_augment=True,  # 是否应用频谱增强,默认为True
        mask_time_prob=0.05,  # 时间掩码的概率,默认为0.05
        mask_time_length=10,  # 时间掩码的长度,默认为10
        mask_time_min_masks=2,  # 时间掩码的最小数量,默认为2
        mask_feature_prob=0.0,  # 特征掩码的概率,默认为0.0
        mask_feature_length=10,  # 特征掩码的长度,默认为10
        mask_feature_min_masks=0,  # 特征掩码的最小数量,默认为0
        num_codevectors_per_group=320,  # 每组编码向量的数量,默认为320
        num_codevector_groups=2,  # 编码向量组的数量,默认为2
        contrastive_logits_temperature=0.1,  # 对比日志的温度参数,默认为0.1
        num_negatives=100,  # 负样本的数量,默认为100
        codevector_dim=256,  # 编码向量的维度,默认为256
        proj_codevector_dim=256,  # 投影编码向量的维度,默认为256
        diversity_loss_weight=0.1,  # 多样性损失的权重,默认为0.1
        ctc_loss_reduction="sum",  # CTC损失的减少方式,默认为"sum"
        ctc_zero_infinity=False,  # CTC损失中是否使用无穷大,默认为False
        use_weighted_layer_sum=False,  # 是否使用加权层求和,默认为False
        classifier_proj_size=256,  # 分类器投影的大小,默认为256
        tdnn_dim=(512, 512, 512, 512, 1500),  # TDNN 层的维度设置,默认为一个元组
        tdnn_kernel=(5, 3, 3, 1, 1),  # TDNN 层的核大小设置,默认为一个元组
        tdnn_dilation=(1, 2, 3, 1, 1),  # TDNN 层的膨胀率设置,默认为一个元组
        xvector_output_dim=512,  # x-vector 的输出维度,默认为512
        pad_token_id=0,  # 填充标记的ID,默认为0
        bos_token_id=1,  # 起始标记的ID,默认为1
        eos_token_id=2,  # 结束标记的ID,默认为2
        add_adapter=False,  # 是否添加适配器,默认为False
        adapter_kernel_size=3,  # 适配器的核大小,默认为3
        adapter_stride=2,  # 适配器的步幅,默认为2
        num_adapter_layers=3,  # 适配器的层数,默认为3
        output_hidden_size=None,  # 输出隐藏层的大小,默认为None
        adapter_attn_dim=None,  # 适配器的注意力维度,默认为None
        **kwargs,  # 其他关键字参数
    ):
        # 计算输入到 Logits 的比例,使用 functools.reduce 和 operator.mul 函数
        @property
        def inputs_to_logits_ratio(self):
            return functools.reduce(operator.mul, self.conv_stride, 1)

.\models\wav2vec2\convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2 checkpoint."""

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import json  # 用于处理JSON格式的数据
import os  # 用于与操作系统进行交互

import fairseq  # 导入fairseq库
import torch  # 导入PyTorch库
from fairseq.data import Dictionary  # 导入fairseq库中的Dictionary类

# 导入transformers库中的各个组件和模型类
from transformers import (
    Wav2Vec2Config,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2ForCTC,
    Wav2Vec2ForPreTraining,
    Wav2Vec2Processor,
    logging,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification

# 设置日志记录的详细程度为INFO级别
logging.set_verbosity_info()
logger = logging.get_logger(__name__)

# 定义一个映射字典,用于将旧模型的参数映射到新模型的参数
MAPPING = {
    "post_extract_proj": "feature_projection.projection",
    "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
    "self_attn.k_proj": "encoder.layers.*.attention.k_proj",
    "self_attn.v_proj": "encoder.layers.*.attention.v_proj",
    "self_attn.q_proj": "encoder.layers.*.attention.q_proj",
    "self_attn.out_proj": "encoder.layers.*.attention.out_proj",
    "self_attn_layer_norm": "encoder.layers.*.layer_norm",
    "fc1": "encoder.layers.*.feed_forward.intermediate_dense",
    "fc2": "encoder.layers.*.feed_forward.output_dense",
    "final_layer_norm": "encoder.layers.*.final_layer_norm",
    "encoder.layer_norm": "encoder.layer_norm",
    "adapter_layer": "encoder.layers.*.adapter_layer",
    "w2v_model.layer_norm": "feature_projection.layer_norm",
    "quantizer.weight_proj": "quantizer.weight_proj",
    "quantizer.vars": "quantizer.codevectors",
    "project_q": "project_q",
    "final_proj": "project_hid",
    "w2v_encoder.proj": "lm_head",
    "mask_emb": "masked_spec_embed",
    "pooling_layer.linear": "projector",
    "pooling_layer.projection": "classifier",
}

# 定义顶层键列表,列出需要映射的最高层级参数
TOP_LEVEL_KEYS = [
    "lm_head",
    "quantizer.weight_proj",
    "quantizer.codevectors",
    "project_q",
    "project_hid",
    "projector",
    "classifier",
]

# 定义一个函数,从文本文件中读取内容并存储为字典形式
def read_txt_into_dict(filename):
    result = {}
    with open(filename, "r") as file:
        for line_number, line in enumerate(file):
            line = line.strip()
            if line:
                words = line.split()
                key = line_number
                value = words[0]
                result[key] = value
    return result

# 定义一个递归设置函数,用于根据指定的键路径设置值到相应的属性上
def set_recursively(key, value, full_name, weight_type, hf_pointer):
    for attribute in key.split("."):
        hf_pointer = getattr(hf_pointer, attribute)

    hf_param_name = None  # 暂时未使用的参数名变量
    # 遍历参数映射字典中的所有键
    for param_key in PARAM_MAPPING.keys():
        # 检查完整名称是否以当前参数键结尾
        if full_name.endswith(param_key):
            # 根据参数映射字典获取对应的参数名
            hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
            # 设置权重类型为参数类型
            weight_type = "param"

    # 如果权重类型不为空且不为参数类型
    if weight_type is not None and weight_type != "param":
        # 获取指定权重类型属性的形状
        hf_shape = getattr(hf_pointer, weight_type).shape
    # 如果权重类型不为空且为参数类型
    elif weight_type is not None and weight_type == "param":
        # 逐级获取参数名对应的形状指针
        shape_pointer = hf_pointer
        for attribute in hf_param_name.split("."):
            shape_pointer = getattr(shape_pointer, attribute)
        # 获取最终的形状
        hf_shape = shape_pointer.shape

        # 缩减维度,仅保留第一个元素
        value = value[0]
    # 如果以上条件都不满足,获取当前指针的形状
    else:
        hf_shape = hf_pointer.shape

    # 检查获取的形状与值的形状是否相等,如果不相等则抛出异常
    if hf_shape != value.shape:
        raise ValueError(
            f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
            f" {value.shape} for {full_name}"
        )

    # 根据权重类型将值赋给相应的属性
    if weight_type == "weight":
        hf_pointer.weight.data = value
    elif weight_type == "weight_g":
        hf_pointer.weight_g.data = value
    elif weight_type == "weight_v":
        hf_pointer.weight_v.data = value
    elif weight_type == "bias":
        hf_pointer.bias.data = value
    elif weight_type == "param":
        # 逐级获取参数名对应的属性指针,并赋值
        for attribute in hf_param_name.split("."):
            hf_pointer = getattr(hf_pointer, attribute)
        hf_pointer.data = value
    else:
        # 若权重类型为空,则直接将值赋给指针
        hf_pointer.data = value

    # 记录日志,标明哪个权重或参数从哪里初始化
    logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
# 定义一个函数,用于根据一组映射规则重命名字典中的键,并更新权重类型
def rename_dict(key, value, full_name, weight_type, hf_dict):
    # 初始化变量hf_param_name
    hf_param_name = None
    # 遍历PARAM_MAPPING字典中的所有键
    for param_key in PARAM_MAPPING.keys():
        # 如果full_name以param_key结尾,确定hf_param_name的值为PARAM_MAPPING中对应的值
        if full_name.endswith(param_key):
            hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
            weight_type = "param"

    # 如果weight_type不为None且不等于"param"
    if weight_type is not None and weight_type != "param":
        # 构建完整的键名full_key,格式为key.weight_type
        full_key = ".".join([key, weight_type])
    # 如果weight_type不为None且等于"param"
    elif weight_type is not None and weight_type == "param":
        # 构建完整的键名full_key,格式为key.hf_param_name
        full_key = ".".join([key, hf_param_name])
    else:
        # 否则直接使用key作为完整的键名full_key
        full_key = key

    # 将键值对(full_key, value)添加到hf_dict字典中,如果full_key中包含"lm_head"则只取value的第一个元素
    hf_dict[full_key] = value if "lm_head" in full_key else value[0]


# PARAM_MAPPING字典,用于存储特定键的重命名规则
PARAM_MAPPING = {
    "W_a": "linear_1.weight",
    "W_b": "linear_2.weight",
    "b_a": "linear_1.bias",
    "b_b": "linear_2.bias",
    "ln_W": "norm.weight",
    "ln_b": "norm.bias",
}


# 加载wav2vec2模型的特定层的权重数据
def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None):
    # 标志变量,指示是否使用了这个权重数据
    is_used = False
    # 遍历MAPPING字典中的所有键值对
    for key, mapped_key in MAPPING.items():
        # 将mapped_key设置为"wav2vec2." + mapped_key,如果mapped_key不在TOP_LEVEL_KEYS中
        mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
        # 如果name中包含key或者name按"."分割的第一个部分等于key
        if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
            # 表示这个权重数据被使用了
            is_used = True
            # 如果mapped_key中包含"*",则替换为name按key分割后的倒数第二部分
            if "*" in mapped_key:
                layer_index = name.split(key)[0].split(".")[-2]
                mapped_key = mapped_key.replace("*", layer_index)
            # 根据name的内容确定权重类型
            if "weight_g" in name:
                weight_type = "weight_g"
            elif "weight_v" in name:
                weight_type = "weight_v"
            elif "bias" in name:
                weight_type = "bias"
            elif "weight" in name:
                # TODO: 不匹配quantizer.weight_proj
                weight_type = "weight"
            else:
                weight_type = None
            # 如果hf_dict不为None,则调用rename_dict函数重命名mapped_key,并将权重数据value存储到hf_dict中
            if hf_dict is not None:
                rename_dict(mapped_key, value, name, weight_type, hf_dict)
            else:
                # 否则调用set_recursively函数递归地设置mapped_key的权重数据value到hf_model中
                set_recursively(mapped_key, value, name, weight_type, hf_model)
            # 返回is_used,表示权重数据被使用了
            return is_used
    # 如果没有使用这个权重数据,则返回False
    return is_used


# 递归加载fairseq模型的权重数据到hf_model中
def recursively_load_weights(fairseq_model, hf_model, is_headless):
    # 未使用的权重数据列表
    unused_weights = []
    # 获取fairseq模型的状态字典
    fairseq_dict = fairseq_model.state_dict()

    # 获取hf_model中的特征提取器
    feature_extractor = hf_model.wav2vec2.feature_extractor

    # 遍历fairseq_dict中的所有权重数据项
    for name, value in fairseq_dict.items():
        # 标志变量,指示这个权重数据是否被使用
        is_used = False
        # 如果name中包含"conv_layers"
        if "conv_layers" in name:
            # 调用load_conv_layer函数加载卷积层的权重数据
            load_conv_layer(
                name,
                value,
                feature_extractor,
                unused_weights,
                hf_model.config.feat_extract_norm == "group",
            )
            # 标记这个权重数据被使用了
            is_used = True
        else:
            # 否则调用load_wav2vec2_layer函数加载wav2vec2模型的权重数据
            is_used = load_wav2vec2_layer(name, value, hf_model)
        # 如果这个权重数据没有被使用,则将其添加到未使用的权重数据列表中
        if not is_used:
            unused_weights.append(name)

    # 记录未使用的权重数据到日志中
    logger.warning(f"Unused weights: {unused_weights}")


# 加载卷积层的权重数据到特征提取器中
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
    # 将full_name按"conv_layers."分割,获取卷积层的名称name
    name = full_name.split("conv_layers.")[-1]
    # 将name按"."分割,获取层ID和类型ID
    items = name.split(".")
    layer_id = int(items[0])
    type_id = int(items[1])
    # 如果权重类型为0(偏置):
    if type_id == 0:
        # 如果名称中包含"bias":
        if "bias" in name:
            # 检查传入值的形状是否与对应卷积层的偏置数据形状相同,若不同则引发数值错误异常
            if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
                )
            # 将传入的值赋给对应卷积层的偏置数据
            feature_extractor.conv_layers[layer_id].conv.bias.data = value
            # 记录日志,表示特征提取器卷积层的偏置数据已从指定来源初始化
            logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
        # 如果名称中包含"weight":
        elif "weight" in name:
            # 检查传入值的形状是否与对应卷积层的权重数据形状相同,若不同则引发数值错误异常
            if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
                )
            # 将传入的值赋给对应卷积层的权重数据
            feature_extractor.conv_layers[layer_id].conv.weight.data = value
            # 记录日志,表示特征提取器卷积层的权重数据已从指定来源初始化
            logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
    
    # 如果权重类型为2且不使用分组归一化,或者权重类型为2且为第一层且使用分组归一化:
    elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
        # 如果名称中包含"bias":
        if "bias" in name:
            # 检查传入值的形状是否与对应卷积层的分组归一化偏置数据形状相同,若不同则引发数值错误异常
            if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
                )
            # 将传入的值赋给对应卷积层的分组归一化偏置数据
            feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
            # 记录日志,表示特征提取器卷积层的分组归一化偏置数据已从指定来源初始化
            logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
        # 如果名称中包含"weight":
        elif "weight" in name:
            # 检查传入值的形状是否与对应卷积层的分组归一化权重数据形状相同,若不同则引发数值错误异常
            if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
                raise ValueError(
                    f"{full_name} has size {value.shape}, but"
                    f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
                )
            # 将传入的值赋给对应卷积层的分组归一化权重数据
            feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
            # 记录日志,表示特征提取器卷积层的分组归一化权重数据已从指定来源初始化
            logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
    
    # 如果以上条件都不满足:
    else:
        # 将未使用的权重名称添加到未使用权重列表中
        unused_weights.append(full_name)
@torch.no_grad()
def convert_wav2vec2_checkpoint(
    checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    # 如果提供了配置文件路径,则从预训练配置加载配置信息
    if config_path is not None:
        config = Wav2Vec2Config.from_pretrained(config_path)
    else:
        # 否则,使用默认配置
        config = Wav2Vec2Config()

    # 如果是序列分类任务
    if is_seq_class:
        # 从文本文件加载字典
        id2label = read_txt_into_dict(dict_path)
        # 将 id 到标签的映射加入配置
        config.id2label = id2label
        # 创建序列分类的 Wav2Vec2 模型
        hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
        # 创建特征提取器对象
        feature_extractor = Wav2Vec2FeatureExtractor(
            feature_size=1,
            sampling_rate=16000,
            padding_value=0,
            do_normalize=True,
            return_attention_mask=True,
        )
        # 保存特征提取器配置到指定路径
        feature_extractor.save_pretrained(pytorch_dump_folder_path)

    # 如果是微调模型
    elif is_finetuned:
        if dict_path:
            # 加载目标字典
            target_dict = Dictionary.load(dict_path)

            # 调整配置中的特殊 token id,因为 CTC 符号是 <pad> 而不是 fairseq 中的 <s>
            config.bos_token_id = target_dict.pad_index
            config.pad_token_id = target_dict.bos_index
            config.eos_token_id = target_dict.eos_index
            config.vocab_size = len(target_dict.symbols)
            # 创建词汇表 JSON 文件的路径
            vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
            # 检查目标路径是否是目录
            if not os.path.isdir(pytorch_dump_folder_path):
                logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
                return
            # 创建目录(如果不存在)
            os.makedirs(pytorch_dump_folder_path, exist_ok=True)
            # 创建词汇表字典
            vocab_dict = target_dict.indices

            # fairseq 中的 <pad> 和 <s> 需要交换
            vocab_dict["<pad>"] = 0
            vocab_dict["<s>"] = 1
            # 将词汇表字典写入 JSON 文件
            with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
                json.dump(vocab_dict, vocab_handle)
            # 创建 Wav2Vec2CTC tokenizer 对象
            tokenizer = Wav2Vec2CTCTokenizer(
                vocab_path,
                unk_token=target_dict.unk_word,
                pad_token=target_dict.pad_word,
                bos_token=target_dict.bos_word,
                eos_token=target_dict.eos_word,
                word_delimiter_token="|",
                do_lower_case=False,
            )
            # 根据配置选择是否返回注意力掩码
            return_attention_mask = True if config.feat_extract_norm == "layer" else False
            # 创建特征提取器对象
            feature_extractor = Wav2Vec2FeatureExtractor(
                feature_size=1,
                sampling_rate=16000,
                padding_value=0,
                do_normalize=True,
                return_attention_mask=return_attention_mask,
            )
            # 创建处理器对象,包括特征提取器和 tokenizer
            processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
            # 保存处理器配置到指定路径
            processor.save_pretrained(pytorch_dump_folder_path)

        # 创建 Wav2Vec2ForCTC 模型对象
        hf_wav2vec = Wav2Vec2ForCTC(config)
    else:
        # 创建预训练模型对象
        hf_wav2vec = Wav2Vec2ForPreTraining(config)
    # 如果模型已经进行了微调或者是用于序列分类,则执行以下操作
    if is_finetuned or is_seq_class:
        # 载入模型集合和任务信息,同时设置数据路径参数
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
        )
    else:
        # 否则,设置音频预训练任务参数
        task_arg = argparse.Namespace(task="audio_pretraining")
        task = fairseq.tasks.setup_task(task_arg)

        # 载入模型集合和任务信息,同时设置任务参数
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task)

    # 将模型切换到评估模式
    model = model[0].eval()

    # 递归加载权重到模型,使用 hf_wav2vec 的权重,若非微调则反向加载
    recursively_load_weights(model, hf_wav2vec, not is_finetuned)

    # 将 PyTorch 模型保存到指定的转储文件夹路径
    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
# 如果脚本作为主程序运行,则执行以下代码块
if __name__ == "__main__":
    # 创建参数解析器对象
    parser = argparse.ArgumentParser()
    # 添加参数:输出 PyTorch 模型的路径
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    # 添加参数:fairseq 模型的检查点路径
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
    # 添加参数:微调模型的字典路径
    parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
    # 添加参数:待转换模型的 hf config.json 路径
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    # 添加参数:指示待转换模型是否为微调模型的标志
    parser.add_argument(
        "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
    )
    # 添加参数:指示待转换模型是否为序列分类模型的标志
    parser.add_argument(
        "--is_seq_class",
        action="store_true",
        help="Whether the model to convert is a fine-tuned sequence classification model or not",
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 根据参数判断待转换模型是否为微调模型
    is_finetuned = not args.not_finetuned and not args.is_seq_class
    # 调用函数,将 wav2vec2 模型检查点转换为 PyTorch 模型
    convert_wav2vec2_checkpoint(
        args.checkpoint_path,
        args.pytorch_dump_folder_path,
        args.config_path,
        args.dict_path,
        is_finetuned,
        args.is_seq_class,
    )

.\models\wav2vec2\convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py

# 导入必要的库和模块
import argparse  # 用于解析命令行参数

import torch  # PyTorch库

from transformers import (  # 导入transformers库中的相关模块和类
    Wav2Vec2Config,  # Wav2Vec2模型的配置类
    Wav2Vec2FeatureExtractor,  # Wav2Vec2的特征提取器类
    Wav2Vec2ForAudioFrameClassification,  # 用于音频帧分类的Wav2Vec2模型类
    Wav2Vec2ForSequenceClassification,  # 用于序列分类的Wav2Vec2模型类
    Wav2Vec2ForXVector,  # 用于X向量生成的Wav2Vec2模型类
    logging,  # 日志记录模块
)

logging.set_verbosity_info()  # 设置日志记录级别为info
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def convert_classification(base_model_name, hf_config, downstream_dict):
    # 根据预训练的模型名称和配置hf_config创建序列分类的Wav2Vec2模型
    model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
    # 设置模型的投影层权重和偏置,从下游任务的字典中获取
    model.projector.weight.data = downstream_dict["projector.weight"]
    model.projector.bias.data = downstream_dict["projector.bias"]
    # 设置模型的分类器权重和偏置,从下游任务的字典中获取
    model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
    model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
    return model  # 返回转换后的模型


def convert_diarization(base_model_name, hf_config, downstream_dict):
    # 根据预训练的模型名称和配置hf_config创建音频帧分类的Wav2Vec2模型
    model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
    # 设置模型的分类器权重和偏置,从下游任务的字典中获取
    model.classifier.weight.data = downstream_dict["model.linear.weight"]
    model.classifier.bias.data = downstream_dict["model.linear.bias"]
    return model  # 返回转换后的模型


def convert_xvector(base_model_name, hf_config, downstream_dict):
    # 根据预训练的模型名称和配置hf_config创建X向量生成的Wav2Vec2模型
    model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config)
    # 设置模型的投影层权重和偏置,从下游任务的字典中获取
    model.projector.weight.data = downstream_dict["connector.weight"]
    model.projector.bias.data = downstream_dict["connector.bias"]
    
    # 遍历并设置每个TDNN层的卷积核权重和偏置,从下游任务的字典中获取
    for i, kernel_size in enumerate(hf_config.tdnn_kernel):
        model.tdnn[i].kernel.weight.data = downstream_dict[
            f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
        ]
        model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]

    # 设置特征提取器的权重和偏置,从下游任务的字典中获取
    model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
    model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
    # 设置分类器的权重和偏置,从下游任务的字典中获取
    model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
    model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
    # 设置目标函数的权重,从下游任务的字典中获取
    model.objective.weight.data = downstream_dict["objective.W"]
    return model  # 返回转换后的模型


@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
    # 用于将S3PRL模型检查点转换为其他格式的函数,使用torch.no_grad()进行装饰
    """
    将模型的权重复制/粘贴/调整到transformers设计中。
    """
    # 使用torch加载检查点文件中的模型权重,指定CPU作为目标设备
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    
    # 从检查点中获取下游任务相关的信息
    downstream_dict = checkpoint["Downstream"]
    
    # 根据指定的配置路径创建Wav2Vec2的配置对象
    hf_config = Wav2Vec2Config.from_pretrained(config_path)
    
    # 根据预训练模型名称创建Wav2Vec2的特征提取器对象
    hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
        base_model_name, return_attention_mask=True, do_normalize=False
    )
    
    # 获取模型的架构信息,并检查其类型以确定使用哪种转换方法
    arch = hf_config.architectures[0]
    if arch.endswith("ForSequenceClassification"):
        # 如果模型架构适用于序列分类任务,则进行相应的转换
        hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
    elif arch.endswith("ForAudioFrameClassification"):
        # 如果模型架构适用于音频帧分类任务,则进行相应的转换
        hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
    elif arch.endswith("ForXVector"):
        # 如果模型架构适用于X向量任务,则进行相应的转换
        hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
    else:
        # 如果架构类型未知或不支持,则抛出未实现错误
        raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
    
    # 如果配置指定使用加权层求和,则加载权重信息到模型中
    if hf_config.use_weighted_layer_sum:
        hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
    
    # 将特征提取器的配置保存到指定路径
    hf_feature_extractor.save_pretrained(model_dump_path)
    
    # 将转换后的模型保存到指定路径
    hf_model.save_pretrained(model_dump_path)
if __name__ == "__main__":
    # 如果脚本作为主程序执行,则进入条件判断
    parser = argparse.ArgumentParser()
    # 创建参数解析器对象
    parser.add_argument(
        "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
    )
    # 添加命令行参数:预训练模型的名称
    parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
    # 添加命令行参数:分类器配置文件的路径
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
    # 添加命令行参数:s3prl 检查点文件的路径
    parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
    # 添加命令行参数:最终转换模型的输出路径
    args = parser.parse_args()
    # 解析命令行参数,并将其存储在 args 对象中
    convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
    # 调用函数 convert_s3prl_checkpoint,传入命令行参数中的相关路径信息作为参数

.\models\wav2vec2\feature_extraction_wav2vec2.py

# coding=utf-8
# 版权所有 2021 年 HuggingFace Inc. 团队。
#
# 根据 Apache 许可证 2.0 版本("许可证")许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则软件根据"原样"分发,
# 没有任何明示或暗示的保证或条件。
# 有关特定语言的权限,请参阅许可证。
"""
Wav2Vec2 的特征提取器类
"""

from typing import List, Optional, Union

import numpy as np

# 导入序列特征提取器和批处理特征
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging

# 获取日志记录器
logger = logging.get_logger(__name__)


class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
    r"""
    构建一个 Wav2Vec2 特征提取器。

    此特征提取器继承自 [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`],其中包含大部分主要方法。
    用户应参考此超类以获取关于这些方法的更多信息。

    Args:
        feature_size (`int`, 默认为 1):
            提取特征的特征维度。
        sampling_rate (`int`, 默认为 16000):
            音频文件的数字化采样率,以赫兹(Hz)表示。
        padding_value (`float`, 默认为 0.0):
            用于填充填充值的值。
        do_normalize (`bool`, *可选*, 默认为 `True`):
            是否对输入进行零均值单位方差归一化。归一化可以显著提高某些模型的性能,
            例如 [wav2vec2-lv60](https://huggingface.co/models?search=lv60)。
        return_attention_mask (`bool`, *可选*, 默认为 `False`):
            是否 [`~Wav2Vec2FeatureExtractor.__call__`] 应返回 `attention_mask`。

            <Tip>

            对于设置了 `config.feat_extract_norm == "group"` 的 Wav2Vec2 模型,例如
            [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h),**没有** 使用
            `attention_mask` 进行训练。对于这样的模型,`input_values` 应仅用 0 填充,不应传递 `attention_mask`。

            对于设置了 `config.feat_extract_norm == "layer"` 的 Wav2Vec2 模型,例如
            [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self),应传递 `attention_mask`
            以进行批处理推断。

            </Tip>
    """

    model_input_names = ["input_values", "attention_mask"]
    # 初始化方法,设置特征大小、采样率、填充值等参数,并调用父类的初始化方法
    def __init__(
        self,
        feature_size=1,
        sampling_rate=16000,
        padding_value=0.0,
        return_attention_mask=False,
        do_normalize=True,
        **kwargs,
    ):
        super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
        # 是否返回注意力掩码
        self.return_attention_mask = return_attention_mask
        # 是否进行归一化处理
        self.do_normalize = do_normalize

    @staticmethod
    def zero_mean_unit_var_norm(
        input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
    ) -> List[np.ndarray]:
        """
        每个数组都被归一化为零均值和单位方差
        """
        if attention_mask is not None:
            # 将注意力掩码转换为numpy数组类型
            attention_mask = np.array(attention_mask, np.int32)
            normed_input_values = []

            # 对于输入值列表中的每个向量和对应的长度进行循环
            for vector, length in zip(input_values, attention_mask.sum(-1)):
                # 计算切片的归一化值,确保长度外的部分使用填充值
                normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
                if length < normed_slice.shape[0]:
                    normed_slice[length:] = padding_value

                # 将归一化后的切片添加到结果列表中
                normed_input_values.append(normed_slice)
        else:
            # 对于没有注意力掩码的情况,直接对每个输入值进行归一化处理
            normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]

        # 返回归一化后的输入值列表
        return normed_input_values

    # 调用方法,接收原始语音数据及相关参数,并进行相应的数据处理和转换
    def __call__(
        self,
        raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
        padding: Union[bool, str, PaddingStrategy] = False,
        max_length: Optional[int] = None,
        truncation: bool = False,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        sampling_rate: Optional[int] = None,
        **kwargs,

.\models\wav2vec2\modeling_flax_wav2vec2.py

# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Flax Wav2Vec2 model.
"""

# 导入必要的模块和库
from functools import partial  # 导入 partial 函数,用于创建偏函数
from typing import Optional, Tuple, Union  # 导入类型提示所需的类型

import flax  # 导入 Flax 模块
import flax.linen as nn  # 导入 Flax 的线性层模块
import jax  # 导入 JAX 模块
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口
import numpy as np  # 导入 NumPy 库
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 的 FrozenDict 相关函数
from flax.linen.attention import dot_product_attention_weights  # 导入注意力权重计算函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入字典扁平化和反扁平化函数
from jax import lax  # 导入 JAX 的 lax 模块

# 导入相关输出、工具类和配置
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput  # 导入输出类
from ...modeling_flax_utils import (  # 导入工具函数和基类
    ACT2FN,
    FlaxPreTrainedModel,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging  # 导入实用工具和日志模块
from .configuration_wav2vec2 import Wav2Vec2Config  # 导入 Wav2Vec2 的配置类

# 获取日志记录器
logger = logging.get_logger(__name__)
    # 定义变量 `last_hidden_state`,用于存储 JAX NumPy 数组(jnp.ndarray),初始值为 None
    last_hidden_state: jnp.ndarray = None
    # 定义变量 `extract_features`,用于存储 JAX NumPy 数组(jnp.ndarray),初始值为 None
    extract_features: jnp.ndarray = None
    # 定义变量 `hidden_states`,用于存储一个元组,其中元素是 JAX NumPy 数组(jnp.ndarray),可选类型(Optional)表示可以为 None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 定义变量 `attentions`,用于存储一个元组,其中元素是 JAX NumPy 数组(jnp.ndarray),可选类型(Optional)表示可以为 None
    attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义一个数据类,用于存储 FlaxWav2Vec2 模型预训练的输出结果,继承自 ModelOutput
@flax.struct.dataclass
class FlaxWav2Vec2ForPreTrainingOutput(ModelOutput):
    """
    Output type of [`FlaxWav2Vec2ForPreTrainingOutput`], with potential hidden states and attentions.

    Args:
        loss (*optional*, returned when model is in train mode, `jnp.ndarray` of shape `(1,)`):
            Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
            paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
        projected_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
            Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
            projected quantized states.
        projected_quantized_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
            Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
            target vectors for contrastive loss.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义属性:模型预测的状态向量,形状为 jnp.ndarray 或者 None
    projected_states: jnp.ndarray = None
    # 定义属性:量化后的状态向量,形状为 jnp.ndarray 或者 None
    projected_quantized_states: jnp.ndarray = None
    # 定义属性:码本的困惑度,形状为 jnp.ndarray 或者 None
    codevector_perplexity: jnp.ndarray = None
    # 定义属性:隐藏状态的元组,包含 jnp.ndarray 或者 None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 定义属性:注意力的元组,包含 jnp.ndarray 或者 None
    attentions: Optional[Tuple[jnp.ndarray]] = None


# 定义一个函数,用于计算给定形状的随机掩码段落,用于实现 SpecAugment 数据增强方法,参考了 ASR 领域的论文
def _compute_mask_indices(
    shape: Tuple[int, int],
    mask_prob: float,
    mask_length: int,
    attention_mask: Optional[np.ndarray] = None,
    min_masks: int = 0,
) -> np.ndarray:
    """
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.
    """
    Args:
        shape: the shape for which to compute masks.
            should be of size 2 where first element is batch size and 2nd is timesteps
        mask_prob:
            probability for each token to be chosen as start of the span to be masked. this will be multiplied by
            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
        mask_length: size of the mask
        min_masks: minimum number of masked spans

    """
    # 解包形状参数,batch_size 为批次大小,sequence_length 为时间步长
    batch_size, sequence_length = shape

    # 如果 mask_length 小于 1,则引发值错误
    if mask_length < 1:
        raise ValueError("`mask_length` has to be bigger than 0.")

    # 如果 mask_length 大于 sequence_length,则引发值错误
    if mask_length > sequence_length:
        raise ValueError(
            f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
            f" `sequence_length`: {sequence_length}`"
        )

    # 计算每批次中需要掩蔽的区间数目
    num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item())
    # 确保 num_masked_spans 不小于 min_masks
    num_masked_spans = max(num_masked_spans, min_masks)

    # 确保掩蔽的索引数不超过 sequence_length
    if num_masked_spans * mask_length > sequence_length:
        num_masked_spans = sequence_length // mask_length

    # 初始化一个形状为 (batch_size, sequence_length) 的布尔类型的掩蔽数组
    spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)

    # 随机生成要掩蔽的起始索引
    spec_aug_mask_idxs = np.array(
        [
            np.random.choice(np.arange(sequence_length - (mask_length - 1)), num_masked_spans, replace=False)
            for _ in range(batch_size)
        ]
    )

    # 将掩蔽的索引扩展为掩蔽的区间
    spec_aug_mask_idxs = np.broadcast_to(spec_aug_mask_idxs[:, :, None], (batch_size, num_masked_spans, mask_length))
    spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, num_masked_spans * mask_length)

    # 创建一个偏移数组以便扩展掩蔽的区间
    offsets = np.arange(mask_length)[None, None, :]
    offsets = np.broadcast_to(offsets, (batch_size, num_masked_spans, mask_length)).reshape(
        batch_size, num_masked_spans * mask_length
    )
    spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

    # 在掩蔽数组中填充掩蔽的索引
    np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)

    # 如果存在 attention_mask,则确保填充的输入 ID 不能被掩蔽
    if attention_mask is not None:
        spec_aug_mask = np.where(attention_mask, spec_aug_mask, False)

    # 返回生成的掩蔽数组
    return spec_aug_mask
def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None):
    """
    Sample `num_negatives` vectors from feature vectors.
    """
    # 解析输入参数的形状信息
    batch_size, sequence_length, hidden_size = features_shape

    # 检查序列长度是否小于等于1,如果是则引发异常
    if sequence_length <= 1:
        raise ValueError(
            "`features should have `sequence_length` > 1, but are of shape "
            f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
        )

    # 从同一个语句中随机选择 `num_negatives` 个向量索引
    sampled_negative_indices = []
    for batch_idx in range(batch_size):
        # 根据注意力掩码确定可用索引的上限,或者使用序列长度的上限
        high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
        # 随机抽样索引,数量为 `num_negatives * sequence_length`
        sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,))
        sampled_negative_indices.append(sampled_indices_slice)

    sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32)

    # 生成正向量的索引,将其重复 `num_negatives` 次
    feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()

    # 避免抽样到相同的正向量索引,同时保持均匀分布
    sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1

    # 调整索引以匹配批次大小
    for batch_idx in range(1, batch_size):
        sampled_negative_indices[batch_idx] += batch_idx * sequence_length

    return sampled_negative_indices


WAV_2_VEC_2_START_DOCSTRING = r"""
    Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
    Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
    Auli.

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)


"""
    Parameters:
        config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).
            
            This specifies the data type used for computations, allowing for mixed-precision training or
            half-precision inference on GPUs or TPUs. If specified, all computations within the model will be
            performed with the specified `dtype`.

            **Note that this setting affects only the computation dtype and not the dtype of model parameters.**

            To change the dtype of model parameters, refer to [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
"""
定义一个类 `FlaxWav2Vec2LayerNormConvLayer`,继承自 `nn.Module`,用于实现基于 Flax 的 Wav2Vec2 模型的一层。
"""
class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
    # 设置类属性 `config` 为 `Wav2Vec2Config` 类型,用于配置模型参数
    config: Wav2Vec2Config
    # 设置类属性 `layer_id` 为整数,表示当前层的标识,默认为 0
    layer_id: int = 0
    # 设置类属性 `dtype` 为 `jnp.float32`,表示数据类型为 32 位浮点数
    dtype: jnp.dtype = jnp.float32
    # 设置函数,用于初始化网络层参数
    def setup(self):
        # 如果当前层不是第一层,设置输入卷积维度为指定的卷积维度列表中对应层的值,否则设为1
        self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
        # 设置输出卷积维度为指定的卷积维度列表中对应层的值
        self.out_conv_dim = self.config.conv_dim[self.layer_id]

        # 初始化卷积层
        self.conv = nn.Conv(
            features=self.config.conv_dim[self.layer_id],  # 卷积层输出特征维度
            kernel_size=(self.config.conv_kernel[self.layer_id],),  # 卷积核大小
            strides=(self.config.conv_stride[self.layer_id],),  # 卷积步长
            use_bias=self.config.conv_bias,  # 是否使用偏置
            kernel_init=jax.nn.initializers.he_normal(),  # 卷积核初始化方法
            padding="VALID",  # 卷积填充方式
            dtype=self.dtype,  # 数据类型
        )
        # 初始化层归一化层
        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化激活函数,根据配置选择相应的激活函数
        self.activation = ACT2FN[self.config.feat_extract_activation]

    # 定义调用函数,用于前向传播计算
    def __call__(self, hidden_states):
        # 卷积操作,计算特征提取后的隐藏状态
        hidden_states = self.conv(hidden_states)
        # 层归一化操作,对卷积输出进行归一化处理
        hidden_states = self.layer_norm(hidden_states)
        # 激活函数操作,对归一化后的输出应用激活函数
        hidden_states = self.activation(hidden_states)
        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个自定义的 Flax 模块,用于卷积操作并包含权重归一化
class FlaxConvWithWeightNorm(nn.Module):
    # 配置信息,指定为 Wav2Vec2Config 类型
    config: Wav2Vec2Config
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模块设置方法,用于初始化模块的各个部分
    def setup(self):
        # 创建卷积层,设置特征数为 hidden_size,卷积核大小为 num_conv_pos_embeddings
        self.conv = nn.Conv(
            features=self.config.hidden_size,
            kernel_size=(self.config.num_conv_pos_embeddings,),
            kernel_init=jax.nn.initializers.he_normal(),
            padding="VALID",
            feature_group_count=self.config.num_conv_pos_embedding_groups,
            dtype=self.dtype,
        )
        # 定义权重形状,与卷积层特征数及分组数有关
        weight_shape = (
            self.conv.features,
            self.conv.features // self.conv.feature_group_count,
            self.conv.kernel_size[0],
        )
        # 初始化并定义权重 v 作为模型参数,使用 he_normal 初始化器
        self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
        # 计算权重 v 的 L2 范数,并初始化权重 g 作为模型参数
        self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
        # 初始化偏置参数,特征数与卷积层相同
        self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
        # 计算用于填充输入的前置填充数
        self.prev_padding = self.conv.kernel_size[0] // 2

    # 内部方法,用于获取归一化后的权重
    def _get_normed_weights(self):
        # 计算权重 v 的归一化形式
        weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
        normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
        # 计算归一化后的卷积核
        normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
        return normed_kernel

    # 模块的调用方法,执行卷积操作并返回结果
    def __call__(self, hidden_states):
        # 获取归一化后的卷积核
        kernel = self._get_normed_weights()
        # 对输入进行前置填充,保证卷积输出尺寸与输入相同
        hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
        # 应用卷积操作到输入上,使用归一化后的卷积核和偏置
        hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
        return hidden_states


# 定义一个 Flax 模块,用于处理位置卷积嵌入
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
    # 配置信息,指定为 Wav2Vec2Config 类型
    config: Wav2Vec2Config
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模块设置方法,用于初始化模块的各个部分
    def setup(self):
        # 创建包含权重归一化的卷积层模块
        self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
        # 设置激活函数为配置文件中指定的函数
        self.activation = ACT2FN[self.config.feat_extract_activation]
        # 根据卷积核大小决定需要移除的填充数量
        self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0

    # 模块的调用方法,执行位置卷积嵌入操作并返回结果
    def __call__(self, hidden_states):
        # 调整输入张量的维度顺序
        hidden_states = hidden_states.transpose((0, 1, 2))
        # 应用包含权重归一化的卷积操作到输入上
        hidden_states = self.conv(hidden_states)
        # 根据需要移除的填充数量截取卷积输出
        if self.num_pad_remove > 0:
            hidden_states = hidden_states[:, : -self.num_pad_remove, :]
        # 应用激活函数到卷积输出上
        hidden_states = self.activation(hidden_states)
        # 恢复张量的原始维度顺序并返回结果
        hidden_states = hidden_states.transpose((0, 1, 2))
        return hidden_states


# 定义一个 Flax 模块,用于包含一系列卷积层的集合
class FlaxConvLayersCollection(nn.Module):
    # 配置信息,指定为 Wav2Vec2Config 类型
    config: Wav2Vec2Config
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 初始化方法,用于设置对象的初始状态
    def setup(self):
        # 如果配置要求特征提取的归一化方式为 "layer"
        if self.config.feat_extract_norm == "layer":
            # 创建一系列 FlaxWav2Vec2LayerNormConvLayer 对象作为 self.layers 列表的元素,
            # 每个对象对应一个特征提取层
            self.layers = [
                FlaxWav2Vec2LayerNormConvLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_feat_extract_layers)
            ]
        # 如果配置要求特征提取的归一化方式为 "group",暂时不支持这种方式
        elif self.config.feat_extract_norm == "group":
            # 抛出 NotImplementedError 异常,提醒暂时只支持 "layer" 形式的特征提取归一化
            raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
        # 如果配置的特征提取归一化方式既不是 "layer" 也不是 "group",则抛出 ValueError 异常
        else:
            # 抛出 ValueError 异常,指明配置中的 feat_extract_norm 值不合法
            raise ValueError(
                f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group',"
                " 'layer']"
            )

    # 对象被调用时执行的方法,用于处理输入的隐藏状态数据
    def __call__(self, hidden_states):
        # 遍历 self.layers 中的每个 conv_layer,依次对 hidden_states 进行处理
        for i, conv_layer in enumerate(self.layers):
            hidden_states = conv_layer(hidden_states)  # 调用 conv_layer 对象处理 hidden_states
        # 返回处理后的 hidden_states
        return hidden_states
class FlaxWav2Vec2FeatureEncoder(nn.Module):
    """从原始音频波形中构建特征"""

    config: Wav2Vec2Config  # 引用Wav2Vec2Config配置对象
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型,默认为单精度浮点数

    def setup(self):
        self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
        # 初始化卷积层集合,使用配置对象和指定数据类型

    def __call__(self, input_values, freeze_feature_encoder=False):
        hidden_states = input_values[:, :, None]
        # 在最后添加一个维度,将形状从[batch_size, seq_len]变为[batch_size, seq_len, 1]
        hidden_states = self.conv_layers(hidden_states)
        # 经过卷积层处理,处理后形状为[batch_size, seq_len, hidden_size]
        if freeze_feature_encoder:
            hidden_states = jax.lax.stop_gradient(hidden_states)
            # 如果需要冻结特征编码器,则停止梯度传播
        return hidden_states


class FlaxWav2Vec2FeatureProjection(nn.Module):
    config: Wav2Vec2Config  # 引用Wav2Vec2Config配置对象
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型,默认为单精度浮点数

    def setup(self):
        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化层归一化,使用指定的epsilon值和数据类型
        self.projection = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化全连接层,设置隐藏大小、权重初始化方法和数据类型
        self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
        # 初始化dropout层,设置丢弃率为配置中的特征投影dropout率

    def __call__(self, hidden_states, deterministic=True):
        norm_hidden_states = self.layer_norm(hidden_states)
        # 对隐藏状态进行层归一化处理
        hidden_states = self.projection(norm_hidden_states)
        # 应用投影层
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 应用dropout,如果确定性为True,则使用确定性dropout
        return hidden_states, norm_hidden_states


class FlaxWav2Vec2Attention(nn.Module):
    config: Wav2Vec2Config  # 引用Wav2Vec2Config配置对象
    embed_dim: int  # 嵌入维度
    num_heads: int  # 头的数量
    dropout: float = 0.0  # dropout率,默认为0.0
    bias: bool = True  # 是否使用偏置,默认为True
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型,默认为单精度浮点数

    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads
        # 计算每个头的维度
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
            # 检查embed_dim必须能够被num_heads整除的条件,否则引发错误

        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 创建一个部分应用了参数的全连接层函数

        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        # 使用dense函数初始化查询、键、值投影层
        self.out_proj = dense()
        # 使用dense函数初始化输出投影层

        self.dropout_layer = nn.Dropout(rate=self.dropout)
        # 初始化dropout层,设置丢弃率为配置中的dropout率

    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
        # 将隐藏状态切分成多个头

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
        # 合并多个头的隐藏状态

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        # 定义Attention层的调用方式,包括隐藏状态、键值状态、注意力掩码和确定性
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""
        
        # 获取查询投影
        query_states = self.q_proj(hidden_states)

        # 获取键投影和值投影
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # 将查询投影、键投影和值投影按照头的数量进行分割
        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # 如果存在注意力掩码,则扩展维度以匹配张量形状
        if attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # 将布尔类型的注意力掩码转换为注意力偏置
        if attention_mask is not None:
            # 注意力掩码转换为注意力偏置
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        # 如果不是确定性计算且具有非零的 dropout 率,则创建 dropout 随机数生成器
        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        # 计算注意力权重
        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        # 计算注意力输出,使用 einsum 实现批量矩阵乘法
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = self._merge_heads(attn_output)  # 合并注意力头
        attn_output = self.out_proj(attn_output)  # 输出投影

        return attn_output, attn_weights
# 定义一个名为 FlaxWav2Vec2FeedForward 的自定义神经网络模块,继承自 nn.Module
class FlaxWav2Vec2FeedForward(nn.Module):
    # 类属性:配置信息,类型为 Wav2Vec2Config
    config: Wav2Vec2Config
    # 类属性:数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法,设置网络结构
    def setup(self):
        # 定义中间层的 dropout 操作,使用配置中的激活函数的 dropout 率
        self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)

        # 定义中间层的全连接层,输入大小为配置中的 intermediate_size
        # 初始化方式为正态分布,范围为配置中的 initializer_range
        self.intermediate_dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

        # 根据配置选择激活函数,如果是字符串则从预定义的映射中获取,否则直接使用配置中的激活函数
        if isinstance(self.config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
        else:
            self.intermediate_act_fn = self.config.hidden_act

        # 定义输出层的全连接层,输出大小为配置中的 hidden_size
        # 初始化方式为正态分布,范围为配置中的 initializer_range
        self.output_dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

        # 定义输出层的 dropout 操作,使用配置中的隐藏层 dropout 率
        self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)

    # 前向传播方法,接收隐藏状态和是否确定性的标志,返回最终的隐藏状态
    def __call__(self, hidden_states, deterministic=True):
        # 中间层的全连接操作
        hidden_states = self.intermediate_dense(hidden_states)
        # 中间层的激活函数
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 中间层的 dropout 操作
        hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)

        # 输出层的全连接操作
        hidden_states = self.output_dense(hidden_states)
        # 输出层的 dropout 操作
        hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
        # 返回最终的隐藏状态
        return hidden_states


# 定义一个名为 FlaxWav2Vec2EncoderLayerStableLayerNorm 的自定义神经网络模块,继承自 nn.Module
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
    # 类属性:配置信息,类型为 Wav2Vec2Config
    config: Wav2Vec2Config
    # 类属性:数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法,设置网络结构
    def setup(self):
        # 定义注意力层
        self.attention = FlaxWav2Vec2Attention(
            config=self.config,
            embed_dim=self.config.hidden_size,
            num_heads=self.config.num_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 定义隐藏层的 dropout 操作
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
        # 定义层归一化操作,使用配置中的 epsilon
        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 定义前馈网络层
        self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
        # 定义最终的层归一化操作
        self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    # 前向传播方法,接收隐藏状态、注意力掩码、是否确定性的标志和是否输出注意力权重的标志,返回输出
    def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
        # 记录注意力残差连接
        attn_residual = hidden_states
        # 应用层归一化操作
        hidden_states = self.layer_norm(hidden_states)
        # 注意力层的前向传播
        hidden_states, attn_weights = self.attention(
            hidden_states, attention_mask=attention_mask, deterministic=deterministic
        )
        # 应用隐藏层的 dropout 操作
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 加上注意力残差连接
        hidden_states = attn_residual + hidden_states
        # 应用前馈网络层
        hidden_states = hidden_states + self.feed_forward(
            self.final_layer_norm(hidden_states), deterministic=deterministic
        )

        # 输出结果
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则添加到输出中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs


# 定义一个名为 FlaxWav2Vec2EncoderLayerStableLayerNormCollection 的自定义神经网络模块,继承自 nn.Module
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
    # 类属性:配置信息,类型为 Wav2Vec2Config
    config: Wav2Vec2Config
    # 定义数据类型为 jnp.float32,默认为浮点数类型
    dtype: jnp.dtype = jnp.float32
    
    # 定义初始化方法,创建多个编码层对象并存储在列表 self.layers 中
    def setup(self):
        self.layers = [
            # 使用 FlaxWav2Vec2EncoderLayerStableLayerNorm 类创建编码层对象,编号从 '0' 到 str(num_hidden_layers-1),并指定数据类型为 self.dtype
            FlaxWav2Vec2EncoderLayerStableLayerNorm(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.num_hidden_layers)
        ]
    
    # 定义调用方法,接受输入 hidden_states 和多个可选参数,并根据参数返回结果
    def __call__(
        self,
        hidden_states,  # 输入的隐藏状态张量
        attention_mask=None,  # 可选的注意力掩码张量,默认为 None
        deterministic: bool = True,  # 是否确定性推断,默认为 True
        output_attentions: bool = False,  # 是否输出注意力张量,默认为 False
        output_hidden_states: bool = False,  # 是否输出所有隐藏状态,默认为 False
        return_dict: bool = True,  # 是否以字典形式返回结果,默认为 True
    ):
        # 初始化空的元组变量 all_attentions 和 all_hidden_states,根据参数决定是否存储相应的输出
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
    
        # 遍历 self.layers 中的编码层,并依次处理隐藏状态
        for i, layer in enumerate(self.layers):
            if output_hidden_states:
                # 如果需要输出隐藏状态,则将当前隐藏状态存入 all_hidden_states 中
                all_hidden_states += (hidden_states,)
    
            # 调用当前层的 __call__ 方法,处理隐藏状态和注意力掩码,根据参数确定是否输出注意力张量
            layer_outputs = layer(
                hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
            )
    
            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
    
            if output_attentions:
                # 如果需要输出注意力张量,则将当前层的注意力张量存入 all_attentions 中
                all_attentions += (layer_outputs[1],)
    
        if output_hidden_states:
            # 如果需要输出隐藏状态,则将最终的隐藏状态存入 all_hidden_states 中
            all_hidden_states += (hidden_states,)
    
        # 按照设定的返回方式构建输出元组 outputs
        outputs = (hidden_states, all_hidden_states, all_attentions)
    
        if not return_dict:
            # 如果不需要以字典形式返回,则返回一个去除 None 值后的元组
            return tuple(v for v in outputs if v is not None)
    
        # 如果需要以字典形式返回,则返回一个包含最终隐藏状态、所有隐藏状态和所有注意力张量的 FlaxBaseModelOutput 对象
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
    # Wav2Vec2Config类型的配置对象
    config: Wav2Vec2Config
    # 数据类型,默认为32位浮点数
    dtype: jnp.dtype = jnp.float32

    # 模块设置方法,初始化各个子模块
    def setup(self):
        # 位置卷积嵌入层对象,使用Wav2Vec2Config配置和指定数据类型
        self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
        # 层归一化对象,使用指定的epsilon值和数据类型
        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 丢弃层对象,使用指定的丢弃率
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
        # 编码器层集合对象,使用Wav2Vec2Config配置和指定数据类型
        self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)

    # 对象调用方法,实现编码器的前向计算
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic=True,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        # 如果存在注意力掩码,则确保填充的令牌不被注意到
        if attention_mask is not None:
            hidden_states = jnp.where(
                # 根据注意力掩码扩展到hidden_states的形状,将未被掩盖的位置置为0
                jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
            )

        # 计算位置嵌入
        position_embeddings = self.pos_conv_embed(hidden_states)

        # 将位置嵌入加到hidden_states中
        hidden_states = hidden_states + position_embeddings
        # 对加了位置嵌入的hidden_states进行丢弃操作
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)

        # 调用编码器层集合对象进行编码器层的前向计算
        outputs = self.layers(
            hidden_states,
            attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 对编码器输出的最后一个隐藏状态进行层归一化处理
        last_hidden_state = self.layer_norm(outputs[0])

        # 如果需要返回隐藏状态历史,更新最后一个`hidden_states`元素
        hidden_states = None
        if output_hidden_states:
            hidden_states = outputs[1]
            hidden_states = hidden_states[:-1] + (last_hidden_state,)

        # 如果不返回字典格式的结果,则展开outputs并返回非空值
        if not return_dict:
            outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
            return tuple(v for v in outputs if v is not None)

        # 返回FlaxBaseModelOutput对象,包括最后的隐藏状态、隐藏状态历史和注意力信息
        return FlaxBaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
        )


class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
    """
    Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
    """

    # Wav2Vec2Config类型的配置对象
    config: Wav2Vec2Config
    # 数据类型,默认为32位浮点数
    dtype: jnp.dtype = jnp.float32
    # 在设置方法中初始化类的一些属性
    def setup(self):
        # 将配置中的参数赋值给实例属性
        self.num_groups = self.config.num_codevector_groups
        self.num_vars = self.config.num_codevectors_per_group

        # 检查是否能够均匀分割 codevector_dim
        if self.config.codevector_dim % self.num_groups != 0:
            # 如果不能整除,抛出数值错误异常
            raise ValueError(
                f"`config.codevector_dim {self.config.codevector_dim} must be divisible by"
                f" `config.num_codevector_groups` {self.num_groups} for concatenation"
            )

        # 为存储码书变量(码字)预留空间
        self.codevectors = self.param(
            "codevectors",
            jax.nn.initializers.uniform(),
            (1, self.num_groups * self.num_vars, self.config.codevector_dim // self.num_groups),
        )
        
        # 设置权重投影层
        self.weight_proj = nn.Dense(
            self.num_groups * self.num_vars,
            kernel_init=jax.nn.initializers.normal(1.0),
            dtype=self.dtype,
        )

    # 静态方法:计算困惑度
    @staticmethod
    def _compute_perplexity(probs, mask=None):
        # 如果有掩码,扩展掩码并应用到概率矩阵上
        if mask is not None:
            mask_extended = jnp.broadcast_to(mask.flatten()[:, None, None], probs.shape)
            probs = jnp.where(mask_extended, probs, jnp.zeros_like(probs))
            marginal_probs = probs.sum(axis=0) / mask.sum()
        else:
            # 否则,计算概率矩阵的平均值
            marginal_probs = probs.mean(axis=0)

        # 计算困惑度
        perplexity = jnp.exp(-jnp.sum(marginal_probs * jnp.log(marginal_probs + 1e-7), axis=-1)).sum()
        return perplexity
    def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1):
        batch_size, sequence_length, hidden_size = hidden_states.shape

        # 将隐藏状态投影到代码向量维度
        hidden_states = self.weight_proj(hidden_states)
        hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1)

        if not deterministic:
            # 使用古贝尔分布在可区分的方式中采样代码向量概率
            gumbel_rng = self.make_rng("gumbel")
            gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape)
            codevector_probs = nn.softmax((hidden_states + gumbels) / temperature)

            # 计算困惑度
            codevector_soft_dist = nn.softmax(
                hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1
            )
            perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
        else:
            # 以非可区分的方式取 argmax
            # 计算硬代码向量分布(one-hot)
            codevector_idx = hidden_states.argmax(axis=-1)
            codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0
            codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1)
            perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)

        codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1)
        # 使用概率值检索代码向量
        codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors
        codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
        codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1)

        return codevectors, perplexity
class FlaxWav2Vec2Adapter(nn.Module):
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # hidden_states require down-projection if feature dims don't match
        # 如果特征维度不匹配,则需要对隐藏状态进行降维投影
        if self.config.output_hidden_size != self.config.hidden_size:
            # Initialize a Dense layer for projection with normal distribution initialization
            # 初始化一个用于投影的稠密层,使用正态分布进行初始化
            self.proj = nn.Dense(
                self.config.output_hidden_size,
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
                dtype=self.dtype,
            )
            # Layer normalization for the projection layer
            # 投影层的层归一化
            self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        else:
            self.proj = self.proj_layer_norm = None

        # Initialize the collection of adapter layers
        # 初始化适配器层集合
        self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)

    def __call__(self, hidden_states, deterministic=True):
        # down-project hidden_states if required
        # 如果需要,则对隐藏状态进行降维投影
        if self.proj is not None and self.proj_layer_norm is not None:
            hidden_states = self.proj(hidden_states)
            hidden_states = self.proj_layer_norm(hidden_states)

        # Pass hidden_states through adapter layers
        # 通过适配器层处理隐藏状态
        hidden_states = self.layers(hidden_states)

        return hidden_states


class FlaxWav2Vec2AdapterLayer(nn.Module):
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Initialize a convolutional layer for the adapter layer
        # 初始化适配器层的卷积层
        self.conv = nn.Conv(
            features=2 * self.config.output_hidden_size,
            kernel_size=(self.config.adapter_kernel_size,),
            strides=(self.config.adapter_stride,),
            padding=((1, 1),),
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    def __call__(self, hidden_states):
        # Apply convolution to hidden_states
        # 将卷积应用于隐藏状态
        hidden_states = self.conv(hidden_states)
        # Apply gated linear unit (GLU) activation along axis 2
        # 沿着轴 2 应用门控线性单元(GLU)激活函数
        hidden_states = nn.glu(hidden_states, axis=2)

        return hidden_states


class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Initialize a list of adapter layers
        # 初始化适配器层的列表
        self.layers = [
            FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.num_adapter_layers)
        ]

    def __call__(self, hidden_states):
        # Iterate through each adapter layer and apply it to hidden_states
        # 遍历每个适配器层,并将其应用于隐藏状态
        for conv_layer in self.layers:
            hidden_states = conv_layer(hidden_states)

        return hidden_states


class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = Wav2Vec2Config
    base_model_prefix: str = "wav2vec2"
    main_input_name = "input_values"
    module_class: nn.Module = None

    def __init__(
        self,
        config: Wav2Vec2Config,
        input_shape: Tuple = (1, 1024),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用配置和数据类型初始化模块对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类初始化方法,传递配置、模块对象、输入形状、随机种子、数据类型和是否执行初始化的标志
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_values = jnp.zeros(input_shape, dtype="i4")
        # 创建一个与输入值形状相同的全1张量作为注意力掩码
        attention_mask = jnp.ones_like(input_values)
        # 拆分随机数生成器为两部分,一个用于参数,一个用于dropout
        params_rng, dropout_rng = jax.random.split(rng, 2)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模块的初始化方法初始化参数,返回参数字典
        random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]

        if params is not None:
            # 如果传入了额外的参数,将随机生成的参数与传入的参数合并
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            # 否则直接返回随机生成的参数
            return random_params

    @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        freeze_feature_encoder: bool = False,
        return_dict: Optional[bool] = None,
    ):
        # 如果输出注意力没有明确指定,则使用配置中的设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果输出隐藏状态没有明确指定,则使用配置中的设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果返回字典没有明确指定,则使用配置中的设置
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 获取输入数据的批量大小和序列长度
        batch_size, sequence_length = input_values.shape

        # 如果没有提供注意力掩码,则创建一个全1的注意力掩码
        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # 处理可能存在的随机数生成器
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 构建输入参数字典,如果未提供params则使用self.params
        inputs = {"params": params or self.params}

        # 调用模块的应用方法,执行模型前向传播
        return self.module.apply(
            inputs,
            jnp.array(input_values, dtype="f4"),
            jnp.array(attention_mask, dtype="i4"),
            mask_time_indices,
            not train,
            output_attentions,
            output_hidden_states,
            freeze_feature_encoder,
            return_dict,
            rngs=rngs,
        )

    def _get_feat_extract_output_lengths(
        self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
    ):
        # 调用模块的特征提取方法,获取输出长度
        return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
# 定义一个名为 FlaxWav2Vec2Module 的 PyTorch 模块
class FlaxWav2Vec2Module(nn.Module):
    # 类型注解:配置信息为 Wav2Vec2Config 类型
    config: Wav2Vec2Config
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模块初始化方法
    def setup(self):
        # 初始化特征提取器,使用配置信息和指定数据类型
        self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
        # 初始化特征投影器,使用配置信息和指定数据类型
        self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
        # 初始化掩码后的谱图嵌入参数,形状为 (hidden_size,)
        self.masked_spec_embed = self.param(
            "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
        )

        # 如果配置指定使用稳定层归一化
        if self.config.do_stable_layer_norm:
            # 初始化编码器,使用配置信息和指定数据类型
            self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
        else:
            # 抛出错误,暂不支持稳定层归一化未启用的情况
            raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")

        # 如果配置指定添加适配器,初始化适配器
        self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None

    # 模块的调用方法,用于执行模型前向传播
    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        deterministic=True,
        output_attentions=None,
        output_hidden_states=None,
        freeze_feature_encoder=False,
        return_dict=None,
    ):
        # 提取特征向量
        extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)

        # 如果有注意力掩码
        if attention_mask is not None:
            # 计算对应于特征向量的减少注意力掩码
            attention_mask = self._get_feature_vector_attention_mask(
                extract_features.shape[1], attention_mask, add_adapter=False
            )

        # 特征投影
        hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
        
        # 如果有时间轴索引的掩码
        if mask_time_indices is not None:
            # 在时间轴上应用 SpecAugment,并使用给定的索引
            hidden_states = jnp.where(
                jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
                jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
                hidden_states,
            )

        # 编码器的输出
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 编码器的隐藏状态
        hidden_states = encoder_outputs[0]

        # 如果有适配器,应用适配器
        if self.adapter is not None:
            hidden_states = self.adapter(hidden_states)

        # 如果不返回字典形式的结果
        if not return_dict:
            # 返回元组形式的结果:(隐藏状态, 提取的特征) + 编码器输出中的其余部分
            return (hidden_states, extract_features) + encoder_outputs[1:]

        # 返回 FlaxWav2Vec2BaseModelOutput 类的实例,包括最后的隐藏状态、提取的特征、隐藏状态和注意力权重
        return FlaxWav2Vec2BaseModelOutput(
            last_hidden_state=hidden_states,
            extract_features=extract_features,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    # 辅助方法:获取特征提取器的输出长度
    def _get_feat_extract_output_lengths(
        self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
    ):
        """
        Computes the output length of the convolutional layers
        计算卷积层的输出长度
        """

        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

        def _conv_out_length(input_length, kernel_size, stride):
            # 1D convolutional layer output length formula taken
            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
            # 1D卷积层输出长度的计算公式,参考自 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
            return (input_length - kernel_size) // stride + 1

        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        if add_adapter:
            for _ in range(self.config.num_adapter_layers):
                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

        return input_lengths

    def _get_feature_vector_attention_mask(
        self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
    ):
        # Effectively attention_mask.sum(-1), but not inplace to be able to run
        # on inference mode.
        # 实际上是 attention_mask.sum(-1),但不是原地操作,以便在推断模式下运行。
        non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]

        output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)

        batch_size = attention_mask.shape[0]

        attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
        # these two operations makes sure that all values
        # before the output lengths indices are attended to
        # 这两个操作确保所有输出长度索引之前的值都被关注到
        attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
        attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
        return attention_mask
# 添加函数文档字符串和装饰器,描述此类作为没有特定输出头部的裸Wav2Vec2模型转换器
@add_start_docstrings(
    "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
    WAV_2_VEC_2_START_DOCSTRING,
)
# 定义 FlaxWav2Vec2Model 类,继承自 FlaxWav2Vec2PreTrainedModel 类
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
    module_class = FlaxWav2Vec2Module  # 设置模块类为 FlaxWav2Vec2Module


# 定义 FLAX_WAV2VEC2_MODEL_DOCSTRING 作为模型的文档字符串,描述返回值和示例用法
FLAX_WAV2VEC2_MODEL_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoProcessor, FlaxWav2Vec2Model
    >>> from datasets import load_dataset
    >>> import soundfile as sf

    >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-lv60")
    >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")


    >>> def map_to_array(batch):
    ...     speech, _ = sf.read(batch["file"])
    ...     batch["speech"] = speech
    ...     return batch


    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> ds = ds.map(map_to_array)

    >>> input_values = processor(
    ...     ds["speech"][0], sampling_rate=16_000, return_tensors="np"
    ... ).input_values  # Batch size 1
    >>> hidden_states = model(input_values).last_hidden_state
    ```
"""

# 调用 overwrite_call_docstring 函数,将输入的文档字符串添加到 FlaxWav2Vec2Model 类的文档字符串中
overwrite_call_docstring(
    FlaxWav2Vec2Model,
    WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
)

# 调用 append_replace_return_docstrings 函数,为 FlaxWav2Vec2Model 类添加返回值文档字符串
append_replace_return_docstrings(
    FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
)


# 定义 FlaxWav2Vec2ForCTCModule 类,继承自 nn.Module
class FlaxWav2Vec2ForCTCModule(nn.Module):
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    # 初始化函数,设置模块及其成员
    def setup(self):
        self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)  # 初始化 wav2vec2 模块
        self.dropout = nn.Dropout(rate=self.config.final_dropout)  # 初始化 dropout 层
        self.lm_head = nn.Dense(  # 初始化语言模型头部 Dense 层
            self.config.vocab_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    # 调用函数,定义模型的前向传播逻辑
    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        deterministic=True,
        output_attentions=None,
        output_hidden_states=None,
        freeze_feature_encoder=False,
        return_dict=None,
    ):
        # 调用 wav2vec2 模块进行前向传播,获取输出
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            mask_time_indices=mask_time_indices,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            freeze_feature_encoder=freeze_feature_encoder,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]  # 获取隐藏状态
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)  # 应用 dropout

        logits = self.lm_head(hidden_states)  # 计算 logits

        if not return_dict:
            return (logits,) + outputs[2:]  # 返回 logits 和其他输出

        # 返回包含 logits、隐藏状态和注意力的 FlaxCausalLMOutput 对象
        return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
    def _get_feat_extract_output_lengths(
        self,
        input_lengths: Union[jnp.ndarray, int],
        add_adapter: Optional[bool] = None,
    ):
        """
        Computes the output length of the convolutional layers
        """

        # 如果 add_adapter 未提供,则使用配置中的默认值
        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

        def _conv_out_length(input_length, kernel_size, stride):
            # 从 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 获取的
            # 1维卷积层输出长度计算公式
            return (input_length - kernel_size) // stride + 1

        # 遍历每个卷积核大小和步长,并计算每层卷积的输出长度
        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        # 如果需要添加适配器层,根据配置中的适配器层数量和步长进行计算
        if add_adapter:
            for _ in range(self.config.num_adapter_layers):
                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

        # 返回最终计算得到的输出长度
        return input_lengths
# 使用装饰器为 FlaxWav2Vec2ForCTC 类添加文档字符串,描述其为在 Connectionist Temporal Classification (CTC) 上加有语言建模头部的 Wav2Vec2 模型
@add_start_docstrings(
    "Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).",
    WAV_2_VEC_2_START_DOCSTRING,
)
# 定义 FlaxWav2Vec2ForCTC 类,继承自 FlaxWav2Vec2PreTrainedModel 类
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
    # 将 module_class 属性指定为 FlaxWav2Vec2ForCTCModule
    module_class = FlaxWav2Vec2ForCTCModule


# FLAX_WAV2VEC2_FOR_CTC_DOCSTRING 是一个长字符串,描述了 FlaxWav2Vec2ForCTC 类的返回值和示例用法

# 调用 overwrite_call_docstring 函数,为 FlaxWav2Vec2ForCTC 类的文档字符串添加输入参数文档和 FLAX_WAV2VEC2_FOR_CTC_DOCSTRING 内容
overwrite_call_docstring(
    FlaxWav2Vec2ForCTC,
    WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
)

# 调用 append_replace_return_docstrings 函数,为 FlaxWav2Vec2ForCTC 类添加输出类型文档,并指定 output_type 和 config_class 参数
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)


# 定义 FlaxWav2Vec2ForPreTrainingModule 类,继承自 nn.Module 类
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
    # 设置 config 属性为 Wav2Vec2Config 类型,dtype 属性默认为 jnp.float32
    config: Wav2Vec2Config
    dtype: jnp.dtype = jnp.float32

    # 定义 setup 方法,初始化模块
    def setup(self):
        # 实例化 FlaxWav2Vec2Module 类,并存储在 self.wav2vec2 属性中
        self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
        # 使用 self.config.feat_quantizer_dropout 参数初始化 nn.Dropout 类,存储在 self.dropout_features 属性中
        self.dropout_features = nn.Dropout(self.config.feat_quantizer_dropout)

        # 实例化 FlaxWav2Vec2GumbelVectorQuantizer 类,并存储在 self.quantizer 属性中
        self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype)
        # 使用 self.config.proj_codevector_dim 参数初始化 nn.Dense 类,存储在 self.project_q 属性中
        self.project_q = nn.Dense(
            self.config.proj_codevector_dim,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 使用 self.config.proj_codevector_dim 参数初始化 nn.Dense 类,存储在 self.project_hid 属性中
        self.project_hid = nn.Dense(
            self.config.proj_codevector_dim,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    # 定义 __call__ 方法,实现对象的可调用性
    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        gumbel_temperature: int = 1,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        freeze_feature_encoder=False,
        return_dict=None,
        # 函数参数的注释可以在文档字符串中找到
        **kwargs,
    ):
        # 省略方法内部的具体实现,不在注释范围内
        ):
        r"""
        Returns:

        Example:

        ```

        ```"""

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用给定的参数调用wav2vec2模型,获取输出
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            mask_time_indices=mask_time_indices,
            deterministic=deterministic,
            freeze_feature_encoder=freeze_feature_encoder,
            return_dict=return_dict,
        )

        # 将所有转换后的特征(包括被掩码的)投影到最终的向量量化维度
        transformer_features = self.project_hid(outputs[0])

        # 量化所有(未被掩码的)提取特征并投影到最终的向量量化维度
        extract_features = self.dropout_features(outputs[1], deterministic=deterministic)
        quantized_features, codevector_perplexity = self.quantizer(
            extract_features, mask_time_indices, deterministic=deterministic, temperature=gumbel_temperature
        )
        quantized_features = self.project_q(quantized_features)

        # 如果不使用返回字典,则返回元组形式的输出
        if not return_dict:
            return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]

        # 使用FlaxWav2Vec2ForPreTrainingOutput类封装输出,包括所有相关信息
        return FlaxWav2Vec2ForPreTrainingOutput(
            projected_states=transformer_features,
            projected_quantized_states=quantized_features,
            codevector_perplexity=codevector_perplexity,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def _get_feat_extract_output_lengths(
        self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
    ):
        """
        计算卷积层的输出长度
        """

        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

        def _conv_out_length(input_length, kernel_size, stride):
            # 从 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 中得到的一维卷积层输出长度公式
            return (input_length - kernel_size) // stride + 1

        # 遍历配置的卷积核大小和步幅,计算每一层卷积层的输出长度
        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        # 如果需要添加适配器层,则计算适配器层的输出长度
        if add_adapter:
            for _ in range(self.config.num_adapter_layers):
                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

        return input_lengths
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
    module_class = FlaxWav2Vec2ForPreTrainingModule

    @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
    # 覆盖原始定义,添加了 `gumbel_temperature` 输入参数
    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        gumbel_temperature: int = 1,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        gumbel_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        freeze_feature_encoder: bool = False,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_values.shape

        # 如果未提供注意力掩码,则创建一个全为1的注意力掩码
        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # 处理可能需要的任何伪随机数生成器
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        if gumbel_rng is not None:
            rngs["gumbel"] = gumbel_rng

        # 准备模型输入
        inputs = {"params": params or self.params}

        # 调用模块的前向方法
        return self.module.apply(
            inputs,
            jnp.array(input_values, dtype="f4"),
            jnp.array(attention_mask, dtype="i4"),
            mask_time_indices,
            gumbel_temperature,
            not train,
            output_attentions,
            output_hidden_states,
            freeze_feature_encoder,
            return_dict,
            rngs=rngs,
        )


FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> import optax
    >>> import numpy as np
    >>> import jax.numpy as jnp
    >>> from transformers import AutoFeatureExtractor, FlaxWav2Vec2ForPreTraining
    >>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
    >>> from datasets import load_dataset
    >>> import soundfile as sf

    >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
    >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")


    >>> def map_to_array(batch):
    ...     speech, _ = sf.read(batch["file"])
    ...     batch["speech"] = speech
    ...     return batch


    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> ds = ds.map(map_to_array)
    >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values  # 获取输入特征向量值,批大小为1

    >>> # 计算掩码索引
    >>> batch_size, raw_sequence_length = input_values.shape  # 获取批大小和原始序列长度
    >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)  # 根据模型获取特征提取后的序列长度
    >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)  # 计算掩码时间点的索引

    >>> outputs = model(input_values, mask_time_indices=mask_time_indices)  # 使用模型进行推理,传入掩码时间点索引

    >>> # 计算预测状态(outputs.projected_states)与目标状态(outputs.projected_quantized_states)之间的余弦相似度
    >>> cosine_sim = optax.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states)

    >>> # 确保余弦相似度在掩码时间点上的平均值高于0.5
    >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
"""
为 `FlaxWav2Vec2ForPreTraining` 类的 `__call__` 方法覆盖文档字符串,
使用 `WAV_2_VEC_2_INPUTS_DOCSTRING` 和 `FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING` 进行替换。
"""
overwrite_call_docstring(
    FlaxWav2Vec2ForPreTraining,
    WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
)

"""
为 `FlaxWav2Vec2ForPreTraining` 类附加和替换返回值文档字符串,
使用 `FlaxWav2Vec2ForPreTrainingOutput` 作为输出类型,`Wav2Vec2Config` 作为配置类。
"""
append_replace_return_docstrings(
    FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config
)

.\models\wav2vec2\modeling_tf_wav2vec2.py

# 设定代码文件的字符编码为 UTF-8
# 版权声明和许可信息,表明此代码的使用受 Apache 许可证 2.0 版本的约束
#
# 警告:此文件涉及 Fairseq 作者和 HuggingFace Inc. 团队的版权,保留所有权利。

""" TensorFlow Wav2Vec2 模型。"""

from __future__ import annotations  # 允许在类型注解中使用字符串以及类型本身的声明

import warnings  # 引入警告模块
from dataclasses import dataclass  # 导入 dataclass 用于数据类的定义
from typing import Any, Optional, Tuple, Union  # 引入类型提示的模块

import numpy as np  # 引入 NumPy 库
import tensorflow as tf  # 导入 TensorFlow 库

from ...activations_tf import get_tf_activation  # 从本地相对路径导入 TensorFlow 激活函数
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput  # 导入 TensorFlow 模型输出类
from ...modeling_tf_utils import (  # 导入 TensorFlow 模型工具函数
    TFPreTrainedModel,
    get_initializer,
    keras,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax  # 从 TensorFlow 实用工具模块导入函数
from ...utils import (  # 从通用工具模块导入多个实用函数和类
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_wav2vec2 import Wav2Vec2Config  # 从本地相对路径导入 Wav2Vec2 的配置类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_HIDDEN_STATES_START_POSITION = 2  # 设置隐藏状态的起始位置索引为2

_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"  # 预训练模型的检查点名称,用于文档
_CONFIG_FOR_DOC = "Wav2Vec2Config"  # Wav2Vec2 配置文件的名称,用于文档

# 预训练模型存档列表,包含多个预训练模型的名称
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/wav2vec2-base-960h",
    "facebook/wav2vec2-large-960h",
    "facebook/wav2vec2-large-960h-lv60",
    "facebook/wav2vec2-large-960h-lv60-self",
    # 查看所有 Wav2Vec2 模型:https://huggingface.co/models?filter=wav2vec2
]

LARGE_NEGATIVE = -1e8  # 定义一个较大的负数常量,用于特定目的

@dataclass
class TFWav2Vec2BaseModelOutput(ModelOutput):
    """
    [`TFWav2Vec2BaseModelOutput`] 的输出类型,包含潜在的隐藏状态和注意力。
    继承自 ModelOutput 类。
    """
    """
    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层输出的隐藏状态序列。
        extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
            模型最后一个卷积层提取的特征向量序列。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            包含模型每一层输出的隐藏状态的元组。形状为 `(batch_size, sequence_length, hidden_size)`。

            模型每一层的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            包含注意力权重的元组。形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    last_hidden_state: tf.Tensor = None  # 初始化最后一层隐藏状态为 None
    extract_features: tf.Tensor = None  # 初始化提取的特征向量为 None
    hidden_states: Tuple[tf.Tensor] | None = None  # 初始化隐藏状态元组为 None
    attentions: Tuple[tf.Tensor] | None = None  # 初始化注意力权重元组为 None
def _sample_without_replacement(distribution, num_samples):
    """
    Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
    https://github.com/tensorflow/tensorflow/issues/9260 for more info
    """
    # 使用负数对数的随机数作为采样分布
    z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1))
    # 对分布加上 gumbel-max 技巧后,取前 num_samples 个最高分布的索引
    _, indices = tf.nn.top_k(distribution + z, num_samples)
    return indices


def _scatter_values_on_batch_indices(values, batch_indices, output_shape):
    """
    Scatter function as in PyTorch with indices in format (batch_dim, indixes)
    """
    # 获取 batch_indices 的形状
    indices_shape = shape_list(batch_indices)
    # 扩展 batch 维度到 indices_shape 形状
    broad_casted_batch_dims = tf.reshape(
        tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1]
    )
    # 将 batch_indices 转换为成对的 indices
    pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
    # 将 values 根据 pair_indices 散布到指定的 output_shape 上
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape)


def _compute_mask_indices(
    shape: Tuple[int, int],
    mask_prob: float,
    mask_length: int,
    min_masks: int = 0,
) -> tf.Tensor:
    """
    Computes random mask spans for a given shape

    Args:
        shape: the shape for which to compute masks.
            should be of size 2 where first element is batch size and 2nd is timesteps
        attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
        mask_prob:
            probability for each token to be chosen as start of the span to be masked. this will be multiplied by
            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
        mask_length: size of the mask
        min_masks: minimum number of masked spans

    Adapted from [fairseq's
    data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376).
    """
    batch_size, sequence_length = shape

    if mask_length < 1:
        raise ValueError("`mask_length` has to be bigger than 0.")

    tf.debugging.assert_less(
        mask_length,
        sequence_length,
        message=(
            f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
            f" `sequence_length`: {sequence_length}`"
        ),
    )

    # 计算批次中的被遮罩索引数目
    num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
    num_masked_spans = tf.maximum(num_masked_spans, min_masks)
    num_masked_spans = tf.cast(num_masked_spans, tf.int32)

    # 确保被遮罩的索引数目不超过 sequence_length
    # 计算允许的最大掩码数量,确保不超过序列长度的最大掩码数量
    num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
    # 去除可能存在的多余维度,确保得到一个标量值
    num_masked_spans = tf.squeeze(num_masked_spans)

    # 创建一个全零的张量作为 SpecAugment 掩码的初始模板
    spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)

    # 创建一个均匀分布的张量,用于采样掩码的起始索引,确保采样的索引不超过序列长度
    uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1)))

    # 获取随机的索引位置,用于创建掩码
    spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans)

    # 将掩码的索引扩展到掩码跨度
    spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1)
    spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length))
    spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length))

    # 创建偏移量,用于将掩码的索引扩展到每个掩码的具体位置
    offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :]
    offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1))
    offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length))

    # 将偏移量加到掩码的索引上,得到最终的掩码位置
    spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

    # 将掩码应用到 spec_aug_mask 上,使用 _scatter_values_on_batch_indices 函数
    spec_aug_mask = _scatter_values_on_batch_indices(
        tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
    )

    # 返回生成的 SpecAugment 掩码
    return spec_aug_mask
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    # 获取输入张量的第二个维度,即序列长度
    src_len = shape_list(mask)[1]
    # 如果没有提供目标长度,使用源长度作为目标长度
    tgt_len = tgt_len if tgt_len is not None else src_len
    # 创建常数张量,值为1.0,数据类型与输入张量相同
    one_cst = tf.constant(1.0)
    # 将输入张量转换为浮点数类型
    mask = tf.cast(mask, dtype=one_cst.dtype)
    # 在第二维度上复制输入张量,使其形状变为 [bsz, 1, tgt_seq_len, src_seq_len]
    expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))

    # 返回扩展后的注意力掩码,乘以一个大负数,用于模型中的无效化处理
    return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFWav2Vec2GroupNorm(keras.layers.Layer):
    """
    From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
    """

    def __init__(
        self,
        groups: int = 32,
        axis: int = -1,
        epsilon: float = 1e-3,
        center: bool = True,
        scale: bool = True,
        beta_initializer: keras.initializers.Initializer = "zeros",
        gamma_initializer: keras.initializers.Initializer = "ones",
        beta_regularizer: keras.regularizers.Regularizer = None,
        gamma_regularizer: keras.regularizers.Regularizer = None,
        beta_constraint: keras.constraints.Constraint = None,
        gamma_constraint: keras.constraints.Constraint = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.supports_masking = True
        # 分组数
        self.groups = groups
        # 归一化的轴
        self.axis = axis
        # 小数项,防止分母为零
        self.epsilon = epsilon
        # 是否包含中心参数
        self.center = center
        # 是否包含缩放参数
        self.scale = scale
        # beta 初始化器
        self.beta_initializer = keras.initializers.get(beta_initializer)
        # gamma 初始化器
        self.gamma_initializer = keras.initializers.get(gamma_initializer)
        # beta 正则化器
        self.beta_regularizer = keras.regularizers.get(beta_regularizer)
        # gamma 正则化器
        self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
        # beta 约束条件
        self.beta_constraint = keras.constraints.get(beta_constraint)
        # gamma 约束条件
        self.gamma_constraint = keras.constraints.get(gamma_constraint)
        # 检查归一化轴
        self._check_axis()

    def build(self, input_shape):
        # 检查输入形状是否为 None
        self._check_if_input_shape_is_none(input_shape)
        # 设置实例归一化的组数
        self._set_number_of_groups_for_instance_norm(input_shape)
        # 检查维度大小
        self._check_size_of_dimensions(input_shape)
        # 创建输入规范
        self._create_input_spec(input_shape)

        # 添加 gamma 权重
        self._add_gamma_weight(input_shape)
        # 添加 beta 权重
        self._add_beta_weight(input_shape)
        self.built = True
        super().build(input_shape)

    def call(self, inputs):
        # 获取输入张量的形状
        input_shape = keras.backend.int_shape(inputs)
        # 获取输入张量的 TensorFlow 形状
        tensor_input_shape = tf.shape(inputs)

        # 重塑输入张量为分组形状,返回重塑后的张量及其形状
        reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape)

        # 应用归一化操作
        normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)

        # 如果是实例归一化,将张量展平为原始形状
        is_instance_norm = (input_shape[self.axis] // self.groups) == 1
        if not is_instance_norm:
            outputs = tf.reshape(normalized_inputs, tensor_input_shape)
        else:
            outputs = normalized_inputs

        # 返回归一化后的输出张量
        return outputs
    # 获取配置信息的方法,返回一个包含当前层配置信息的字典
    def get_config(self):
        # 构建配置字典,包括各种属性和超参数
        config = {
            "groups": self.groups,
            "axis": self.axis,
            "epsilon": self.epsilon,
            "center": self.center,
            "scale": self.scale,
            "beta_initializer": keras.initializers.serialize(self.beta_initializer),
            "gamma_initializer": keras.initializers.serialize(self.gamma_initializer),
            "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer),
            "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer),
            "beta_constraint": keras.constraints.serialize(self.beta_constraint),
            "gamma_constraint": keras.constraints.serialize(self.gamma_constraint),
        }
        # 调用父类的获取配置方法,获取基础配置信息
        base_config = super().get_config()
        # 合并基础配置和当前层配置,返回完整的配置字典
        return {**base_config, **config}

    # 计算输出形状的方法,直接返回输入的形状
    def compute_output_shape(self, input_shape):
        return input_shape

    # 将输入重塑为分组的形状的方法
    def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
        # 计算分组的形状
        group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
        # 检查是否为实例归一化
        is_instance_norm = (input_shape[self.axis] // self.groups) == 1
        if not is_instance_norm:
            # 如果不是实例归一化,重新设置分组的轴
            group_shape[self.axis] = input_shape[self.axis] // self.groups
            group_shape.insert(self.axis, self.groups)
            group_shape = tf.stack(group_shape)
            # 重新形状化输入数据
            reshaped_inputs = tf.reshape(inputs, group_shape)
            return reshaped_inputs, group_shape
        else:
            # 如果是实例归一化,直接返回输入数据和分组的形状
            return inputs, group_shape

    # 应用归一化操作的方法
    def _apply_normalization(self, reshaped_inputs, input_shape):
        # 获取重塑后输入数据的形状
        group_shape = keras.backend.int_shape(reshaped_inputs)
        # 计算需要归并的轴
        group_reduction_axes = list(range(1, len(group_shape)))
        # 检查是否为实例归一化
        is_instance_norm = (input_shape[self.axis] // self.groups) == 1
        if not is_instance_norm:
            # 如果不是实例归一化,确定归并的轴
            axis = -2 if self.axis == -1 else self.axis - 1
        else:
            # 如果是实例归一化,确定归并的轴
            axis = -1 if self.axis == -1 else self.axis - 1
        group_reduction_axes.pop(axis)

        # 计算均值和方差
        mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True)

        # 获取重塑后的权重 gamma 和 beta
        gamma, beta = self._get_reshaped_weights(input_shape)

        # 应用批归一化
        normalized_inputs = tf.nn.batch_normalization(
            reshaped_inputs,
            mean=mean,
            variance=variance,
            scale=gamma,
            offset=beta,
            variance_epsilon=self.epsilon,
        )
        # 返回归一化后的数据
        return normalized_inputs

    # 获取重塑后的权重 gamma 和 beta 的方法
    def _get_reshaped_weights(self, input_shape):
        # 创建广播形状
        broadcast_shape = self._create_broadcast_shape(input_shape)
        gamma = None
        beta = None
        # 如果开启了 scale 参数,重塑 gamma
        if self.scale:
            gamma = tf.reshape(self.gamma, broadcast_shape)

        # 如果开启了 center 参数,重塑 beta
        if self.center:
            beta = tf.reshape(self.beta, broadcast_shape)

        # 返回重塑后的 gamma 和 beta
        return gamma, beta
    # 检查输入形状中指定轴的维度是否为 None,如果是则引发 ValueError 异常
    def _check_if_input_shape_is_none(self, input_shape):
        dim = input_shape[self.axis]
        if dim is None:
            raise ValueError(
                "Axis "
                + str(self.axis)
                + " of input tensor should have a defined dimension but the layer received an input with shape "
                + str(input_shape)
                + "."
            )

    # 设置 InstanceNormalization 层的分组数目,若分组数为 -1,则设置为输入的维度数
    def _set_number_of_groups_for_instance_norm(self, input_shape):
        dim = input_shape[self.axis]

        if self.groups == -1:
            self.groups = dim

    # 检查维度大小,确保分组数不大于通道数,并且分组数必须是通道数的整数倍
    def _check_size_of_dimensions(self, input_shape):
        dim = input_shape[self.axis]
        if dim < self.groups:
            raise ValueError(
                "Number of groups ("
                + str(self.groups)
                + ") cannot be more than the number of channels ("
                + str(dim)
                + ")."
            )

        if dim % self.groups != 0:
            raise ValueError(
                "Number of groups ("
                + str(self.groups)
                + ") must be a multiple of the number of channels ("
                + str(dim)
                + ")."
            )

    # 检查轴的值,如果为 0,则引发 ValueError 异常,建议使用 tf.layer.batch_normalization
    def _check_axis(self):
        if self.axis == 0:
            raise ValueError(
                "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
            )

    # 创建输入规范(InputSpec),用于指定输入的维度和轴信息
    def _create_input_spec(self, input_shape):
        dim = input_shape[self.axis]
        self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim})

    # 添加 gamma 权重,如果启用 scale,则创建 gamma 权重变量,否则设为 None
    def _add_gamma_weight(self, input_shape):
        dim = input_shape[self.axis]
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                name="gamma",
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
            )
        else:
            self.gamma = None

    # 添加 beta 权重,如果启用 center,则创建 beta 权重变量,否则设为 None
    def _add_beta_weight(self, input_shape):
        dim = input_shape[self.axis]
        shape = (dim,)

        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                name="beta",
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
            )
        else:
            self.beta = None

    # 创建广播形状,用于 InstanceNormalization 层的归一化操作
    def _create_broadcast_shape(self, input_shape):
        broadcast_shape = [1] * len(input_shape)
        is_instance_norm = (input_shape[self.axis] // self.groups) == 1
        if not is_instance_norm:
            broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
            broadcast_shape.insert(self.axis, self.groups)
        else:
            broadcast_shape[self.axis] = self.groups
        return broadcast_shape
class TFWav2Vec2WeightNormConv1D(keras.layers.Conv1D):
    """Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm"""

    def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
        # 调用父类构造函数初始化卷积层
        super().__init__(
            filters=filters,
            kernel_size=kernel_size,
            groups=groups,
            padding="valid",
            use_bias=True,
            bias_initializer="he_normal",
            **kwargs,
        )
        self.explicit_padding = explicit_padding  # 设置是否使用显式填充
        self.filter_axis = 2  # 卷积核在权重张量中的轴索引
        self.kernel_norm_axes = tf.constant([0, 1])  # 计算卷积核标准化时的轴索引

    def _init_norm(self):
        """Set the norm of the weight vector."""
        # 计算权重向量的范数,用于初始化权重标准化的参数
        kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes))
        self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis])  # 将计算得到的范数赋值给权重标准化的参数

    def _normalize_kernel(self):
        """Generate normalized weights."""
        # 标准化卷积核的权重
        kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g)
        self.kernel = tf.transpose(kernel)  # 转置得到标准化后的卷积核权重

    def build(self, input_shape):
        if not self.built:
            super().build(input_shape)

            # 初始化权重向量并赋值给self.weight_v
            self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
            self.weight_v = self.kernel

            # 添加权重参数weight_g,用于存储卷积核标准化的参数
            self.weight_g = self.add_weight(
                name="weight_g",
                shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1),
                initializer="ones",
                dtype=self.weight_v.dtype,
                trainable=True,
            )
            self._init_norm()  # 初始化权重标准化参数
            self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)

    def call(self, inputs):
        # 在call方法中标准化卷积核的权重
        self._normalize_kernel()

        # 对输入进行显式填充
        padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
        # 调用父类的call方法进行卷积操作
        output = super().call(padded_inputs)

        return output


class TFWav2Vec2NoLayerNormConvLayer(keras.layers.Layer):
    def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # 根据配置文件初始化输入和输出的卷积维度
        self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
        self.out_conv_dim = config.conv_dim[layer_id]

        # 初始化卷积层,根据配置文件中的参数设置
        self.conv = keras.layers.Conv1D(
            filters=self.out_conv_dim,
            kernel_size=config.conv_kernel[layer_id],
            strides=config.conv_stride[layer_id],
            use_bias=config.conv_bias,
            name="conv",
        )
        # 根据配置文件获取激活函数,并赋值给self.activation
        self.activation = get_tf_activation(config.feat_extract_activation)
    # 定义一个方法用于调用卷积层和激活函数处理隐藏状态张量,并返回处理后的张量
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 使用定义好的卷积层处理隐藏状态张量
        hidden_states = self.conv(hidden_states)
        # 使用定义好的激活函数处理卷积后的张量
        hidden_states = self.activation(hidden_states)
        # 返回处理后的张量
        return hidden_states

    # 定义一个方法用于构建模型,初始化卷积层
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回
        if self.built:
            return
        # 标记模型已构建
        self.built = True
        # 如果存在卷积层,则在命名作用域内构建卷积层
        if getattr(self, "conv", None) is not None:
            with tf.name_scope(self.conv.name):
                # 构建卷积层,指定输入形状为 [None, None, self.in_conv_dim]
                self.conv.build([None, None, self.in_conv_dim])
class TFWav2Vec2LayerNormConvLayer(keras.layers.Layer):
    # 初始化函数,设置层的参数和配置
    def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # 根据层 ID 设置输入和输出的卷积维度
        self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
        self.out_conv_dim = config.conv_dim[layer_id]

        # 创建卷积层对象
        self.conv = keras.layers.Conv1D(
            filters=self.out_conv_dim,
            kernel_size=config.conv_kernel[layer_id],
            strides=config.conv_stride[layer_id],
            use_bias=config.conv_bias,
            name="conv",
        )
        # 创建层归一化对象
        self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps)
        # 获取激活函数对象
        self.activation = get_tf_activation(config.feat_extract_activation)

    # 前向传播函数,定义层的计算逻辑
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 卷积操作
        hidden_states = self.conv(hidden_states)
        # 层归一化操作
        hidden_states = self.layer_norm(hidden_states)
        # 激活函数操作
        hidden_states = self.activation(hidden_states)
        return hidden_states

    # 构建函数,用于构建层内的各个子层
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建卷积层
        if getattr(self, "conv", None) is not None:
            with tf.name_scope(self.conv.name):
                self.conv.build([None, None, self.in_conv_dim])
        # 构建层归一化层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.out_conv_dim])


class TFWav2Vec2GroupNormConvLayer(keras.layers.Layer):
    # 初始化函数,设置层的参数和配置
    def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # 根据层 ID 设置输入和输出的卷积维度
        self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
        self.out_conv_dim = config.conv_dim[layer_id]

        # 创建卷积层对象
        self.conv = keras.layers.Conv1D(
            filters=self.out_conv_dim,
            kernel_size=config.conv_kernel[layer_id],
            strides=config.conv_stride[layer_id],
            use_bias=config.conv_bias,
            name="conv",
        )
        # 获取激活函数对象
        self.activation = get_tf_activation(config.feat_extract_activation)
        # 创建分组归一化层对象
        self.layer_norm = TFWav2Vec2GroupNorm(
            groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm"
        )

    # 前向传播函数,定义层的计算逻辑
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 卷积操作
        hidden_states = self.conv(hidden_states)
        # 分组归一化操作
        hidden_states = self.layer_norm(hidden_states)
        # 激活函数操作
        hidden_states = self.activation(hidden_states)
        return hidden_states

    # 构建函数,用于构建层内的各个子层
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建卷积层
        if getattr(self, "conv", None) is not None:
            with tf.name_scope(self.conv.name):
                self.conv.build([None, None, self.in_conv_dim])
        # 构建分组归一化层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.out_conv_dim])
class TFWav2Vec2PositionalConvEmbedding(keras.layers.Layer):
    # 定义 TF Wav2Vec2 的位置卷积嵌入层,继承自 Keras 的层
    def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # 初始化函数,接收配置对象和其他关键字参数
        self.conv = TFWav2Vec2WeightNormConv1D(
            filters=config.hidden_size,
            kernel_size=config.num_conv_pos_embeddings,
            groups=config.num_conv_pos_embedding_groups,
            explicit_padding=config.num_conv_pos_embeddings // 2,
            name="conv",
        )
        # 设置卷积层,使用权重归一化的 TF Wav2Vec2 卷积层
        self.padding = TFWav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
        # 设置填充层,用于保持卷积输出的长度
        self.activation = get_tf_activation(config.feat_extract_activation)
        # 获取激活函数并设置为实例的属性
        self.config = config
        # 保存配置对象到实例的属性中

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 定义调用函数,处理输入的隐藏状态张量并返回处理后的张量
        hidden_states = self.conv(hidden_states)
        # 经过卷积层处理
        hidden_states = self.padding(hidden_states)
        # 经过填充层处理
        hidden_states = self.activation(hidden_states)
        # 经过激活函数处理
        return hidden_states
        # 返回处理后的隐藏状态张量

    def build(self, input_shape=None):
        # 构建函数,在第一次调用时构建层的变量
        if self.built:
            return
        self.built = True
        if getattr(self, "conv", None) is not None:
            with tf.name_scope(self.conv.name):
                self.conv.build([None, None, self.config.hidden_size])
                # 使用配置的隐藏大小构建卷积层



class TFWav2Vec2SamePadLayer(keras.layers.Layer):
    # 定义 TF Wav2Vec2 的同填充层,继承自 Keras 的层
    def __init__(self, num_conv_pos_embeddings, **kwargs):
        super().__init__(**kwargs)
        # 初始化函数,接收卷积位置嵌入数目和其他关键字参数
        self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
        # 计算需要移除的填充数目,根据卷积位置嵌入的奇偶性确定

    def call(self, hidden_states):
        # 定义调用函数,处理输入的隐藏状态张量并返回处理后的张量
        if self.num_pad_remove > 0:
            hidden_states = hidden_states[:, : -self.num_pad_remove, :]
            # 如果需要移除填充,则在最后一个维度上移除相应数量的填充
        return hidden_states
        # 返回处理后的隐藏状态张量



class TFWav2Vec2FeatureEncoder(keras.layers.Layer):
    # 定义 TF Wav2Vec2 的特征编码器层,继承自 Keras 的层
    def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # 初始化函数,接收配置对象和其他关键字参数
        if config.feat_extract_norm == "group":
            # 如果特征提取归一化方式为 group
            conv_layers = [TFWav2Vec2GroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [
                TFWav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i+1}")
                for i in range(config.num_feat_extract_layers - 1)
            ]
            # 创建一组带有组归一化的卷积层
        elif config.feat_extract_norm == "layer":
            # 如果特征提取归一化方式为 layer
            conv_layers = [
                TFWav2Vec2LayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}")
                for i in range(config.num_feat_extract_layers)
            ]
            # 创建一组带有层归一化的卷积层
        else:
            # 如果特征提取归一化方式既不是 group 也不是 layer,则抛出异常
            raise ValueError(
                f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
            )
        self.conv_layers = conv_layers
        # 保存创建的卷积层列表到实例的属性中

    def call(self, input_values):
        # 定义调用函数,处理输入值并返回处理后的张量
        hidden_states = tf.expand_dims(input_values, -1)
        # 在最后一个维度上扩展输入张量
        for conv_layer in self.conv_layers:
            hidden_states = conv_layer(hidden_states)
            # 通过每个卷积层处理隐藏状态
        return hidden_states
        # 返回处理后的隐藏状态张量
    # 定义神经网络层的构建方法,接收输入形状作为参数,如果已经构建过则直接返回
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,不再重复构建
        if self.built:
            return
        # 标记该神经网络层已经构建
        self.built = True
        # 如果存在卷积层列表,则逐个构建每个卷积层
        if getattr(self, "conv_layers", None) is not None:
            for conv_layer in self.conv_layers:
                # 使用 TensorFlow 的命名作用域,将当前卷积层的名称作为作用域名称
                with tf.name_scope(conv_layer.name):
                    # 调用卷积层的 build 方法来构建该层
                    conv_layer.build(None)
# 定义 TFWav2Vec2FeatureExtractor 类,继承自 TFWav2Vec2FeatureEncoder 类
class TFWav2Vec2FeatureExtractor(TFWav2Vec2FeatureEncoder):
    
    # 初始化方法,接受 config 和额外的关键字参数
    def __init__(self, config, **kwargs):
        # 调用父类 TFWav2Vec2FeatureEncoder 的初始化方法
        super().__init__(config, **kwargs)
        
        # 发出警告信息,提示该类即将被废弃
        warnings.warn(
            f"The class `{self.__class__.__name__}` has been depreciated "
            "and will be removed in Transformers v5. "
            f"Use `{self.__class__.__bases__[0].__name__}` instead.",
            FutureWarning,
        )


# 定义 TFWav2Vec2FeatureProjection 类,继承自 keras 的 Layer 类
class TFWav2Vec2FeatureProjection(keras.layers.Layer):
    
    # 初始化方法,接受 config 参数和其他关键字参数
    def __init__(self, config: Wav2Vec2Config, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 使用 config 中的参数创建 LayerNormalization 层,设置 epsilon 和名称
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        
        # 创建 Dense 层作为投影层,设置单元数、初始化器和偏置初始化器
        self.projection = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            bias_initializer="zeros",
            name="projection",
        )
        
        # 创建 Dropout 层,设置丢弃率
        self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout)
        
        # 保存 config 参数
        self.config = config

    # 调用方法,接受 hidden_states 和 training 参数,返回 Tensor
    def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 对 hidden_states 进行 LayerNormalization 处理
        norm_hidden_states = self.layer_norm(hidden_states)
        
        # 将处理后的 hidden_states 投影到新的维度空间
        hidden_states = self.projection(norm_hidden_states)
        
        # 根据 training 参数应用 Dropout
        hidden_states = self.dropout(hidden_states, training=training)
        
        # 返回处理后的 hidden_states
        return hidden_states, norm_hidden_states

    # 构建方法,用于构建层结构
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        
        # 标记已经构建
        self.built = True
        
        # 如果存在 layer_norm 属性
        if getattr(self, "layer_norm", None) is not None:
            # 在 layer_norm 的名称空间下构建该层,传入形状参数
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.conv_dim[-1]])
        
        # 如果存在 projection 属性
        if getattr(self, "projection", None) is not None:
            # 在 projection 的名称空间下构建该层,传入形状参数
            with tf.name_scope(self.projection.name):
                self.projection.build([None, None, self.config.conv_dim[-1]])


# 从 transformers.models.bart.modeling_tf_bart.TFBartAttention 复制而来,修改为 TFWav2Vec2Attention
class TFWav2Vec2Attention(keras.layers.Layer):
    """Multi-headed attention from "Attention Is All You Need"""

    # 初始化方法,接受多个参数包括 embed_dim, num_heads, dropout 等
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        **kwargs,
        # 调用父类初始化函数,传入指定的关键字参数
        super().__init__(**kwargs)
        # 初始化嵌入维度
        self.embed_dim = embed_dim

        # 初始化注意力头数
        self.num_heads = num_heads
        # 初始化dropout层
        self.dropout = keras.layers.Dropout(dropout)
        # 初始化头部维度
        self.head_dim = embed_dim // num_heads
        # 如果头部维度乘以注意力头数不等于嵌入维度,抛出数值错误异常
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        # 初始化缩放因子
        self.scaling = self.head_dim**-0.5
        # 初始化是否是解码器的标志
        self.is_decoder = is_decoder

        # 初始化k投影层
        self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
        # 初始化q投影层
        self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
        # 初始化v投影层
        self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
        # 初始化输出投影层
        self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")

    # 定义变形函数,接收张量、序列长度和批大小作为输入,返回变形后的张量
    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))

    # 定义调用函数,接收隐藏状态、键值状态、过去的键值对、注意力掩码、层头遮罩和训练标志作为输入
    def call(
        self,
        hidden_states: tf.Tensor,
        key_value_states: tf.Tensor | None = None,
        past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
        attention_mask: tf.Tensor | None = None,
        layer_head_mask: tf.Tensor | None = None,
        training: Optional[bool] = False,
    # 定义构建函数,接收输入形状作为输入,并在已构建时返回
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果存在k_proj,则构建k_proj
        if getattr(self, "k_proj", None) is not None:
            with tf.name_scope(self.k_proj.name):
                self.k_proj.build([None, None, self.embed_dim])
        # 如果存在q_proj,则构建q_proj
        if getattr(self, "q_proj", None) is not None:
            with tf.name_scope(self.q_proj.name):
                self.q_proj.build([None, None, self.embed_dim])
        # 如果存在v_proj,则构建v_proj
        if getattr(self, "v_proj", None) is not None:
            with tf.name_scope(self.v_proj.name):
                self.v_proj.build([None, None, self.embed_dim])
        # 如果存在out_proj,则构建out_proj
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                self.out_proj.build([None, None, self.embed_dim])
# 定义一个自定义的 Transformer 编码器层,基于 Keras 的 Layer 类
class TFWav2Vec2EncoderLayer(keras.layers.Layer):
    # 初始化方法,接收 Wav2Vec2Config 类型的配置参数和其他关键字参数
    def __init__(self, config: Wav2Vec2Config, **kwargs):
        super().__init__(**kwargs)
        
        # 初始化注意力机制模块,使用 TFWav2Vec2Attention 类
        self.attention = TFWav2Vec2Attention(
            embed_dim=config.hidden_size,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=False,
            name="attention",
        )
        
        # Dropout 层,用于隐藏状态的随机失活
        self.dropout = keras.layers.Dropout(config.hidden_dropout)
        
        # LayerNormalization 层,用于归一化层输入,防止梯度爆炸
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        
        # 基于配置参数初始化前馈神经网络层
        self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward")
        
        # 最终的 LayerNormalization 层,用于归一化前馈神经网络的输出
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
        
        # 保存配置参数
        self.config = config

    # 前向传播方法,接收隐藏状态张量和训练标志作为输入,返回处理后的张量
    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        output_attentions: Optional[bool] = False,
        training: bool = False,
    ) -> tf.Tensor:
        # 使用注意力机制进行处理
        hidden_states = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            training=training,
        )
        
        # 使用 Dropout 对注意力机制输出进行随机失活
        hidden_states = self.dropout(hidden_states, training=training)
        
        # 应用 LayerNormalization 对随机失活后的隐藏状态进行归一化
        hidden_states = self.layer_norm(hidden_states)
        
        # 使用前馈神经网络处理归一化后的隐藏状态
        hidden_states = self.feed_forward(hidden_states, training=training)
        
        # 最终使用 LayerNormalization 对前馈神经网络的输出进行归一化
        hidden_states = self.final_layer_norm(hidden_states)
        
        # 返回处理后的张量作为编码器层的输出
        return hidden_states
    # 定义函数,该函数接受隐藏状态、注意力掩码和训练标志作为输入,返回包含注意力权重的元组
    ) -> Tuple[tf.Tensor]:
        # 将原始的隐藏状态保存到变量 attn_residual 中
        attn_residual = hidden_states
        # 调用 self.attention 对象进行注意力计算,返回新的隐藏状态、注意力权重和占位符
        hidden_states, attn_weights, _ = self.attention(
            hidden_states, attention_mask=attention_mask, training=training
        )
        # 对隐藏状态应用 dropout 操作,用于正则化
        hidden_states = self.dropout(hidden_states, training=training)
        # 将原始的隐藏状态与新的隐藏状态相加,得到残差连接的结果
        hidden_states = attn_residual + hidden_states

        # 应用层归一化操作
        hidden_states = self.layer_norm(hidden_states)
        # 经过前馈神经网络处理隐藏状态
        hidden_states = hidden_states + self.feed_forward(hidden_states)
        # 再次进行最终层的归一化操作
        hidden_states = self.final_layer_norm(hidden_states)

        # 构建输出元组,只包含隐藏状态
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则将注意力权重加入输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        # 返回最终的输出元组
        return outputs

    # 定义 build 方法,用于构建模型的层次结构
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 设置标志位表示已经构建过
        self.built = True

        # 如果存在 self.attention 属性,则构建 attention 层
        if getattr(self, "attention", None) is not None:
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 如果存在 self.layer_norm 属性,则构建 layer_norm 层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.hidden_size])
        
        # 如果存在 self.feed_forward 属性,则构建 feed_forward 层
        if getattr(self, "feed_forward", None) is not None:
            with tf.name_scope(self.feed_forward.name):
                self.feed_forward.build(None)
        
        # 如果存在 self.final_layer_norm 属性,则构建 final_layer_norm 层
        if getattr(self, "final_layer_norm", None) is not None:
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.config.hidden_size])
class TFWav2Vec2EncoderLayerStableLayerNorm(keras.layers.Layer):
    # TFWav2Vec2EncoderLayerStableLayerNorm 类,继承自 keras.layers.Layer

    def __init__(self, config: Wav2Vec2Config, **kwargs):
        # 初始化函数,接受一个 Wav2Vec2Config 类型的 config 对象和其他关键字参数

        super().__init__(**kwargs)
        # 调用父类的初始化函数

        # 创建注意力层对象
        self.attention = TFWav2Vec2Attention(
            embed_dim=config.hidden_size,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=False,
            name="attention",
        )
        # 创建 Dropout 层
        self.dropout = keras.layers.Dropout(config.hidden_dropout)
        # 创建 LayerNormalization 层
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        # 创建前馈网络层
        self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward")
        # 创建最终 LayerNormalization 层
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")

        # 保存配置对象
        self.config = config

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        output_attentions: Optional[bool] = False,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 定义 call 方法用于前向传播

        # 保存注意力层的残差连接
        attn_residual = hidden_states
        # LayerNormalization 层
        hidden_states = self.layer_norm(hidden_states)
        # 注意力计算
        hidden_states, attn_weights, _ = self.attention(
            hidden_states, attention_mask=attention_mask, training=training
        )
        # Dropout 层
        hidden_states = self.dropout(hidden_states, training=training)
        # 残差连接
        hidden_states = attn_residual + hidden_states
        # 前馈网络计算
        hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))

        # 返回输出结果
        outputs = (hidden_states,)

        # 如果需要输出注意力权重
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

    def build(self, input_shape=None):
        # 构建函数,用于构建层的参数

        if self.built:
            return

        self.built = True

        # 构建注意力层
        if getattr(self, "attention", None) is not None:
            with tf.name_scope(self.attention.name):
                self.attention.build(None)

        # 构建 LayerNormalization 层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.hidden_size])

        # 构建前馈网络层
        if getattr(self, "feed_forward", None) is not None:
            with tf.name_scope(self.feed_forward.name):
                self.feed_forward.build(None)

        # 构建最终 LayerNormalization 层
        if getattr(self, "final_layer_norm", None) is not None:
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.config.hidden_size])


class TFWav2Vec2Encoder(keras.layers.Layer):
    # TFWav2Vec2Encoder 类,继承自 keras.layers.Layer

    def __init__(self, config: Wav2Vec2Config, **kwargs):
        # 初始化函数,接受一个 Wav2Vec2Config 类型的 config 对象和其他关键字参数

        super().__init__(**kwargs)
        # 调用父类的初始化函数

        # 保存配置对象
        self.config = config

        # 创建位置卷积嵌入层
        self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed")
        # 创建 LayerNormalization 层
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        # 创建 Dropout 层
        self.dropout = keras.layers.Dropout(config.hidden_dropout)

        # 创建多层编码器层列表
        self.layer = [TFWav2Vec2EncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
    # 定义一个方法用于处理模型调用过程中的输入和输出
    def call(
        self,
        hidden_states: tf.Tensor,  # 输入的隐藏状态张量
        attention_mask: tf.Tensor | None = None,  # 注意力掩码张量,默认为None
        output_attentions: Optional[bool] = False,  # 是否输出注意力权重,默认为False
        output_hidden_states: Optional[bool] = False,  # 是否输出隐藏状态,默认为False
        return_dict: Optional[bool] = True,  # 是否返回字典格式的输出,默认为True
        training: Optional[bool] = False,  # 是否处于训练模式,默认为False
    ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
        all_hidden_states = () if output_hidden_states else None  # 如果需要输出隐藏状态,则初始化一个空元组,否则为None
        all_self_attentions = () if output_attentions else None  # 如果需要输出注意力权重,则初始化一个空元组,否则为None

        if attention_mask is not None:
            hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)  # 对隐藏状态应用注意力掩码
            attention_mask = _expand_mask(attention_mask)  # 扩展注意力掩码的维度
        else:
            attention_mask = None  # 如果没有提供注意力掩码,则置为None

        position_embeddings = self.pos_conv_embed(hidden_states)  # 使用位置卷积嵌入处理隐藏状态
        hidden_states = hidden_states + position_embeddings  # 加上位置嵌入的结果
        hidden_states = self.layer_norm(hidden_states)  # 使用层归一化处理隐藏状态
        hidden_states = self.dropout(hidden_states, training=training)  # 应用丢弃操作

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)  # 如果需要输出隐藏状态,则将当前隐藏状态添加到元组中

            # 添加层丢弃(详见 https://arxiv.org/abs/1909.11556)
            dropout_probability = np.random.uniform(0, 1)
            if training and (dropout_probability < self.config.layerdrop):  # 如果处于训练状态且随机数小于层丢弃概率,则跳过当前层
                continue

            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                training=training,
            )
            hidden_states = layer_outputs[0]  # 更新隐藏状态为当前层的输出

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)  # 如果需要输出注意力权重,则将当前层的注意力权重添加到元组中

        # 添加最后一层的隐藏状态输出
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 根据return_dict的设置返回相应的输出格式
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        return TFBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

    # 构建模型,初始化各个组件
    def build(self, input_shape=None):
        if self.built:
            return  # 如果模型已构建,则直接返回
        self.built = True  # 标记模型已构建
        if getattr(self, "pos_conv_embed", None) is not None:
            with tf.name_scope(self.pos_conv_embed.name):
                self.pos_conv_embed.build(None)  # 构建位置卷积嵌入层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.hidden_size])  # 构建层归一化层
        if getattr(self, "layer", None) is not None:
            for layer in self.layer:
                with tf.name_scope(layer.name):
                    layer.build(None)  # 逐层构建模型的层
class TFWav2Vec2EncoderStableLayerNorm(keras.layers.Layer):
    # 初始化函数,接收配置参数 config 和其他关键字参数
    def __init__(self, config: Wav2Vec2Config, **kwargs):
        super().__init__(**kwargs)
        # 保存配置参数
        self.config = config
        # 创建位置编码卷积嵌入层对象,命名为 pos_conv_embed
        self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed")
        # 创建层归一化对象,使用配置中的 epsilon,命名为 layer_norm
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        # 创建 dropout 层,使用配置中的隐藏层 dropout 率
        self.dropout = keras.layers.Dropout(config.hidden_dropout)
        # 创建多个编码器层,列表中包含 config.num_hidden_layers 个 TFWav2Vec2EncoderLayerStableLayerNorm 实例
        self.layer = [
            TFWav2Vec2EncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
        ]

    # 前向传播函数,接收隐藏状态、注意力掩码和其他控制参数,返回 TFBaseModelOutput 或元组 tf.Tensor
    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
        # 如果需要输出隐藏状态,则初始化空元组 all_hidden_states,否则设为 None
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力权重,则初始化空元组 all_self_attentions,否则设为 None
        all_self_attentions = () if output_attentions else None

        # 如果存在 attention_mask,则将隐藏状态与 attention_mask 相乘,实现掩码效果,并使用 _expand_mask 对 attention_mask 进行扩展
        if attention_mask is not None:
            hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
            attention_mask = _expand_mask(attention_mask)
        else:
            attention_mask = None

        # 计算位置编码并添加到隐藏状态中
        position_embeddings = self.pos_conv_embed(hidden_states)
        hidden_states = hidden_states + position_embeddings
        # 对隐藏状态应用 dropout,根据训练状态进行区分
        hidden_states = self.dropout(hidden_states, training=training)

        # 遍历每个编码器层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态加入 all_hidden_states
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 添加 LayerDrop 机制,根据配置中的 layerdrop 参数跳过某些层
            dropout_probability = np.random.uniform(0, 1)
            if training and (dropout_probability < self.config.layerdrop):  # 根据概率跳过该层
                continue

            # 调用当前编码器层进行前向传播
            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                training=training,
            )
            # 更新隐藏状态为编码器层的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则将当前层的注意力权重加入 all_self_attentions
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 对最终的隐藏状态进行层归一化
        hidden_states = self.layer_norm(hidden_states)

        # 如果需要输出隐藏状态,则将最终的隐藏状态加入 all_hidden_states
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果 return_dict 为 False,则返回非 None 的结果元组
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        # 否则,返回 TFBaseModelOutput 对象,包含最终隐藏状态、所有隐藏状态和所有注意力权重
        return TFBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
    # 定义神经网络层的构建方法,如果已经构建过则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        # 设置标志位,表示网络已经构建
        self.built = True

        # 如果存在位置卷积嵌入层,则构建该层
        if getattr(self, "pos_conv_embed", None) is not None:
            # 在 TensorFlow 中设置命名空间为位置卷积嵌入层的名称,并进行构建
            with tf.name_scope(self.pos_conv_embed.name):
                self.pos_conv_embed.build(None)

        # 如果存在层归一化层,则构建该层
        if getattr(self, "layer_norm", None) is not None:
            # 在 TensorFlow 中设置命名空间为层归一化层的名称,并进行构建
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.hidden_size])

        # 如果存在多个子层,则依次构建每个子层
        if getattr(self, "layer", None) is not None:
            for layer in self.layer:
                # 在 TensorFlow 中设置命名空间为子层的名称,并进行构建
                with tf.name_scope(layer.name):
                    layer.build(None)
# 使用 Keras 序列化装饰器标记该类可以被序列化
@keras_serializable
class TFWav2Vec2MainLayer(keras.layers.Layer):
    # 指定配置类为 Wav2Vec2Config
    config_class = Wav2Vec2Config

    # 初始化函数,接受配置对象作为参数,初始化各个子层
    def __init__(self, config: Wav2Vec2Config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        # 创建特征提取器对象,使用给定的配置对象,并命名为 "feature_extractor"
        self.feature_extractor = TFWav2Vec2FeatureEncoder(config, name="feature_extractor")
        # 创建特征投影对象,使用给定的配置对象,并命名为 "feature_projection"
        self.feature_projection = TFWav2Vec2FeatureProjection(config, name="feature_projection")

        # 根据配置选择稳定层归一化编码器或一般编码器
        if config.do_stable_layer_norm:
            self.encoder = TFWav2Vec2EncoderStableLayerNorm(config, name="encoder")
        else:
            self.encoder = TFWav2Vec2Encoder(config, name="encoder")

    # 构建函数,构建该层的权重和子层
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True

        # 如果配置中设置了时间掩码或特征掩码的概率大于0,则添加用于掩码的权重
        if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
            self.masked_spec_embed = self.add_weight(
                shape=(self.config.hidden_size,),  # 形状为隐藏尺寸大小的一维向量
                initializer="uniform",  # 使用均匀分布初始化权重
                trainable=True,  # 可训练
                name="masked_spec_embed"  # 权重的名称为 "masked_spec_embed"
            )

        # 如果存在特征提取器对象,则构建特征提取器
        if getattr(self, "feature_extractor", None) is not None:
            with tf.name_scope(self.feature_extractor.name):
                self.feature_extractor.build(None)

        # 如果存在特征投影对象,则构建特征投影
        if getattr(self, "feature_projection", None) is not None:
            with tf.name_scope(self.feature_projection.name):
                self.feature_projection.build(None)

        # 如果存在编码器对象,则构建编码器
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)

    # 计算卷积层的输出长度
    def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
        """
        Computes the output length of the convolutional layers
        """

        def _conv_out_length(input_length, kernel_size, stride):
            # 1D 卷积层的输出长度公式,参考自 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
            return (input_length - kernel_size) // stride + 1

        # 对于每个卷积核大小和步长,依次计算输出长度
        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        return input_lengths
        # 掩盖隐藏状态,根据时间索引和/或特征索引掩盖提取的特征,根据[SpecAugment](https://arxiv.org/abs/1904.08779)进行操作
        def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None):
            """
            Masks extracted features along time axis and/or along feature axis according to
            [SpecAugment](https://arxiv.org/abs/1904.08779).
            """
            # 获取隐藏状态的形状
            batch_size, sequence_length, hidden_size = shape_list(hidden_states)

            # 如果config.apply_spec_augment设置为False,则不进行掩盖操作
            if not getattr(self.config, "apply_spec_augment", True):
                return hidden_states

            # 如果传入了mask_time_indices
            if mask_time_indices is not None:
                # 根据给定的mask_time_indices沿时间轴应用SpecAugment掩盖
                hidden_states = tf.where(
                    tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
                    self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
                    hidden_states,
                )

            # 如果未传入mask_time_indices,并且mask_time_prob大于0
            elif self.config.mask_time_prob > 0:
                # 生成索引并沿时间轴应用SpecAugment
                mask_time_indices = _compute_mask_indices(
                    (batch_size, sequence_length),
                    mask_prob=self.config.mask_time_prob,
                    mask_length=self.config.mask_time_length,
                    min_masks=2,
                )
                hidden_states = tf.where(
                    tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
                    self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
                    hidden_states,
                )

            # 沿特征轴应用SpecAugment
            if self.config.mask_feature_prob > 0:
                mask_feature_indices = _compute_mask_indices(
                    (batch_size, hidden_size),
                    mask_prob=self.config.mask_feature_prob,
                    mask_length=self.config.mask_feature_length,
                )
                hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0)

            return hidden_states

        # 解包输入参数并调用模型
        @unpack_inputs
        def call(
            self,
            input_values: tf.Tensor,
            attention_mask: tf.Tensor | None = None,
            token_type_ids: tf.Tensor | None = None,
            position_ids: tf.Tensor | None = None,
            head_mask: tf.Tensor | None = None,
            inputs_embeds: tf.Tensor | None = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            training: bool = False,
            **kwargs: Any,
        ):
        # 使用特征提取器从输入值中提取特征,返回特征张量
        extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
        # 如果需要,可以转置提取的特征张量的维度顺序
        # extract_features = tf.transpose(extract_features, perm=(0, 2, 1))

        if attention_mask is not None:
            # 根据卷积公式计算真实的输出长度
            output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))

            # 根据计算得到的长度创建注意力掩码
            attention_mask = tf.sequence_mask(
                output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
            )

        # 将提取的特征张量投影到隐藏状态空间中
        hidden_states, extract_features = self.feature_projection(extract_features, training=training)

        # 获取可选参数中的时间索引屏蔽信息
        mask_time_indices = kwargs.get("mask_time_indices", None)
        if training:
            # 如果处于训练模式,则对隐藏状态进行时间屏蔽处理
            hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)

        # 将隐藏状态输入到编码器中进行编码
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 从编码器输出中获取隐藏状态
        hidden_states = encoder_outputs[0]

        # 如果不返回字典形式的结果,则返回一个包含隐藏状态、提取的特征和其他编码器输出的元组
        if not return_dict:
            return (hidden_states, extract_features) + encoder_outputs[1:]

        # 返回一个包含 TF Wav2Vec2 模型输出的命名元组
        return TFWav2Vec2BaseModelOutput(
            last_hidden_state=hidden_states,
            extract_features=extract_features,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 Wav2Vec2Config 类作为配置类
    config_class = Wav2Vec2Config
    # 基础模型的前缀字符串
    base_model_prefix = "wav2vec2"
    # 主输入的名称
    main_input_name = "input_values"

    @property
    def input_signature(self):
        # 定义模型输入的签名,包括 input_values 和 attention_mask
        return {
            "input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
            "attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
        }

    @property
    def dummy_inputs(self):
        # 返回一个示例的输入字典,包含随机生成的 input_values 和全为1的 attention_mask
        return {
            "input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32),
            "attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32),
        }

    def __init__(self, config, *inputs, **kwargs):
        # 调用父类的构造方法,并打印警告信息,指出CPU上不支持反向传播操作,需要使用GPU或TPU进行训练/微调
        super().__init__(config, *inputs, **kwargs)
        logger.warning(
            f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
            "to train/fine-tune this model, you need a GPU or a TPU"
        )

    def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
        """
        Computes the output length of the convolutional layers
        """
        # 如果 add_adapter 未提供,则使用配置中的 add_adapter 参数
        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

        def _conv_out_length(input_length, kernel_size, stride):
            # 计算卷积层的输出长度
            return tf.math.floordiv(input_length - kernel_size, stride) + 1

        # 对每个卷积核和步长进行迭代,更新输入长度
        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        # 如果配置中启用了 adapter layers,则对每个 adapter layer 同样计算输出长度
        if add_adapter:
            for _ in range(self.config.num_adapter_layers):
                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
        return input_lengths

    def _get_feature_vector_attention_mask(
        self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None
        )
            # 计算非填充长度,即每个样本序列的实际长度
            non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1]
            # 获取特征提取器输出的长度,考虑是否添加适配器
            output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
            output_lengths = tf.cast(output_lengths, tf.int32)
            batch_size = tf.shape(attention_mask)[0]
            # 检查设备位置
            attention_mask = tf.zeros(
                (batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask"
            )  # 这两个操作确保输出长度之前的所有位置都被注意到
            ## 检查设备
            attention_mask = tf.tensor_scatter_nd_update(
                attention_mask,
                indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1),
                updates=tf.ones([batch_size], dtype=attention_mask.dtype),
            )
            attention_mask = tf.reverse(attention_mask, axis=[-1])
            attention_mask = tf.cumsum(attention_mask, axis=-1)
            attention_mask = tf.reverse(attention_mask, axis=[-1])
            attention_mask = tf.cast(attention_mask, tf.bool)
            return attention_mask
"""
This model inherits from `TFPreTrainedModel`. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a `keras.Model` subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0
documentation for all matters related to general usage and behavior.

<Tip>

TensorFlow models and layers in `transformers` accept two formats as input:

- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.

The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
positional argument:

- a single Tensor with `input_values` only and nothing else: `model(input_values)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_values": input_values, "token_type_ids": token_type_ids})`

Note that when creating models and layers with
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
about any of this, as you can just pass inputs like you would to any other Python function!

</Tip>

Args:
    config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
        Initializing with a config file does not load the weights associated with the model, only the
        configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV_2_VEC_2_START_DOCSTRING = r"""

    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
    positional argument:

    - a single Tensor with `input_values` only and nothing else: `model(input_values)`
    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
    `model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])`
    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
    `model({"input_values": input_values, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!

    </Tip>

    Args:
        config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

"""
"""

@add_start_docstrings(
    "The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.",
    WAV_2_VEC_2_START_DOCSTRING,
)
class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
    """
    The bare TFWav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.

    This class inherits from `TFWav2Vec2PreTrainedModel` and includes additional documentation provided by
    `WAV_2_VEC_2_START_DOCSTRING`.

    Args:
        config (Wav2Vec2Config): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the `PreTrainedModel.from_pretrained` method to load the model weights.
    """

    def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
        """
        Initializes a TFWav2Vec2Model instance.

        Args:
            config (Wav2Vec2Config): Model configuration class with all the parameters of the model.
                Initializing with a config file does not load the weights associated with the model, only the
                configuration. Check out the PreTrainedModel.from_pretrained method to load the model weights.
            *inputs: Additional positional arguments to be passed.
            **kwargs: Additional keyword arguments to be passed.
        """
        super().__init__(config, *inputs, **kwargs)
        self.config = config
        self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
    # 将模型的文档字符串添加到前向方法中,用于描述模型输入
    @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
    # 替换返回值的文档字符串,指定输出类型和配置类
    @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
    # 解包输入参数,使其作为独立参数传递给 call 方法
    @unpack_inputs
    def call(
        self,
        input_values: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        token_type_ids: tf.Tensor | None = None,
        position_ids: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        inputs_embeds: tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
        """

        Returns:

        Example:

        ```
        >>> from transformers import AutoProcessor, TFWav2Vec2Model
        >>> from datasets import load_dataset
        >>> import soundfile as sf

        >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
        >>> model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")


        >>> def map_to_array(batch):
        ...     speech, _ = sf.read(batch["file"])
        ...     batch["speech"] = speech
        ...     return batch


        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.map(map_to_array)

        >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values  # Batch size 1
        >>> hidden_states = model(input_values).last_hidden_state
        ```"""

        # 设置是否输出隐藏状态,默认使用配置类中的设定
        output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
        # 设置是否输出注意力权重,默认使用配置类中的设定
        output_attentions = output_attentions if output_attentions else self.config.output_attentions
        # 设置是否返回字典形式的输出,默认使用配置类中的设定
        return_dict = return_dict if return_dict else self.config.return_dict

        # 调用 wav2vec2 模型的前向计算方法,传递所有参数
        outputs = self.wav2vec2(
            input_values=input_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 返回模型输出
        return outputs

    def build(self, input_shape=None):
        # 如果已经构建,则直接返回
        if self.built:
            return
        # 标记模型已构建
        self.built = True
        # 如果存在 wav2vec2 模型,使用其名称为命名空间构建模型
        if getattr(self, "wav2vec2", None) is not None:
            with tf.name_scope(self.wav2vec2.name):
                self.wav2vec2.build(None)
@add_start_docstrings(
    """TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
    WAV_2_VEC_2_START_DOCSTRING,
)
class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
    def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")  # 初始化 TF-Wav2Vec2 主层
        self.dropout = keras.layers.Dropout(config.final_dropout)  # 添加丢弃层,使用给定的丢弃率
        self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head")  # 初始化语言模型头部密集层
        self.output_hidden_size = (
            config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
        )  # 设置输出隐藏尺寸为配置中的特定值或者隐藏尺寸

    def freeze_feature_extractor(self):
        """
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        """
        warnings.warn(
            "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
            "Please use the equivalent `freeze_feature_encoder` method instead.",
            FutureWarning,
        )
        self.freeze_feature_encoder()  # 警告过时方法,调用等效的特征编码器冻结方法

    def freeze_feature_encoder(self):
        """
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        """
        self.wav2vec2.feature_extractor.trainable = False  # 冻结特征编码器,禁止在训练过程中更新其参数

    @unpack_inputs
    @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        input_values: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        token_type_ids: tf.Tensor | None = None,
        position_ids: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        inputs_embeds: tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        labels: tf.Tensor | None = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ):
        """
        Call function to perform forward pass of the model. This function integrates with the `transformers` library's
        `add_start_docstrings_to_model_forward` decorator to provide structured documentation for inputs and outputs.
        """
        # 实现模型的前向传播,结合 `transformers` 库的 `add_start_docstrings_to_model_forward` 装饰器以提供输入和输出的结构化文档

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "wav2vec2", None) is not None:
            with tf.name_scope(self.wav2vec2.name):
                self.wav2vec2.build(None)  # 构建 TF-Wav2Vec2 主层
        if getattr(self, "lm_head", None) is not None:
            with tf.name_scope(self.lm_head.name):
                self.lm_head.build([None, None, self.output_hidden_size])  # 构建语言模型头部密集层
    def __init__(self, config):
        super().__init__(config)
        # 初始化函数,调用父类构造函数,并初始化相关属性
        self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
        # 创建一个名为wav2vec2的TFWav2Vec2MainLayer实例,并赋给self.wav2vec2
        self.num_layers = config.num_hidden_layers + 1
        # 设置self.num_layers为config.num_hidden_layers加一
        with tf.name_scope(self._name_scope()):
            # 使用当前对象的命名空间创建一个上下文管理器
            if config.use_weighted_layer_sum:
                # 如果配置中使用加权层求和
                self.layer_weights = self.add_weight(
                    shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
                )
                # 添加名为layer_weights的权重,形状为(self.num_layers,),初始化为全1,可训练
        self.config = config
        # 将配置对象保存在self.config中
        self.projector = keras.layers.Dense(units=config.classifier_proj_size, name="projector")
        # 创建一个全连接层,单元数为config.classifier_proj_size,名为projector
        self.classifier = keras.layers.Dense(units=config.num_labels, activation=None, name="classifier")
        # 创建一个全连接层,单元数为config.num_labels,激活函数为None,名为classifier

    def freeze_feature_extractor(self):
        """
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        """
        # 弃用警告:freeze_feature_extractor方法将在Transformers v5中移除,请使用等效的freeze_feature_encoder方法
        warnings.warn(
            "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
            "Please use the equivalent `freeze_feature_encoder` method instead.",
            FutureWarning,
        )
        # 发出警告信息,提醒用户方法即将被移除
        self.freeze_feature_encoder()
        # 调用freeze_feature_encoder方法,禁用特征编码器的梯度计算,使其参数在训练过程中不更新

    def freeze_feature_encoder(self):
        """
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        """
        # 调用此函数将禁用特征编码器的梯度计算,使其参数在训练过程中不更新
        self.wav2vec2.feature_extractor.trainable = False

    def freeze_base_model(self):
        """
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        """
        # 调用此函数将禁用基础模型的梯度计算,使其参数在训练过程中不更新,只有分类头将被更新
        for layer in self.wav2vec2.layers:
            # 遍历self.wav2vec2的所有层
            layer.trainable = False
            # 设置每一层的trainable属性为False,即不可训练状态

    @unpack_inputs
    def call(
        self,
        input_values: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        labels: tf.Tensor | None = None,
        training: bool = False,
    ) -> TFSequenceClassifierOutput | Tuple[tf.Tensor]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 如果 return_dict 为 None,则使用 self.config.use_return_dict 的值
        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
        # 如果 self.config.use_weighted_layer_sum 为 True,则设置 output_hidden_states 为 True

        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 调用 self.wav2vec2 模型,传入参数并获取输出

        if self.config.use_weighted_layer_sum:
            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
            # 获取权重层求和后的隐藏状态
            hidden_states = tf.stack(hidden_states, axis=1)
            # 在第二个维度上堆叠隐藏状态
            norm_weights = tf.nn.softmax(self.layer_weights, axis=-1)
            # 对权重进行 softmax 归一化
            hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1)
            # 使用归一化的权重对隐藏状态进行加权求和
        else:
            hidden_states = outputs[0]
            # 否则直接使用模型输出的第一个元素作为隐藏状态

        hidden_states = self.projector(hidden_states)
        # 将隐藏状态投影到指定维度

        if attention_mask is None:
            pooled_output = tf.reduce_mean(hidden_states, axis=1)
            # 如果注意力掩码为 None,则对隐藏状态进行平均池化
        else:
            padding_mask = self._get_feature_vector_attention_mask(shape_list(hidden_states)[1], attention_mask)
            # 获取特征向量注意力掩码
            padding_mask_float = tf.cast(padding_mask, hidden_states.dtype)
            # 将掩码转换为浮点类型
            hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1))
            # 使用掩码进行元素级乘法
            pooled_output = tf.divide(
                tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1)
            )
            # 使用掩码对隐藏状态进行加权求和并进行平均池化

        logits = self.classifier(pooled_output)
        # 使用分类器对池化输出进行分类预测

        loss = None
        if labels is not None:
            loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
            # 使用稀疏分类交叉熵作为损失函数
            loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels]))
            # 计算损失值

        if not return_dict:
            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
            # 构建输出元组
            return ((loss,) + output) if loss is not None else output
            # 返回损失和输出元组,如果没有损失则返回输出元组

        return TFSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
        # 返回 TFSequenceClassifierOutput 对象,包括损失、预测 logits、隐藏状态和注意力
posted @ 2024-06-30 15:40  绝不原创的飞龙  阅读(92)  评论(0)    收藏  举报