.\models\whisper\generation_whisper.py
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy # 导入copy模块,用于复制对象
import math # 导入math模块,用于数学运算
import warnings # 导入warnings模块,用于处理警告
import zlib # 导入zlib模块,用于数据压缩
from typing import Callable, Iterator, List, Optional, Tuple, Union # 导入类型提示相关模块
import numpy as np # 导入NumPy库,用于科学计算
import torch # 导入PyTorch库,用于深度学习
import torch.nn.functional as F # 导入PyTorch的函数库,用于神经网络操作
from torch import nn # 导入PyTorch的神经网络模块
from ...generation.configuration_utils import GenerationConfig # 导入生成配置类
from ...generation.logits_process import (
LogitsProcessorList, # 导入处理logits的列表类
SuppressTokensAtBeginLogitsProcessor, # 导入处理开始位置token的logits处理器类
SuppressTokensLogitsProcessor, # 导入处理token的logits处理器类
WhisperNoSpeechDetection, # 导入无语音检测类
WhisperTimeStampLogitsProcessor, # 导入时间戳logits处理器类
)
from ...generation.stopping_criteria import StoppingCriteriaList # 导入停止标准列表类
from ...modeling_outputs import BaseModelOutput # 导入基础模型输出类
from ...utils import logging # 导入日志记录工具
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE # 导入任务ID和语言代码映射表
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
"""
Applies a median filter of width `filter_width` along the last dimension of the input.
The `inputs` tensor is assumed to be 3- or 4-dimensional.
"""
if filter_width <= 0 or filter_width % 2 != 1:
raise ValueError("`filter_width` should be an odd number")
pad_width = filter_width // 2
if inputs.shape[-1] <= pad_width:
return inputs
# Pad the left and right edges.
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
return result
def _dynamic_time_warping(matrix: np.ndarray):
"""
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
token-level timestamps.
"""
output_length, input_length = matrix.shape
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, input_length + 1):
for i in range(1, output_length + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = matrix[i - 1, j - 1] + c
trace[i, j] = t
# backtrace
# 初始化变量 i 和 j,分别为跟踪矩阵的行数和列数的最大索引
i = trace.shape[0] - 1
j = trace.shape[1] - 1
# 将跟踪矩阵第一行所有元素设置为2
trace[0, :] = 2
# 将跟踪矩阵第一列所有元素设置为1
trace[:, 0] = 1
# 初始化两个空列表,用于存储路径的索引
text_indices = []
time_indices = []
# 当 i 或 j 大于0时,进行循环
while i > 0 or j > 0:
# 将当前 i-1 添加到 text_indices 列表中
text_indices.append(i - 1)
# 将当前 j-1 添加到 time_indices 列表中
time_indices.append(j - 1)
# 根据跟踪矩阵中的值执行不同的操作
if trace[i, j] == 0:
# 如果跟踪矩阵值为0,向左上角移动
i -= 1
j -= 1
elif trace[i, j] == 1:
# 如果跟踪矩阵值为1,向上移动
i -= 1
elif trace[i, j] == 2:
# 如果跟踪矩阵值为2,向左移动
j -= 1
else:
# 如果跟踪矩阵中出现其他值,抛出运行时错误
raise RuntimeError(
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
)
# 将列表转换为 numpy 数组,并反转顺序,然后返回结果
text_indices = np.array(text_indices)[::-1]
time_indices = np.array(time_indices)[::-1]
return text_indices, time_indices
# 从 logits_processor 列表中获取指定类型的 logit_processor_class 实例的属性值
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
# 遍历 logits_processor 列表,找到第一个 isinstance 的 logit_processor_class 类型实例
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
# 如果找到了对应的 logit_processor 实例,则返回其 attribute_name 属性的值,否则返回 None
if logit_processor:
return getattr(logit_processor, attribute_name, None)
return None
# 将当前的分段序列填充到最大长度
def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None):
# 初始化最大总长度和序列列表
max_total_length = 0
sequences = []
# 检查填充方式是否合法,必须是 "right" 或者 "left"
if padding not in ["right", "left"]:
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")
# 遍历当前的分段序列列表
for current_segment_list in current_segments:
# 如果当前分段列表不为空且包含至少一个 tokens 字段的字典
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
# 合并当前分段列表中所有 tokens 字段的张量序列
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
# 如果指定了 cut_off_length,则截取序列后面的部分
if cut_off_length is not None:
sequence = sequence[-cut_off_length:]
# 如果存在 bos_token_tensor,则将其作为起始 token 添加到序列开头
if bos_token_tensor is not None:
sequence = torch.cat([bos_token_tensor, sequence])
# 将处理后的序列添加到 sequences 列表中
sequences.append(sequence)
# 更新最大总长度为当前序列长度和已记录的最大长度的较大值
max_total_length = max(max_total_length, len(sequences[-1]))
# 如果不存在当前分段列表,但存在 bos_token_tensor,则直接将其作为序列
elif bos_token_tensor is not None:
sequences.append(bos_token_tensor)
# 否则,将一个空张量添加到序列中
else:
sequences.append(torch.tensor([]))
# 遍历当前所有序列,对每个序列进行填充,使其长度与最大总长度相同
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
# 将填充后的序列堆叠成一个张量,并返回
sequences = torch.stack(sequences, dim=0)
return sequences
# WhisperGenerationMixin 类,用于生成处理
class WhisperGenerationMixin:
# 生成函数,包含多个可选参数
def generate(
self,
input_features: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: bool = False,
return_timestamps: Optional[bool] = None,
task: Optional[str] = None,
language: Optional[str] = None,
is_multilingual: Optional[bool] = None,
prompt_ids: Optional[torch.Tensor] = None,
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
condition_on_prev_tokens: Optional[bool] = None,
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
compression_ratio_threshold: Optional[float] = None,
logprob_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = None,
num_segment_frames: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
time_precision: float = 0.02,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
**kwargs,
def generate_with_fallback(
self,
segment_input,
decoder_input_ids,
cur_bsz,
batch_idx_map,
seek,
num_segment_frames,
max_frames,
temperatures,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_token_timestamps,
do_condition_on_prev_tokens,
kwargs,
):
# 生成文本输出,并在失败时返回备用选项
# 使用给定的输入生成文本段落
...
@staticmethod
def _prepare_segments(prompt_ids, batch_size, generation_config):
# 准备文本段落以供生成器使用
# 如果指定了 prompt_ids 并且 generation_config 指定了 prompt_condition_type 为 "first-segment"
if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
# 获取 prev_sot_token_id,如果存在的话
prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
# 如果 prompt_ids 的第一个 token 是 prev_sot_token_id,则去掉第一个 token
prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
# 将每个 batch 的当前段落设置为包含 prompt_ids 的 tokens
current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
else:
# 否则,将每个 batch 的当前段落设置为空列表
current_segments = [[] for _ in range(batch_size)]
return current_segments
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
# 后处理生成的输出
# 如果 seek_outputs 是 torch.Tensor 类型,则截取未来的输出
if isinstance(seek_outputs, torch.Tensor):
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :]
return seek_outputs, seek_outputs
# 如果需要返回 token 时间戳并且 generation_config 有 alignment_heads 属性
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
# 获取 num_frames 属性
num_frames = getattr(generation_config, "num_frames", None)
# 提取 token 时间戳并截取未来的输出
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :]
# 截取未来的输出序列
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
def split_by_batch_index(values, key, batch_idx):
# 根据 batch_idx 将值按照指定的 key 分割
if key == "scores":
return [v[batch_idx].cpu() for v in values]
elif key == "past_key_values":
# 不保存 past_key_values,因为这样做成本太高
return None
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
# 如果值是元组且第一个元素是张量,则按照 batch_idx 分割
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
return values[batch_idx].cpu()
# 对每个 batch 分割 seek_outputs 中的值
sequence_tokens = seek_outputs["sequences"]
seek_outputs = [
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
]
return sequence_tokens, seek_outputs
def _need_fallback(
self,
seek_sequence,
seek_outputs,
index,
logits_processor,
generation_config,
vocab_size,
temperature,
):
# 判断是否需要回退到备选方案
...
):
# 初始化需要回退和跳过标志
needs_fallback = False
should_skip = False
# 如果设定了压缩比例阈值,则计算压缩比例
if generation_config.compression_ratio_threshold is not None:
compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
# 如果压缩比例超过阈值,则需要回退
if compression_ratio > generation_config.compression_ratio_threshold:
needs_fallback = True
# 如果设定了对数概率阈值,则进行对数概率的检查
if generation_config.logprob_threshold is not None:
if "sequences_scores" in seek_outputs[0]:
logprobs = [s["sequences_scores"] for s in seek_outputs][index]
else:
scores = seek_outputs[index]["scores"]
logprobs = self._retrieve_avg_logprobs(
scores, seek_sequence, generation_config.eos_token_id, temperature
)
# 如果平均对数概率低于阈值,则需要回退
if logprobs < generation_config.logprob_threshold:
needs_fallback = True
# 如果设定了无语音概率阈值,则进行检查
if generation_config.no_speech_threshold is not None:
# 获取无语音概率
no_speech_prob = _get_attr_from_logit_processors(
logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
)
# 如果对数概率低于阈值且无语音概率高于阈值,则不需要回退但应跳过
if (
logprobs < generation_config.logprob_threshold
and no_speech_prob[index] > generation_config.no_speech_threshold
):
needs_fallback = False
should_skip = True
# 返回是否需要回退和是否应该跳过的标志
return needs_fallback, should_skip
@staticmethod
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
# 从logits处理器中获取无语音检测器的设置输入方法,并将输入传递给它
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
@staticmethod
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
# 如果提供了输入特征,则返回输入帧的总数和每帧的特征维度
if input_features is not None:
return input_features.shape[0], input_features.shape[-1]
# 如果提供了编码器输出,则根据输入步长计算总输入帧数
if "encoder_outputs" in kwargs:
encoder_outputs_shape = (
kwargs["encoder_outputs"][0].shape
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
# 如果没有提供输入特征或编码器输出,则引发值错误异常
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
@staticmethod
def _maybe_warn_unused_inputs(
condition_on_prev_tokens,
temperature,
compression_ratio_threshold,
logprob_threshold,
no_speech_threshold,
total_input_frames,
):
# 警告消息的前缀,指示音频输入帧数不足,激活了短形式转录
warning_prefix = (
f"Audio input consists of only {total_input_frames}. "
"Short-form transcription is activated."
"{}, but will be ignored."
)
# 如果 condition_on_prev_tokens 不为 None,则记录警告信息
if condition_on_prev_tokens is not None:
logger.warn(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
# 如果 compression_ratio_threshold 不为 None,则记录警告信息
if compression_ratio_threshold is not None:
logger.warn(warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}"))
# 如果 logprob_threshold 不为 None,则记录警告信息
if logprob_threshold is not None:
logger.warn(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
# 如果 no_speech_threshold 不为 None,则记录警告信息
if no_speech_threshold is not None:
logger.warn(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
# 当 temperature 作为列表传递时,不能简单地忽略,需要抛出错误
if isinstance(temperature, (list, tuple)):
raise ValueError(
f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
)
@staticmethod
def _set_return_outputs(
return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config
):
# 如果 return_dict_in_generate 为 None,则使用 generation_config 中的默认值
if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate
# 设置是否返回 token 的时间戳
generation_config.return_token_timestamps = return_token_timestamps
if return_token_timestamps:
return_dict_in_generate = True
generation_config.output_attentions = True
generation_config.output_scores = True
# 如果不是短形式生成并且 logprob_threshold 不为 None,则需要输出分数
if not is_shortform and logprob_threshold is not None:
return_dict_in_generate = True
generation_config.output_scores = True
# 更新 generation_config 中的返回字典设置
generation_config.return_dict_in_generate = return_dict_in_generate
# 定义一个静态方法 `_set_return_timestamps`,用于设置返回时间戳的配置
def _set_return_timestamps(return_timestamps, is_shortform, generation_config):
# 如果不是简化形式生成
if not is_shortform:
# 如果 return_timestamps 为 False,则抛出数值错误异常,提示输入的 mel 特征超过了3000 (> 30秒),需要启用长格式生成,此时需要模型预测时间戳标记。
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
# 记录信息日志,设置 `return_timestamps=True` 以用于长格式生成。
logger.info("Setting `return_timestamps=True` for long-form generation.")
return_timestamps = True
# 如果要返回时间戳,并且生成配置没有 `no_timestamps_token_id` 属性
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
# 抛出数值错误异常,提示生成配置未正确设置以返回时间戳,建议初始化正确的生成配置。
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
# 将 return_timestamps 设置到生成配置的属性中
generation_config.return_timestamps = return_timestamps
# 设置语言和任务到生成配置中,如果是多语言模型则更新配置
def _set_language_and_task(language, task, is_multilingual, generation_config):
# 如果提供了 is_multilingual 参数,则更新生成配置中的 is_multilingual 属性
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
# 如果生成配置过时,抛出数值错误异常
raise ValueError(
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.is_multilingual = is_multilingual
# 如果生成配置中标记为非多语言模型,并且尝试指定任务或语言,抛出数值错误异常
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
if task is not None or language is not None:
raise ValueError(
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
)
# 如果指定了语言参数,则更新生成配置中的语言属性,确保语言名称小写化
if language is not None:
if not hasattr(generation_config, "lang_to_id"):
# 如果生成配置过时,抛出数值错误异常
raise ValueError(
"The generation config is outdated and is thus not compatible with the `language` argument "
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
language = language.lower()
generation_config.language = language
# 如果指定了任务参数,则更新生成配置中的任务属性
if task is not None:
if not hasattr(generation_config, "task_to_id"):
# 如果生成配置过时,抛出数值错误异常
raise ValueError(
"The generation config is outdated and is thus not compatible with the `task` argument "
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.task = task
# 设置生成配置中的特定标记 ID,优先使用传入的关键字参数,否则使用配置中的默认值
def _set_token_ids(generation_config, config, kwargs):
# 从关键字参数中弹出结束标记的 ID
eos_token_id = kwargs.pop("eos_token_id", None)
# 从关键字参数中弹出解码器起始标记的 ID
decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# 如果结束标记 ID 存在,则使用它;否则使用生成配置中的默认值
eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id
# 如果解码器起始标记 ID 存在,则使用它;否则使用生成配置中的默认值
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id
)
# 将确定的结束标记 ID 设置回生成配置中,如果不存在则使用全局配置中的默认值
generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id
# 将确定的解码器起始标记 ID 设置回生成配置中,如果不存在则使用全局配置中的默认值
generation_config.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id
)
@staticmethod
# 设置生成配置中的帧数及相关条件,根据传入的关键字参数或生成配置中的默认值
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
# 如果需要返回标记级别的时间戳
if return_token_timestamps:
# 如果生成配置中的任务为“translate”,发出警告信息
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
# 如果生成配置中没有“alignment_heads”,抛出数值错误
if not hasattr(generation_config, "alignment_heads"):
raise ValueError(
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
# 从关键字参数中弹出帧数,如果不存在则设为 None
generation_config.num_frames = kwargs.pop("num_frames", None)
@staticmethod
# 设置生成配置中的阈值和条件,根据传入的参数或生成配置中的默认值
def _set_thresholds_and_condition(
generation_config,
logprob_threshold,
compression_ratio_threshold,
no_speech_threshold,
condition_on_prev_tokens,
):
# 设置生成配置中的对数概率阈值,使用传入的参数或生成配置中的默认值
generation_config.logprob_threshold = (
logprob_threshold
if logprob_threshold is not None
else getattr(generation_config, "logprob_threshold", None)
)
# 设置生成配置中的压缩比阈值,使用传入的参数或生成配置中的默认值
generation_config.compression_ratio_threshold = (
compression_ratio_threshold
if compression_ratio_threshold is not None
else getattr(generation_config, "compression_ratio_threshold", None)
)
# 设置生成配置中的非语音阈值,使用传入的参数或生成配置中的默认值
generation_config.no_speech_threshold = (
no_speech_threshold
if no_speech_threshold is not None
else getattr(generation_config, "no_speech_threshold", None)
)
# 设置生成配置中的基于前一个标记的条件,使用传入的参数或生成配置中的默认值
generation_config.condition_on_prev_tokens = (
condition_on_prev_tokens
if condition_on_prev_tokens is not None
else getattr(generation_config, "condition_on_prev_tokens", None)
)
# 设置生成配置的提示条件类型
def _set_prompt_condition_type(generation_config, prompt_condition_type):
allowed_cond_types = ["first-segment", "all-segments"]
# 默认使用 "first-segment" 作为提示条件类型,除非指定了其他值
prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
# 检查所选的提示条件类型是否在允许的类型列表中
if prompt_condition_type not in allowed_cond_types:
raise ValueError(
f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
)
# 如果选择了 "all-segments" 类型的条件,确保设置了 condition_on_prev_tokens=True
if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
raise ValueError(
"Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'."
)
# 将生成配置中的提示条件类型设定为指定的值
generation_config.prompt_condition_type = prompt_condition_type
@staticmethod
# 设置是否基于先前标记设置条件
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
# 如果未指定 condition_on_prev_tokens 的值,则使用生成配置中的默认值
condition_on_prev_tokens = (
condition_on_prev_tokens
if condition_on_prev_tokens is not None
else getattr(generation_config, "condition_on_prev_tokens", False)
)
# 将生成配置中的条件设置为所选的值
generation_config.condition_on_prev_tokens = condition_on_prev_tokens
@staticmethod
# 获取最大帧数和起始位置
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames):
# 如果批量大小大于 1 且未提供注意力掩码,则抛出错误
if batch_size > 1 and attention_mask is None:
raise ValueError(
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
)
# 如果批量大小大于 1,则计算每个样本的最大帧数,并设置初始位置为零
elif batch_size > 1:
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
seek = torch.zeros((batch_size,), dtype=torch.long)
# 如果批量大小为 1,则所有输入都使用相同的最大帧数,并设置初始位置为零
else:
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames
seek = torch.zeros((1,), dtype=torch.long)
# 返回计算得到的最大帧数和起始位置
return max_frames, seek
# 静态方法:根据生成配置和处理器列表,获取日志概率处理器
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams):
# 如果生成配置中设置了返回时间戳为真
if generation_config.return_timestamps is True:
# 创建时间戳日志概率处理器对象
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
# 将时间戳处理器添加到处理器列表中,如果处理器列表为空则创建新列表
logits_processor = (
[timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
)
# 如果生成配置中设置了需要抑制的标记
if generation_config.suppress_tokens is not None:
# 创建抑制标记日志概率处理器对象
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens)
# 将抑制标记处理器添加到处理器列表中,如果处理器列表为空则创建新列表
logits_processor = (
[suppress_tokens_processor]
if logits_processor is None
else [suppress_tokens_processor] + logits_processor
)
# 将生成配置中的抑制标记设置为 None,避免重复处理
generation_config.suppress_tokens = None
# 如果生成配置中设置了需要在开始位置抑制的标记
if generation_config.begin_suppress_tokens is not None:
# 创建开始位置抑制标记日志概率处理器对象
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens, begin_index=begin_index
)
# 将开始位置抑制标记处理器添加到处理器列表中,如果处理器列表为空则创建新列表
logits_processor = (
[begin_suppress_processor]
if logits_processor is None
else [begin_suppress_processor] + logits_processor
)
# 将生成配置中的开始位置抑制标记设置为 None,避免重复处理
generation_config.begin_suppress_tokens = None
# 如果生成配置中设置了无语音阈值,并且不是短表单模式
if generation_config.no_speech_threshold is not None and not is_shortform:
# 创建无语音检测对象
no_speech_detector = WhisperNoSpeechDetection(
no_speech_token=generation_config.no_timestamps_token_id - 1,
begin_index=begin_index,
scores_is_logprobs=num_beams > 1,
)
# 将无语音检测器添加到处理器列表中,如果处理器列表为空则创建新列表
logits_processor = (
[no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
)
# 将模型对象设置给无语音检测器
no_speech_detector.set_model(self)
# 返回处理器列表
return logits_processor
# 静态方法:可能减少批次的大小
@staticmethod
def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
# 记录前一个批次的大小
prev_bsz = cur_bsz
# 新的批次索引映射列表
new_batch_idx_map = []
# 遍历每个批次中的样本
for i in range(prev_bsz):
# 获取原始批次索引
prev_i = batch_idx_map[i]
# 如果当前样本超过了其最大帧数
if seek[prev_i] >= max_frames[prev_i]:
# 计算要切除的索引
cut_index = i + (cur_bsz - prev_bsz)
# 减少当前批次大小
cur_bsz -= 1
# 从输入特征中删除超出帧数限制的样本
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
else:
# 保留不需要切除的索引
new_batch_idx_map.append(prev_i)
# 返回处理后的输入特征、当前批次大小和新的批次索引映射列表
return input_features, cur_bsz, new_batch_idx_map
# 定义一个静态方法,用于生成输入段落的数据
def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
# 初始化一个空列表,用于存储每个批次样本的输入段落数据
segment_input = []
# 遍历当前批次中的每个样本
for i in range(cur_bsz):
# 获取当前样本在批次中的索引
prev_i = batch_idx_map[i]
# 从输入特征中切片出当前样本的输入段落数据
segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
# 如果当前切片的最后一个维度长度小于期望的段落长度
if segment_input_slice.shape[-1] < num_segment_frames:
# 使用 PyTorch 的填充函数,在末尾维度上填充到期望的段落长度为 3000
segment_input_slice = F.pad(
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
)
# 将处理后的段落数据加入到段落输入列表中
segment_input.append(segment_input_slice)
# 将所有样本的段落数据在第一维度上连接起来
segment_input = torch.cat(segment_input, dim=0)
# 返回合并后的段落输入数据
return segment_input
# 定义一个静态方法,用于准备解码器的输入标识符
@staticmethod
def _prepare_decoder_input_ids(
cur_bsz,
init_tokens,
current_segments,
batch_idx_map,
do_condition_on_prev_tokens,
prompt_ids,
generation_config,
config,
device,
suppress_tokens,
kwargs,
):
# 计算每个样本的目标位置的最大长度的一半减一
cut_off_length = config.max_target_positions // 2 - 1
# 创建一个形状为 (当前批次大小, 1) 的张量,所有元素为 1,设备为指定的 device,数据类型为 long
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
# 根据初始化的标记将张量连接起来形成 decoder 的输入标记张量
decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1)
# 获取前一个文本起始标记的 ID,如果没有指定则使用 suppress_tokens 中的倒数第二个元素
prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
if prev_start_of_text is None:
prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
# 如果任何一个 do_condition_on_prev_tokens 为真,并且当前段落长度大于 0
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
# 根据生成配置和提示条件类型选择前一个标记的 ID
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids
else:
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
# 将前一个标记的 ID 填充到最大长度,以及应用截断长度和其他参数
prev_tokens = _pad_to_max_length(
active_segments,
generation_config.pad_token_id,
padding="left",
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
)
# 将填充后的前一个标记张量和 decoder 输入标记张量连接起来
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
# 设置 decoder_attention_mask,排除填充标记
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
elif prompt_ids is not None:
# 将提示标记张量重复批次大小次数,并与 decoder 输入标记张量连接起来
prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
# 确保不将 `"decoder_attention_mask"` 传递给前向计算
kwargs.pop("decoder_attention_mask", None)
else:
# 确保不将 `"decoder_attention_mask"` 传递给前向计算
kwargs.pop("decoder_attention_mask", None)
# 返回 decoder 输入标记张量和 kwargs 参数
return decoder_input_ids, kwargs
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs):
# 计算初始令牌数量,限制在最大目标位置的一半减一和解码器输入长度减一之间
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
# 从kwargs中弹出'max_length'参数,并赋给passed_max_length变量
passed_max_length = kwargs.pop("max_length", None)
# 从kwargs中弹出'max_new_tokens'参数,并赋给passed_max_new_tokens变量
passed_max_new_tokens = kwargs.pop("max_new_tokens", None)
# 从生成配置(generation_config)中获取'max_length'属性,并赋给max_length_config变量
max_length_config = getattr(generation_config, "max_length", None)
# 从生成配置(generation_config)中获取'max_new_tokens'属性,并赋给max_new_tokens_config变量
max_new_tokens_config = getattr(generation_config, "max_new_tokens", None)
# 初始化max_new_tokens和max_length变量为None
max_new_tokens = None
max_length = None
# 确保不超过'max_length'设定的最大值
if passed_max_length is not None and passed_max_new_tokens is None:
# 根据条件增加max_length,以确保不超过config.max_target_positions
max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment."
)
elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None:
# 根据条件增加max_length,以确保不超过config.max_target_positions
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment."
)
elif (
passed_max_new_tokens is not None
and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
):
# 计算最大新令牌数,以确保不超过config.max_target_positions
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
elif (
passed_max_new_tokens is None
and max_new_tokens_config is not None
and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions
):
# 计算最大新令牌数,以确保不超过config.max_target_positions
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
# 如果max_new_tokens不为None,则将其设置回kwargs中
if max_new_tokens is not None:
kwargs["max_new_tokens"] = max_new_tokens
# 如果max_length不为None,则将其设置回kwargs中
if max_length is not None:
kwargs["max_length"] = max_length
# 返回更新后的kwargs
return kwargs
# 计算平均对数概率的函数,用于生成模型的输出评分
def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
# 根据温度参数重新缩放评分,如果温度为非正数则默认为1
rescale_temperature = temperature if temperature > 0.0 else 1
# 将所有评分堆叠成一个张量,并放置在与 tokens 相同的设备上
scores = torch.stack(scores).to(tokens.device)
# 如果评分张量的长度大于 tokens 张量的长度,则截断评分张量
if scores.shape[0] > tokens.shape[0]:
scores = scores[: tokens.shape[0]]
else:
# 否则截断 tokens 张量,以匹配评分张量的长度
tokens = tokens[-scores.shape[0] :]
# 对缩放后的评分应用对数 softmax 函数,计算对数概率
logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
# 计算所选 tokens 的对数概率并求和
sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
# 如果 eos_token_id 不为空,则计算 tokens 中非 eos_token_id 的长度;否则使用 tokens 的总长度
length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
# 计算平均对数概率,考虑到序列长度加一的影响
avg_logprobs = sum_logprobs / (length + 1)
return avg_logprobs
@staticmethod
def _retrieve_segment(
seek_sequence,
seek_outputs,
time_offset,
timestamp_begin,
seek_num_frames,
time_precision,
input_stride,
prev_idx,
idx,
return_token_timestamps,
.\models\whisper\modeling_flax_whisper.py
# 定义一个文档字符串,用于描述 WHISPER 模型的基本信息和继承关系
WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
etc.) This model is also a Flax Linen
# flax.nn.Module 的子类,可作为常规的 Flax 模块使用,参考 Flax 文档以了解一般用法和行为。
# 最终,此模型支持 JAX 的内置特性,例如:
# - Just-In-Time (JIT) 编译
# - 自动微分
# - 向量化
# - 并行化
# 参数:
# config ([`WhisperConfig`]): 模型配置类,包含模型的所有参数。
# 初始化时使用配置文件不会加载模型的权重,只会加载配置。请查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。
# dtype (`jax.numpy.dtype`, *optional*, 默认为 `jax.numpy.float32`):
# 计算的数据类型。可以是 `jax.numpy.float32`,`jax.numpy.float16`(在 GPU 上),以及 `jax.numpy.bfloat16`(在 TPU 上)。
# 这可以用于在 GPU 或 TPU 上启用混合精度训练或半精度推断。如果指定了dtype,则所有计算将使用给定的dtype执行。
# **注意,这仅指定计算的数据类型,不会影响模型参数的数据类型。**
# 如果要更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""
WHISPER_INPUTS_DOCSTRING = r"""
"""
WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
is not used. By default the silence in the input log mel spectrogram are ignored.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
WHISPER_DECODE_INPUTS_DOCSTRING = r"""
"""
注释:
Args:
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
# 解码器输入序列的标记索引,对应词汇表中的位置。索引可通过 `WhisperTokenizer` 获取。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__` 获取更多细节。
# [decoder input IDs 是什么?](../glossary#decoder-input-ids)
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
encoder_outputs (`tuple(tuple(numpy.ndarray)`):
# 编码器的输出元组,包含 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
# `last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选* 是编码器最后一层的隐藏状态序列。用于解码器的交叉注意力。
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
# 编码器注意力掩码。Whisper 不支持 `input_features` 的屏蔽,该参数为了兼容性而保留,但不被使用。
# 默认情况下,会忽略输入 log mel 频谱图中的静默部分。
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
# 解码器注意力掩码。默认行为:生成一个张量,忽略 `decoder_input_ids` 中的填充标记。
# 默认还使用因果掩码。如需更改填充行为,应根据需求进行修改。参见论文中的图 1 获取更多关于默认策略的信息。
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
# 解码器输入序列中每个标记的位置索引,在位置嵌入中选择的范围为 `[0, config.max_position_embeddings - 1]`。
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
# 预计算的隐藏状态字典,包含用于快速自回归解码的注意力块中的键和值的隐藏状态。预计算的键和值的隐藏状态形状为 *[batch_size, max_length]*。
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回张量中的 `attentions`。
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。有关更多详细信息,请参见返回张量中的 `hidden_states`。
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
# 是否返回一个 `~utils.ModelOutput` 而不是简单的元组。
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
定义一个自定义的Flax模块,用于实现注意力机制。
config: WhisperConfig # 用于存储配置信息的属性
embed_dim: int # 嵌入维度
num_heads: int # 注意力头的数量
dropout: float = 0.0 # 可选的丢弃率,默认为0.0
causal: bool = False # 是否使用因果注意力,默认为False
bias: bool = True # 是否使用偏置项,默认为True
dtype: jnp.dtype = jnp.float32 # 数据类型,默认为32位浮点型
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads # 计算每个注意力头的维度
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
# 创建部分应用了配置的全连接层
dense = partial(
nn.Dense,
self.embed_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj = dense(use_bias=self.bias) # Query投影层
self.k_proj = dense(use_bias=False) # Key投影层
self.v_proj = dense(use_bias=self.bias) # Value投影层
self.out_proj = dense(use_bias=self.bias) # 输出投影层
if self.causal:
# 如果启用了因果注意力,创建因果掩码
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool"
)
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
):
"""
实现模块的调用方法,执行注意力机制。
Args:
hidden_states: 输入的隐藏状态张量
key_value_states: 可选的键值状态张量,用于自注意力机制
attention_mask: 可选的注意力掩码张量,控制哪些位置参与注意力计算
init_cache: 是否初始化缓存,通常用于Transformer模型
deterministic: 是否使用确定性计算,影响是否使用随机性如dropout
"""
...
def _split_heads(self, hidden_state) -> jnp.ndarray:
"""
将隐藏状态张量分割成多个注意力头。
Args:
hidden_state: 输入的隐藏状态张量
Returns:
jnp.ndarray: 分割后的张量
"""
return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_state) -> jnp.ndarray:
"""
合并多个注意力头成一个隐藏状态张量。
Args:
hidden_state: 多头注意力张量
Returns:
jnp.ndarray: 合并后的张量
"""
return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
@nn.compact
"""
def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
# 检测是否通过缺少现有缓存数据来初始化。
is_initialized = self.has_variable("cache", "cached_key")
# 获取或创建缓存的键,并用零值初始化
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
# 获取或创建缓存的值,并用零值初始化
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
# 获取或创建缓存索引,并用整数值0初始化
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
# 获取批次维度和注意力头数等维度信息
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# 使用新的1维空间切片更新键值缓存
cur_index = cache_index.value
# 计算动态更新切片的索引
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
# 动态更新缓存键值
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
# 更新缓存索引,增加已更新的缓存向量数
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# 为缓存的解码器自注意力生成因果掩码:
# 我们的单个查询位置只应注意已生成并缓存的键位置,而不是剩余的零元素。
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
# 合并生成的掩码和输入的注意力掩码
attention_mask = combine_masks(pad_mask, attention_mask)
# 返回更新后的键、值和注意力掩码
return key, value, attention_mask
# 从transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer复制并修改为FlaxWhisperEncoderLayer
class FlaxWhisperEncoderLayer(nn.Module):
# WhisperConfig类型的配置参数
config: WhisperConfig
# 计算时使用的数据类型,默认为32位浮点数
dtype: jnp.dtype = jnp.float32
# 模块的设置方法,用于初始化各个子模块
def setup(self) -> None:
# 编码器层的维度等于模型配置中的d_model
self.embed_dim = self.config.d_model
# 创建WhisperAttention自注意力机制对象
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 对自注意力输出进行LayerNorm归一化处理
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 定义dropout层,用于随机屏蔽输入元素以防止过拟合
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 激活函数,根据配置选择激活函数类型
self.activation_fn = ACT2FN[self.config.activation_function]
# 激活函数后的dropout层
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
# 第一个全连接层,使用正态分布初始化权重
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 第二个全连接层,输出维度为embed_dim,使用正态分布初始化权重
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
# 最终输出的LayerNorm层
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 对象调用方法,执行编码器层的前向传播
def __call__(
self,
hidden_states: jnp.ndarray, # 输入的隐藏状态张量
attention_mask: jnp.ndarray, # 注意力遮罩,用于屏蔽无效位置
output_attentions: bool = True, # 是否输出注意力权重
deterministic: bool = True, # 是否使用确定性计算
) -> Tuple[jnp.ndarray]: # 返回类型为包含一个张量元组的元组
# 保留残差连接
residual = hidden_states
# 对输入进行LayerNorm归一化
hidden_states = self.self_attn_layer_norm(hidden_states)
# 执行自注意力计算,并获取注意力权重
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
# 应用dropout层
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 残差连接
hidden_states = residual + hidden_states
# 保留残差连接
residual = hidden_states
# 对输出进行LayerNorm归一化
hidden_states = self.final_layer_norm(hidden_states)
# 应用激活函数和第一个全连接层
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 应用激活函数后的dropout层
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
# 应用第二个全连接层
hidden_states = self.fc2(hidden_states)
# 应用最终的dropout层
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 残差连接
hidden_states = residual + hidden_states
# 输出结果为隐藏状态张量的元组
outputs = (hidden_states,)
# 如果需要输出注意力权重,则将注意力权重加入输出元组中
if output_attentions:
outputs += (attn_weights,)
# 返回输出结果元组
return outputs
# 编码器层集合类,包含多个编码器层对象
class FlaxWhisperEncoderLayerCollection(nn.Module):
# WhisperConfig类型的配置参数
config: WhisperConfig
# 计算时使用的数据类型,默认为32位浮点数
dtype: jnp.dtype = jnp.float32 # 计算的数据类型
# 梯度检查点标志,默认为False
gradient_checkpointing: bool = False
# 初始化模型的设置
def setup(self):
# 如果启用了梯度检查点,则使用可重定向的编码器层,否则使用常规编码器层
if self.gradient_checkpointing:
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
# 创建编码器层列表,每层使用可重定向的编码器层实例化
self.layers = [
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
else:
# 创建编码器层列表,每层使用常规编码器层实例化
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
# 设置编码器层的 LayerDrop 概率
self.layerdrop = self.config.encoder_layerdrop
# 模型的调用函数,接收隐藏状态、注意力掩码等输入参数
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 如果需要输出注意力矩阵,则初始化一个空元组以存储所有注意力矩阵
all_attentions = () if output_attentions else None
# 如果需要输出隐藏状态,则初始化一个空元组以存储所有隐藏状态
all_hidden_states = () if output_hidden_states else None
# 遍历每个编码器层
for encoder_layer in self.layers:
# 如果需要输出隐藏状态,则将当前隐藏状态加入到所有隐藏状态元组中
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 添加 LayerDrop 机制,根据概率跳过当前编码器层
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
# 如果跳过当前层,则设置输出为 None
layer_outputs = (None, None)
else:
# 否则,调用当前编码器层的前向传播
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
# 更新隐藏状态为当前层的输出隐藏状态
hidden_states = layer_outputs[0]
# 如果需要输出注意力矩阵,则将当前层的注意力矩阵加入到所有注意力矩阵元组中
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# 如果需要输出隐藏状态,则将最终的隐藏状态加入到所有隐藏状态元组中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 组装模型的输出,包括最终的隐藏状态、所有隐藏状态和所有注意力矩阵
outputs = (hidden_states, all_hidden_states, all_attentions)
# 如果不需要返回字典,则返回元组形式的输出,去除值为 None 的部分
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 否则,返回 FlaxBaseModelOutput 类型的字典形式输出
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# 从transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer复制而来,改名为FlaxWhisperDecoderLayer
class FlaxWhisperDecoderLayer(nn.Module):
# 类属性:配置信息为WhisperConfig类型
config: WhisperConfig
# 类属性:数据类型为jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法,设置各个层和模块
def setup(self) -> None:
# 设定embedding维度为配置中的模型维度
self.embed_dim = self.config.d_model
# 创建自注意力层对象FlaxWhisperAttention
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
# Dropout层,用于self-attention之后
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 激活函数选择,根据配置选择相应的激活函数
self.activation_fn = ACT2FN[self.config.activation_function]
# 激活函数之后的dropout层
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
# Layer normalization层,用于self-attention输出
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 创建编码器-解码器注意力层对象FlaxWhisperAttention
self.encoder_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 编码器-解码器注意力层后的Layer normalization层
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 第一个全连接层,即前馈神经网络的第一层
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 第二个全连接层,即前馈神经网络的第二层,输出维度为embedding维度
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
# 最终输出的Layer normalization层
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 对象调用方法,执行解码层的前向计算
def __call__(
self,
hidden_states: jnp.ndarray, # 输入的隐藏状态
attention_mask: jnp.ndarray, # 自注意力和编码器-解码器注意力的掩码
encoder_hidden_states: Optional[jnp.ndarray] = None, # 编码器的隐藏状态,可选
encoder_attention_mask: Optional[jnp.ndarray] = None, # 编码器的注意力掩码,可选
init_cache: bool = False, # 是否初始化缓存,布尔类型
output_attentions: bool = True, # 是否输出注意力权重,布尔类型
deterministic: bool = True, # 是否确定性计算,布尔类型
):
) -> Tuple[jnp.ndarray]:
# 保存输入的隐藏状态作为残差连接的基础
residual = hidden_states
# 对输入的隐藏状态进行 Layer normalization 处理
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# 调用自注意力机制,处理输入的隐藏状态,生成新的隐藏状态和注意力权重
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
# 对生成的新的隐藏状态进行 Dropout 处理
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 使用残差连接,将处理后的隐藏状态与原始输入相加
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
# 保存当前隐藏状态作为残差连接的基础
residual = hidden_states
# 对当前隐藏状态进行 Layer normalization 处理
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# 调用编码器-解码器注意力机制,处理当前隐藏状态和编码器的隐藏状态,生成新的隐藏状态和注意力权重
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
# 对生成的新的隐藏状态进行 Dropout 处理
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 使用残差连接,将处理后的隐藏状态与原始输入相加
hidden_states = residual + hidden_states
# Fully Connected
# 保存当前隐藏状态作为残差连接的基础
residual = hidden_states
# 对当前隐藏状态进行 Layer normalization 处理
hidden_states = self.final_layer_norm(hidden_states)
# 使用激活函数对处理后的隐藏状态进行非线性变换
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 对生成的新的隐藏状态进行 Dropout 处理
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
# 使用全连接层进行线性变换
hidden_states = self.fc2(hidden_states)
# 对生成的新的隐藏状态进行 Dropout 处理
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 使用残差连接,将处理后的隐藏状态与原始输入相加
hidden_states = residual + hidden_states
# 构造输出元组
outputs = (hidden_states,)
# 如果需要输出注意力权重,则将自注意力和交叉注意力的权重添加到输出中
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
# 返回最终的输出
return outputs
# 定义一个名为 FlaxWhisperDecoderLayerCollection 的类,继承自 nn.Module
class FlaxWhisperDecoderLayerCollection(nn.Module):
# 类变量 config,类型为 WhisperConfig,用于存储配置信息
config: WhisperConfig
# 类变量 dtype,默认为 jnp.float32,表示计算中使用的数据类型
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 类变量 gradient_checkpointing,默认为 False,用于控制是否使用梯度检查点
gradient_checkpointing: bool = False
# 定义类方法 setup,用于初始化类的实例
def setup(self):
# 如果启用梯度检查点
if self.gradient_checkpointing:
# 动态创建 FlaxWhisperDecoderCheckpointLayer 类的实例,设置静态参数编号
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
# 创建多个 FlaxWhisperDecoderCheckpointLayer 实例,存储在 self.layers 中
self.layers = [
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
else:
# 如果未启用梯度检查点,创建多个 FlaxWhisperDecoderLayer 实例,存储在 self.layers 中
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
# 设置类变量 layerdrop 为 config 中的 decoder_layerdrop 参数
self.layerdrop = self.config.decoder_layerdrop
# 定义 __call__ 方法,使实例对象可以像函数一样调用
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
# 此处省略了方法的余下部分,不在此注释范围内
):
# 如果输出隐藏状态,则初始化一个空元组;否则设为 None
all_hidden_states = () if output_hidden_states else None
# 如果输出注意力权重,则初始化一个空元组;否则设为 None
all_self_attns = () if output_attentions else None
# 如果输出交叉注意力权重,并且编码器隐藏状态不为空,则初始化一个空元组;否则设为 None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# 遍历解码器每一层
for decoder_layer in self.layers:
# 如果需要输出隐藏状态,则将当前隐藏状态加入 all_hidden_states 元组
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 添加层丢弃 (LayerDrop),详见 https://arxiv.org/abs/1909.11556
# 生成一个0到1之间的随机数作为丢弃概率
dropout_probability = random.uniform(0, 1)
# 如果非确定性模式且随机数小于层丢弃率,则将层输出置为None
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
# 否则,调用当前解码器层进行计算
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
output_attentions,
deterministic,
)
# 更新隐藏状态为当前层输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,则将当前层的自注意力权重加入 all_self_attns 元组
if output_attentions:
all_self_attns += (layer_outputs[1],)
# 如果编码器隐藏状态不为空,则将当前层的交叉注意力权重加入 all_cross_attentions 元组
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# 如果需要输出隐藏状态,则将最后一个解码器层的隐藏状态加入 all_hidden_states 元组
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 将所有输出结果放入 outputs 列表中
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
# 如果不需要返回字典形式的输出,则返回 outputs 中非空的元素组成的元组
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 否则,返回带有过去和交叉注意力的 FlaxBaseModelOutputWithPastAndCrossAttentions 对象
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
# 定义一个名为 FlaxWhisperEncoder 的神经网络模块,继承自 nn.Module
class FlaxWhisperEncoder(nn.Module):
# 定义类变量 config,类型为 WhisperConfig,用于存储模型配置信息
config: WhisperConfig
# 定义变量 dtype,默认为 jnp.float32,指定模型数据类型为 32 位浮点数
dtype: jnp.dtype = jnp.float32
# 定义变量 gradient_checkpointing,默认为 False,用于指示是否启用梯度检查点
gradient_checkpointing: bool = False
# 定义初始化方法,没有参数返回值
def setup(self) -> None:
# 创建第一个卷积层 conv1
self.conv1 = nn.Conv(
self.config.d_model, # 输入通道数为 d_model
kernel_size=(3,), # 卷积核大小为 3
padding=1, # 使用 1 像素的填充
# 使用正态分布初始化卷积核,标准差为 config.init_std
kernel_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype, # 指定数据类型为 dtype
)
# 创建第二个卷积层 conv2
self.conv2 = nn.Conv(
self.config.d_model, # 输入通道数为 d_model
kernel_size=(3,), # 卷积核大小为 3
strides=2, # 步长为 2
padding=1, # 使用 1 像素的填充
# 使用正态分布初始化卷积核,标准差为 config.init_std
kernel_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype, # 指定数据类型为 dtype
)
# 创建一个 Dropout 层,丢弃率为 config.dropout
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 创建一个 FlaxWhisperEncoderLayerCollection 对象 layers,用于存储编码器的层集合
self.layers = FlaxWhisperEncoderLayerCollection(
self.config, # 传入编码器配置信息
dtype=self.dtype, # 指定数据类型为 dtype
gradient_checkpointing=self.gradient_checkpointing, # 传入梯度检查点标志
)
# 创建一个位置嵌入层 embed_positions
self.embed_positions = nn.Embed(
self.config.max_source_positions, # 最大源位置数
self.config.d_model, # 嵌入向量维度为 d_model
dtype=self.dtype, # 指定数据类型为 dtype
embedding_init=sinusoidal_embedding_init, # 使用正弦嵌入初始化
)
# 创建一个 LayerNorm 层 layer_norm,用于归一化层的输出
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 定义调用方法,接收多个输入参数并返回处理结果
def __call__(
self,
input_features: jnp.ndarray, # 输入特征,类型为 jnp.ndarray
output_attentions: bool = False, # 是否输出注意力权重,默认为 False
output_hidden_states: bool = False, # 是否输出隐藏状态,默认为 False
return_dict: bool = True, # 是否以字典形式返回,默认为 True
deterministic: bool = True, # 是否确定性运行,默认为 True
# 方法内容在此继续
# 指定函数的返回类型为包含单个元组的元组,元组的唯一元素为 jnp.ndarray 类型的对象
) -> Tuple[jnp.ndarray]:
# 如果输入特征的形状的第二维不等于 (self.config.num_mel_bins, self.config.max_source_positions * 2)
if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2):
# 抛出值错误,提示详细信息
raise ValueError(
"input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
)
# 调整输入特征的维度顺序,将第二维移动到第三维
input_features = input_features.transpose(0, 2, 1)
# 使用 GELU 激活函数对卷积层 conv1 处理后的隐藏状态进行非线性变换
hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
# 使用 GELU 激活函数对卷积层 conv2 处理后的隐藏状态进行非线性变换
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
# 生成 sinusoidal embeddings,用于位置编码,采用自然数序列 0 到 self.config.max_source_positions
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
# 停止位置编码的梯度传播,使其在后续训练中保持不变
embed_positions = jax.lax.stop_gradient(embed_positions)
# 将位置编码添加到隐藏状态中
hidden_states = hidden_states + embed_positions
# 使用 dropout 层对隐藏状态进行随机失活,若 deterministic 为 True,则保持确定性
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 将隐藏状态传递给模型的各层进行处理,包括注意力掩码和其他输出参数
outputs = self.layers(
hidden_states,
attention_mask=None,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取模型输出的最后隐藏状态
last_hidden_states = outputs[0]
# 对最后隐藏状态进行 layer normalization
last_hidden_states = self.layer_norm(last_hidden_states)
# 如果需要输出隐藏状态,更新隐藏状态的最后一个元素为经过 layernorm 处理后的最后隐藏状态
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
# 如果不返回字典形式的结果,则构造输出元组并返回
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
# 返回 FlaxBaseModelOutput 对象,包含最后的隐藏状态、隐藏状态和注意力输出
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
# 定义 FlaxWhisperDecoder 类,继承自 nn.Module,用于解码器模型
class FlaxWhisperDecoder(nn.Module):
# 定义类属性 config,类型为 WhisperConfig,dtype 默认为 jnp.float32,梯度检查点默认为 False
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
# 初始化方法,无返回值
def setup(self) -> None:
# 创建词嵌入层,vocab_size 和 d_model 从 config 中获取,dtype 为 self.dtype
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
# 创建位置嵌入层,max_target_positions 和 d_model 从 config 中获取,dtype 为 self.dtype
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
# 创建解码器层集合,传入 config、dtype 和 gradient_checkpointing 参数
self.layers = FlaxWhisperDecoderLayerCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 创建 Dropout 层,dropout 率从 config 中获取
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 创建 LayerNorm 层,使用 self.dtype 和 epsilon=1e-5
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5)
# 定义 __call__ 方法,接受多个输入和返回一个元组的输出
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: jnp.ndarray,
position_ids: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
# 获取输入和位置嵌入
input_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
# 计算隐藏状态,将输入嵌入和位置嵌入相加
hidden_states = input_embeds + position_embeds
# 对隐藏状态应用 Dropout 层
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 调用解码器层集合 layers 进行解码器的前向传播
outputs = self.layers(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取最后的隐藏状态,并应用 LayerNorm 层
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# 更新输出的 hidden_states 变量,如果需要输出隐藏状态
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
# 如果 return_dict 为 False,则返回一个元组,包含最后的隐藏状态和隐藏状态列表(如果有)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
# 如果 return_dict 为 True,则返回 FlaxBaseModelOutputWithPastAndCrossAttentions 对象
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
# 定义 FlaxWhisperModule 类,继承自 nn.Module
class FlaxWhisperModule(nn.Module):
# 定义类属性 config,类型为 WhisperConfig,dtype 默认为 jnp.float32,梯度检查点默认为 False
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
# 设置模型的初始化,初始化编码器和解码器
def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.decoder = FlaxWhisperDecoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 调用模型时执行的方法,将输入特征和解码器的输入传递给编码器和解码器,生成序列到序列的输出
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
decoder_attention_mask: jnp.ndarray,
decoder_position_ids: jnp.ndarray,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 调用编码器进行编码器输出的计算
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 调用解码器进行解码器输出的计算,传入编码器输出作为输入
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 如果不返回字典形式的输出,则将解码器和编码器的输出合并返回
if not return_dict:
return decoder_outputs + encoder_outputs
# 返回字典形式的序列到序列模型输出
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
# 获取编码器模块的方法
def _get_encoder_module(self):
return self.encoder
# 获取解码器模块的方法
def _get_decoder_module(self):
return self.decoder
# 定义一个继承自FlaxPreTrainedModel的类,用于预训练的Whisper模型
class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
# 配置类为WhisperConfig
config_class = WhisperConfig
# 基础模型的前缀名为"model"
base_model_prefix: str = "model"
# 主要输入的名称为"input_features"
main_input_name = "input_features"
# 模块类初始化为None
module_class: nn.Module = None
# 初始化函数
def __init__(
self,
config: WhisperConfig,
input_shape: Tuple[int] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
# 使用给定的config、dtype和gradient_checkpointing参数来初始化模块
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
# 如果未提供input_shape,则默认为(1, num_mel_bins, 2 * max_source_positions)
if input_shape is None:
input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
# 调用父类的初始化方法,传递config、module、input_shape等参数
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# 启用梯度检查点的函数
def enable_gradient_checkpointing(self):
# 设置模块的_gradient_checkpointing属性为True
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# 初始化权重的函数
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量input_features,全零张量,并设置最后一个位置为eos_token_id
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
# 初始化decoder_input_ids为全零张量,decoder_attention_mask为全1张量
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 获取decoder_input_ids的批次大小和序列长度,初始化decoder_position_ids
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 分割随机数生成器rng,获取params_rng和dropout_rng
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用module的init方法初始化随机参数
random_params = self.module.init(
rngs,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
)["params"]
# 如果提供了params,则将缺失的键补全并返回冻结后的参数
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
# 否则直接返回随机初始化的参数
return random_params
# 从transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache复制过来,将Bart替换为Whisper
# 这部分代码实现了缓存的初始化,但具体细节在此不详述
# 初始化缓存方法,用于快速自回归解码
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
用于快速自回归解码的批大小。定义了初始化缓存时的批大小。
max_length (`int`):
自回归解码的最大可能长度。定义了初始化缓存时的序列长度。
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` 包含 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
`last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选* 是编码器最后一层的隐藏状态的序列。
在解码器的交叉注意力中使用。
"""
# 初始化解码器输入的标识符,默认为全1矩阵,形状为(batch_size, max_length)
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
# 初始化解码器注意力遮罩,与decoder_input_ids形状相同的全1矩阵
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 初始化解码器位置标识符,广播初始化为解码器输入标识符的长度
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
# 获取解码器模块
decoder_module = module._get_decoder_module()
# 调用解码器模块进行前向传播
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
# 初始化模型参数,用于获取缓存
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0], # 使用编码器的最后隐藏状态
init_cache=True,
method=_decoder_forward, # 只需调用解码器以初始化缓存
)
# 返回解冻后的缓存变量
return unfreeze(init_variables["cache"])
@add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
# 编码方法,用于将输入特征编码为输出
def encode(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
r"""
Returns:
Example:
```
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_features, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_features, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
"""
Decoder function for the Whisper model. Transforms decoder input into model predictions.
Args:
decoder_input_ids: Input IDs for the decoder.
encoder_outputs: Outputs from the encoder model.
encoder_attention_mask: Mask for encoder attention.
decoder_attention_mask: Mask for decoder attention.
decoder_position_ids: Position IDs for the decoder.
past_key_values: Cached key-value states for fast decoding.
output_attentions: Whether to output attention weights.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return a dictionary of outputs.
train: Whether in training mode.
params: Model parameters to use.
dropout_rng: Random number generator for dropout.
Returns:
Output with past and cross attentions as specified by FlaxBaseModelOutputWithPastAndCrossAttentions.
Example:
```
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> encoder_outputs = model.encode(input_features=input_features)
>>> decoder_inputs = {"input_ids": decoder_input_ids}
>>> outputs = model.decode(**decoder_inputs, encoder_outputs=encoder_outputs)
```
"""
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 设置输出注意力权重的选择,如果未指定则使用默认配置
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 设置输出隐藏状态的选择,如果未指定则使用默认配置
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 设置返回字典的选择,如果未指定则使用默认配置
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 准备解码器输入位置信息
if decoder_position_ids is None:
# 如果解码器位置信息未提供,根据解码器注意力掩码生成位置信息
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
# 否则,根据解码器输入的形状生成默认位置信息
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# 如果解码器注意力掩码未提供,使用全1的掩码
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 如果需要处理任何的随机数生成器
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 调用模块的应用方法,传递参数和输入数据
return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
# 使用装饰器添加文档字符串到类 FlaxWhisperModel 上,描述其作用是生成不带特定头部的原始隐藏状态的 Whisper 模型转换器。
@add_start_docstrings(
"The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.",
WHISPER_START_DOCSTRING,
)
# 定义 FlaxWhisperModel 类,继承自 FlaxWhisperPreTrainedModel
class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
# 配置项,指定为 WhisperConfig 类型
config: WhisperConfig
# 计算中使用的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 模块类别为 FlaxWhisperModule
module_class = FlaxWhisperModule
# 调用函数 append_call_sample_docstring,添加示例代码的文档字符串到 FlaxWhisperModel 类上
append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
# 定义 FlaxWhisperForConditionalGenerationModule 类,继承自 nn.Module
class FlaxWhisperForConditionalGenerationModule(nn.Module):
# 配置项,指定为 WhisperConfig 类型
config: WhisperConfig
# 计算中使用的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 是否使用梯度检查点,默认为 False
gradient_checkpointing: bool = False
# 初始化函数,设置模型和语言模型头部
def setup(self) -> None:
# 创建 FlaxWhisperModule 实例作为模型
self.model = FlaxWhisperModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 创建 nn.Dense 实例作为语言模型头部
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 获取编码器模块
def _get_encoder_module(self):
return self.model.encoder
# 获取解码器模块
def _get_decoder_module(self):
return self.model.decoder
# 定义 __call__ 方法,实现类的可调用功能
def __call__(
self,
input_features,
decoder_input_ids,
decoder_attention_mask: jnp.ndarray = None,
decoder_position_ids: jnp.ndarray = None,
position_ids: jnp.ndarray = None,
attention_mask: jnp.ndarray = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 调用模型的 __call__ 方法,传入参数,并接收输出
outputs = self.model(
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 获取模型的隐藏状态输出
hidden_states = outputs[0]
# 如果配置中要求共享词嵌入
if self.config.tie_word_embeddings:
# 获取共享的嵌入层参数
shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
# 应用语言模型头部到隐藏状态上,使用共享的嵌入参数
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
# 应用语言模型头部到隐藏状态上
lm_logits = self.lm_head(hidden_states)
# 如果不要求返回字典格式
if not return_dict:
# 组装输出元组
output = (lm_logits,) + outputs[1:]
return output
# 返回 FlaxSeq2SeqLMOutput 类型的输出对象,包括 logits 和可能的其他属性
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
# 使用装饰器为类添加文档字符串,描述它是一个带有语言建模头的 Whisper 模型。
class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
# 设置模块类为 FlaxWhisperForConditionalGenerationModule
module_class = FlaxWhisperForConditionalGenerationModule
# 设定数据类型为 jnp.float32
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
# 使用装饰器添加解码方法的输入文档字符串,描述解码方法的输入参数含义。
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
# 使用装饰器替换返回值的文档字符串,指定输出类型和配置类为 FlaxCausalLMOutputWithCrossAttentions 和 WhisperConfig。
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 解析方法,负责解码任务
# decoder_input_ids: 解码器输入的 token IDs
# encoder_outputs: 编码器的输出
# encoder_attention_mask: 编码器的注意力掩码,可选
# decoder_attention_mask: 解码器的注意力掩码,可选
# decoder_position_ids: 解码器位置 IDs,可选
# past_key_values: 用于存储历史键值的字典,可选
# output_attentions: 是否输出注意力权重,可选
# output_hidden_states: 是否输出隐藏状态,可选
# return_dict: 是否以字典形式返回结果,可选
# train: 是否处于训练模式,默认为 False
# params: 模型参数,字典类型,可选
# dropout_rng: 随机数生成器用于 dropout 操作,可选
def generate(
self,
input_features,
generation_config=None,
logits_processor=None,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
**kwargs,
):
# 生成方法,用于生成文本或其他任务的输出
# input_features: 输入特征,通常是 token IDs 或其他输入形式
# generation_config: 生成配置,控制生成行为,可选
# logits_processor: logits 处理器,用于调整输出 logits,可选
# return_timestamps: 是否返回时间戳,可选
# task: 任务类型,例如生成文本的特定任务,可选
# language: 生成文本的语言,可选
# is_multilingual: 是否多语言生成,可选
# **kwargs: 其他可能的关键字参数
):
# 如果 generation_config 参数为 None,则使用类中的默认生成配置
if generation_config is None:
generation_config = self.generation_config
# 如果 return_timestamps 参数不为 None,则设置生成配置中的 return_timestamps 属性
if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps
# 如果 task 参数不为 None,则设置生成配置中的 task 属性
if task is not None:
generation_config.task = task
# 如果 is_multilingual 参数不为 None,则设置生成配置中的 is_multilingual 属性
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
# 如果 language 参数不为 None,则设置生成配置中的 language 属性
if language is not None:
generation_config.language = language
# 如果 kwargs 参数不为 None 并且包含 "decoder_input_ids" 键,则获取其长度作为 decoder_input_length
# 否则,将 decoder_input_length 设置为 1
if kwargs is not None and "decoder_input_ids" in kwargs:
decoder_input_length = len(kwargs["decoder_input_ids"])
else:
decoder_input_length = 1
# 初始化强制解码器输入列表
forced_decoder_ids = []
# 如果生成配置中具有 "is_multilingual" 属性且为 True,则处理多语言设置
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
# 如果生成配置中具有 "language" 属性,则根据语言映射添加到强制解码器输入列表
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
else:
# 否则,添加一个空语言 ID
forced_decoder_ids.append((1, None))
# 如果生成配置中具有 "task" 属性,则根据任务映射添加到强制解码器输入列表
if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
# 否则,默认添加一个 "transcribe" 任务 ID
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
# 如果生成配置中具有 "return_timestamps" 属性且为 True,或者 return_timestamps 参数为 True,则配置 logits_processor
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [
FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
]
else:
# 否则,如果存在强制解码器输入且最后一个元素不等于 no_timestamps_token_id,则添加一个默认的时间戳标记
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# 如果强制解码器输入列表长度大于 0,则将其设置到生成配置中的 forced_decoder_ids 属性
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
# 调用父类的 generate 方法,生成文本序列
return super().generate(
input_features,
generation_config,
logits_processor=logits_processor,
**kwargs,
)
# 准备生成输入的方法,根据给定参数设置解码器输入
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
# initializing the cache
# 解码器输入的批量大小和序列长度
batch_size, seq_length = decoder_input_ids.shape
# 使用 self.init_cache 方法初始化过去的键值对
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# 注意:通常需要将注意力掩码中超出 input_ids.shape[-1] 和小于 cache_length 的位置设置为 0,
# 但由于解码器使用因果掩码,这些位置已经被掩码处理。
# 因此,我们可以在此处创建一个静态的注意力掩码,对编译更有效。
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
# 如果存在解码器的注意力掩码,则计算位置 ids
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(-1) - 1
# 使用 lax.dynamic_update_slice 将 decoder_attention_mask 更新到 extended_attention_mask 中
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
# 否则,使用广播方式创建位置 ids
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
# 返回更新后的字典,包括过去的键值对、编码器输出、编码器注意力掩码、解码器注意力掩码和位置 ids
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
# 将模型输出中的过去的键值对更新到模型参数中
model_kwargs["past_key_values"] = model_outputs.past_key_values
# 更新解码器位置 ids 以便生成下一个词
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
# 覆盖并修改 FlaxWhisperForConditionalGeneration 类的文档字符串,添加了条件生成的例子和返回说明
FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
Returns:
Transcription example:
```
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(input_ids=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
# 调用函数覆盖并修改 FlaxWhisperForConditionalGeneration 类的文档字符串
overwrite_call_docstring(
FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING
)
# 追加并替换 FlaxWhisperForConditionalGeneration 类的返回文档字符串
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxWhisperForAudioClassificationModule(nn.Module):
# 定义 FlaxWhisperForAudioClassificationModule 类
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self) -> None:
# 设置函数,初始化模型组件
self.encoder = FlaxWhisperEncoder(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 设置编码器组件
self.config.is_encoder_decoder = False
# 确定非编码器解码器模式
num_layers = self.config.num_hidden_layers + 1
# 计算层数
if self.config.use_weighted_layer_sum:
# 如果使用加权层求和
self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
# 设置层权重
self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
# 定义分类投影器
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
# 定义分类器
def __call__(
self,
input_features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states: bool = True,
return_dict: bool = True,
# 重载调用操作,接受输入特征等参数,返回字典格式结果
):
# 如果未指定输出注意力的设置,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果未指定输出隐藏状态的设置,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果未指定返回字典的设置,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果编码器输出为空,则调用编码器进行编码
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 如果配置中使用加权层求和
if self.config.use_weighted_layer_sum:
# 将编码器输出堆叠起来形成张量
hidden_states = jnp.stack(encoder_outputs, axis=1)
# 对层权重进行 softmax 归一化
norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
# 加权求和后的隐藏状态
hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
# 否则直接使用编码器的第一个输出作为隐藏状态
hidden_states = encoder_outputs[0]
# 将隐藏状态投影到新的空间
hidden_states = self.projector(hidden_states)
# 对隐藏状态进行平均池化,生成池化输出
pooled_output = jnp.mean(hidden_states, axis=1)
# 将池化输出送入分类器得到预测 logits
logits = self.classifier(pooled_output)
# 如果不需要返回字典形式的输出
if not return_dict:
# 返回 logits 和编码器的其他输出状态
return (logits,) + encoder_outputs[1:]
# 否则以 FlaxSequenceClassifierOutput 的形式返回结果
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForAudioClassificationModule # 设置模型的模块类为FlaxWhisperForAudioClassificationModule
dtype: jnp.dtype = jnp.float32 # 设置数据类型为32位浮点数
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量
input_features = jnp.zeros(input_shape, dtype="f4") # 创建全零的输入特征张量,数据类型为32位浮点数
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) # 将输入张量的最后一个位置设置为配置中的eos_token_id
params_rng, dropout_rng = jax.random.split(rng) # 使用随机数生成器rng分割得到参数随机数生成器和dropout随机数生成器
rngs = {"params": params_rng, "dropout": dropout_rng} # 创建随机数生成器字典
random_params = self.module.init( # 使用模块的初始化方法初始化随机参数
rngs,
input_features=input_features,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params)) # 展开和解冻随机参数
params = flatten_dict(unfreeze(params)) # 展开和解冻给定参数
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key] # 将缺失的键添加到参数字典中
self._missing_keys = set() # 清空缺失键集合
return freeze(unflatten_dict(params)) # 冻结并返回重构的参数字典
else:
return random_params # 返回随机初始化的参数
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # 根据参数设置是否输出注意力权重
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # 根据参数设置是否输出隐藏状态
)
return_dict = return_dict if return_dict is not None else self.config.return_dict # 根据参数设置是否返回字典形式的结果
# 如果需要处理任何随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng # 将dropout随机数生成器添加到随机数生成器字典中
return self.module.apply( # 使用模块的应用方法进行前向传播
{"params": params or self.params}, # 使用给定参数或默认参数
input_features=jnp.array(input_features, dtype="f4"), # 将输入特征转换为32位浮点数的JAX数组
output_attentions=output_attentions, # 输出注意力权重
output_hidden_states=output_hidden_states, # 输出隐藏状态
return_dict=return_dict, # 返回字典形式的结果
rngs=rngs, # 随机数生成器字典
)
# 加载数据集 "google/fleurs" 的验证集,使用流式数据加载方式
ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
# 从数据集中获取下一个样本
sample = next(iter(ds))
# 使用特征提取器从音频样本中提取特征,并返回NumPy数组格式的张量
inputs = feature_extractor(
sample["audio"]["array"], # 音频数据数组
sampling_rate=sample["audio"]["sampling_rate"], # 音频采样率
return_tensors="np" # 返回NumPy数组格式的张量
)
# 获取输入特征
input_features = inputs.input_features
# 使用模型对输入特征进行推理,获取预测的 logits(未归一化的预测分数)
logits = model(input_features).logits
# 根据 logits 计算预测的类别编号
predicted_class_ids = jnp.argmax(logits).item()
# 根据模型配置中的 id2label 映射,获取预测类别的标签名
predicted_label = model.config.id2label[predicted_class_ids]
# 返回预测的标签名
predicted_label
"""
调用函数 `overwrite_call_docstring`,用于覆盖指定类的文档字符串,结合给定的文档字符串常量。
"""
overwrite_call_docstring(
FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
)
"""
调用函数 `append_replace_return_docstrings`,用于在指定类的文档字符串末尾追加并替换返回值的描述信息。
"""
append_replace_return_docstrings(
FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
)
.\models\whisper\modeling_tf_whisper.py
# 设置文件编码为 UTF-8
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow Whisper model."""
from __future__ import annotations
import math
import random
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
# 导入自定义模块
from ...activations_tf import get_tf_activation
from ...generation.configuration_utils import GenerationConfig
from ...generation.tf_logits_process import TFLogitsProcessorList
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
# 获取日志记录器
logger = logging.get_logger(__name__)
# 用于文档的配置信息
_CONFIG_FOR_DOC = "WhisperConfig"
# 预训练模型的存档列表
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai/whisper-base",
# 查看所有 Whisper 模型:https://huggingface.co/models?filter=whisper
]
# 定义一个大负数常量
LARGE_NEGATIVE = -1e8
def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
"""Returns sinusoids for positional embedding"""
# 解构形状元组
length, channels = shape
# 如果通道数不能被2整除,则抛出异常
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
# 计算时间尺度增量的对数
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
# 计算时间尺度的倒数
inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
# 缩放时间
scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
# 合并正弦和余弦的时间编码
return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
# 从 transformers.models.bart.modeling_tf_bart.shift_tokens_right 复制的函数
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
# 将 pad_token_id 和 decoder_start_token_id 转换为与 input_ids 相同的数据类型
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
# 创建起始标记的张量,用于decoder输入的起始
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
# 将输入的ids向左移动一个位置,用于生成decoder输入序列
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# 将labels中可能的-100值替换为`pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# 断言`shifted_input_ids`中的值大于等于0或者为-100
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# 确保断言操作被调用,通过在结果外包装一个identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
# 从transformers库中复制的函数,用于生成用于自注意力的因果遮罩
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
# 获取批量大小
bsz = input_ids_shape[0]
# 获取目标序列长度
tgt_len = input_ids_shape[1]
# 创建一个形状为(tgt_len, tgt_len)的矩阵,并用大负数初始化
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
# 创建一个与tgt_len长度相等的序列
mask_cond = tf.range(shape_list(mask)[-1])
# 将 mask 中小于 mask_cond + 1 的位置置为0
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
# 如果过去键值长度大于0,则在mask的左侧连接一个形状为(tgt_len, past_key_values_length)的零矩阵
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
# 返回形状为(bsz, 1, tgt_len, tgt_len)的mask矩阵
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# 从transformers库中复制的函数,用于将注意力遮罩从[bsz, seq_len]扩展到[bsz, 1, tgt_seq_len, src_seq_len]
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
# 获取输入mask的源序列长度
src_len = shape_list(mask)[1]
# 如果未提供tgt_len,则使用源序列长度
tgt_len = tgt_len if tgt_len is not None else src_len
# 创建一个常数张量,值为1.0
one_cst = tf.constant(1.0)
# 将mask转换为与one_cst相同的数据类型
mask = tf.cast(mask, dtype=one_cst.dtype)
# 在第二维上将mask扩展为形状为[bsz, 1, tgt_len, src_len]的张量
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
# 返回形状为[bsz, 1, tgt_len, src_len]的扩展后的遮罩,其中未覆盖区域的值乘以一个大负数
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFWhisperPositionalEmbedding(keras.layers.Layer):
def __init__(
self,
num_positions: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
embedding_initializer=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_positions = num_positions
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embedding_initializer = keras.initializers.get(embedding_initializer)
def build(self, input_shape):
# 添加名为'weight'的权重,形状为[num_positions, embedding_dim],由embedding_initializer初始化
self.weight = self.add_weight(
name="weight",
shape=[self.num_positions, self.embedding_dim],
initializer=self.embedding_initializer,
trainable=True,
)
super().build(input_shape)
def call(self, input_ids, past_key_values_length=0):
# 将past_key_values_length转换为tf.int32类型
past_key_values_length = tf.cast(past_key_values_length, tf.int32)
# 创建一个序列,从past_key_values_length开始,步长为1,长度为input_ids的第二个维度的长度
gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
# 返回根据gather_indices从self.weight中收集的张量
return tf.gather(self.weight, gather_indices)
class TFWhisperAttention(keras.layers.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
# 初始化用于处理输入的线性层,不使用偏置
self.k_proj = keras.layers.Dense(embed_dim, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
# 从 transformers.models.bart.modeling_tf_bart.TFBartAttention._shape 复制而来,用于整形张量
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
# 从 transformers.models.bart.modeling_tf_bart.TFBartAttention.call 复制而来,用于执行注意力计算
def call(
self,
hidden_states: tf.Tensor,
key_value_states: tf.Tensor | None = None,
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
training: Optional[bool] = False,
):
# 如果已经构建则直接返回
if self.built:
return
self.built = True
# 构建各个线性层
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# 从transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextEncoderLayer复制并修改为Whisper
class TFWhisperEncoderLayer(keras.layers.Layer):
def __init__(self, config: WhisperConfig, **kwargs):
super().__init__(**kwargs)
# 初始化层参数
self.embed_dim = config.d_model # 设置嵌入维度为config中的d_model
# 创建自注意力层对象,使用Whisper的注意力头数和dropout参数
self.self_attn = TFWhisperAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
# 创建LayerNormalization层,用于自注意力层的归一化
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
# 创建dropout层,用于全连接层之前的随机失活
self.dropout = keras.layers.Dropout(config.dropout)
# 获取激活函数
self.activation_fn = get_tf_activation(config.activation_function)
# 创建dropout层,用于激活函数之后的随机失活
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
# 创建全连接层1,输入维度为config中的encoder_ffn_dim,输出维度保持一致
self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
# 创建全连接层2,输入和输出维度都为嵌入维度
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
# 创建LayerNormalization层,用于最终层的归一化
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
# 保存配置信息
self.config = config
def call(
self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False
):
"""
Args:
hidden_states (`tf.Tensor`): 输入到层的张量,形状为`(batch, seq_len, embed_dim)`
attention_mask (`tf.Tensor`): 注意力遮罩,形状为`(batch, 1, tgt_len, src_len)`,用极大的负值表示填充元素
layer_head_mask (`tf.Tensor`): 给定层中注意力头的遮罩,形状为`(encoder_attention_heads,)`
training (bool): 是否处于训练模式
"""
residual = hidden_states # 保存输入张量作为残差连接的起始点
hidden_states = self.self_attn_layer_norm(hidden_states) # 对输入进行自注意力归一化处理
# 调用自注意力层,获取输出张量、注意力权重和未使用的附加信息
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
training=training,
)
# 断言自注意力层没有改变查询张量的形状
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training) # 应用dropout到自注意力层输出
hidden_states = residual + hidden_states # 执行残差连接
residual = hidden_states # 更新残差连接的起始点
hidden_states = self.final_layer_norm(hidden_states) # 对输出进行最终归一化处理
hidden_states = self.activation_fn(self.fc1(hidden_states)) # 应用激活函数到第一个全连接层
hidden_states = self.activation_dropout(hidden_states, training=training) # 应用dropout到激活函数输出
hidden_states = self.fc2(hidden_states) # 应用第二个全连接层
hidden_states = self.dropout(hidden_states, training=training) # 应用dropout到第二个全连接层输出
hidden_states = residual + hidden_states # 执行最终的残差连接
return hidden_states, self_attn_weights # 返回处理后的张量和自注意力权重
# 在构建网络层之前,检查是否已经构建过,如果已构建则直接返回,避免重复构建
def build(self, input_shape=None):
if self.built:
return
# 设置标志位,表示网络已经构建
self.built = True
# 如果存在 self_attn 属性,则构建 self attention 层
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
# 调用 self attention 层的 build 方法,传入 None 作为输入形状
self.self_attn.build(None)
# 如果存在 self_attn_layer_norm 属性,则构建 layer normalization 层
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
# 调用 layer normalization 层的 build 方法,传入形状为 [None, None, self.embed_dim]
self.self_attn_layer_norm.build([None, None, self.embed_dim])
# 如果存在 fc1 属性,则构建第一个全连接层
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
# 调用第一个全连接层的 build 方法,传入形状为 [None, None, self.embed_dim]
self.fc1.build([None, None, self.embed_dim])
# 如果存在 fc2 属性,则构建第二个全连接层
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
# 调用第二个全连接层的 build 方法,传入形状为 [None, None, self.config.encoder_ffn_dim]
self.fc2.build([None, None, self.config.encoder_ffn_dim])
# 如果存在 final_layer_norm 属性,则构建最终的 layer normalization 层
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
# 调用最终 layer normalization 层的 build 方法,传入形状为 [None, None, self.embed_dim]
self.final_layer_norm.build([None, None, self.embed_dim])
# 从transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextDecoderLayer复制而来,更名为TFWhisperDecoderLayer
class TFWhisperDecoderLayer(keras.layers.Layer):
def __init__(self, config: WhisperConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model # 设置嵌入维度为config中的d_model值
# 创建自注意力层,用于处理decoder自身的注意力机制
self.self_attn = TFWhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
is_decoder=True,
)
self.dropout = keras.layers.Dropout(config.dropout) # dropout层,用于模型训练过程中的随机失活
self.activation_fn = get_tf_activation(config.activation_function) # 激活函数,根据config中的激活函数类型获取对应的激活函数
self.activation_dropout = keras.layers.Dropout(config.activation_dropout) # 激活函数后的dropout层
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") # 自注意力层的归一化层
# 创建与encoder交互的注意力层,用于decoder与encoder交互信息
self.encoder_attn = TFWhisperAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
name="encoder_attn",
is_decoder=True,
)
self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") # 与encoder交互注意力层的归一化层
self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") # 全连接层1
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") # 全连接层2,输出维度与嵌入维度相同
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") # 最终的归一化层
self.config = config # 保存配置信息
def call(
self,
hidden_states,
attention_mask: tf.Tensor | None = None,
encoder_hidden_states: tf.Tensor | None = None,
encoder_attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
cross_attn_layer_head_mask: tf.Tensor | None = None,
past_key_value: Tuple[tf.Tensor] | None = None,
training=False,
# 如果模型已经建立,则直接返回,不再重复建立
if self.built:
return
# 设置标志位,表示模型已经建立
self.built = True
# 如果存在自注意力层,则构建自注意力层
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
# 如果存在自注意力层归一化层,则构建该层,输入形状为[None, None, self.embed_dim]
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
# 如果存在编码器注意力层,则构建编码器注意力层
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
# 如果存在编码器注意力层归一化层,则构建该层,输入形状为[None, None, self.embed_dim]
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
# 如果存在第一个全连接层,则构建该层,输入形状为[None, None, self.embed_dim]
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
# 如果存在第二个全连接层,则构建该层,输入形状为[None, None, self.config.decoder_ffn_dim]
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
# 如果存在最终归一化层,则构建该层,输入形状为[None, None, self.embed_dim]
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFWhisperPreTrainedModel(TFPreTrainedModel):
# 指定配置类为WhisperConfig,用于配置模型参数
config_class = WhisperConfig
# 模型基础名称前缀为"model"
base_model_prefix = "model"
# 主要输入名称为"input_features"
main_input_name = "input_features"
def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor) -> int:
"""
计算卷积层的输出长度
"""
# 根据公式计算卷积层的输出长度
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
构建网络所需的虚拟输入
Returns:
`Dict[str, tf.Tensor]`: 虚拟输入字典
"""
return {
# 创建形状为[1, num_mel_bins, max_source_positions * 2 - 1]的均匀分布随机张量
self.main_input_name: tf.random.uniform(
[1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
),
# 固定形状为[[1, 3]]的整数张量作为decoder的输入id
"decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32),
}
@property
def input_signature(self):
# 定义输入签名,指定输入张量的形状和数据类型
return {
"input_features": tf.TensorSpec((None, self.config.num_mel_bins, None), tf.float32, name="input_features"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
WHISPER_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
Parameters:
config ([`WhisperConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
WHISPER_INPUTS_DOCSTRING = r"""
"""
@keras_serializable
class TFWhisperEncoder(keras.layers.Layer):
# 指定配置类为WhisperConfig,用于配置编码器参数
config_class = WhisperConfig
"""
Transformer编码器,包含config.encoder_layers个自注意力层。每一层是一个[`TFWhisperEncoderLayer`].
Args:
config: WhisperConfig
embed_tokens (TFWhisperEmbedding): 输出嵌入
"""
# 初始化方法,接收一个WhisperConfig对象和其他关键字参数
def __init__(self, config: WhisperConfig, **kwargs):
# 调用父类的初始化方法
super().__init__(**kwargs)
# 将传入的config对象保存到self.config中
self.config = config
# 从config对象中获取encoder_layerdrop属性并保存到self.layerdrop中
self.layerdrop = config.encoder_layerdrop
# 从config对象中获取d_model属性并保存到self.embed_dim中
self.embed_dim = config.d_model
# 从config对象中获取num_mel_bins属性并保存到self.num_mel_bins中
self.num_mel_bins = config.num_mel_bins
# 从config对象中获取pad_token_id属性并保存到self.padding_idx中
self.padding_idx = config.pad_token_id
# 从config对象中获取max_source_positions属性并保存到self.max_source_positions中
self.max_source_positions = config.max_source_positions
# 如果config对象中的scale_embedding为True,则计算并保存self.embed_scale为self.embed_dim的平方根,否则为1.0
self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0
# 在call()方法中添加填充以匹配PyTorch实现
# 创建第一个卷积层,设置卷积核大小为3,步长为1,padding方式为"valid",并命名为"conv1"
self.conv1 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=1, padding="valid", name="conv1")
# 创建第二个卷积层,设置卷积核大小为3,步长为2,padding方式为"valid",并命名为"conv2"
self.conv2 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
# 创建位置嵌入层TFWhisperPositionalEmbedding对象,设置位置数量为self.max_source_positions,嵌入维度为self.embed_dim,
# 使用sinusoidal_embedding_init作为初始化方法,并命名为"embed_positions"
self.embed_positions = TFWhisperPositionalEmbedding(
num_positions=self.max_source_positions,
embedding_dim=self.embed_dim,
embedding_initializer=sinusoidal_embedding_init,
name="embed_positions",
)
# 设置位置嵌入层为不可训练状态
self.embed_positions.trainable = False
# 创建编码器层列表,包含config.encoder_layers个TFWhisperEncoderLayer对象,每个对象命名为"layers.{i}",其中i为层的索引
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
# 创建LayerNormalization层,设置epsilon为1e-5,并命名为"layer_norm"
self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
# 创建Dropout层,设置dropout率为config.dropout
self.dropout = keras.layers.Dropout(config.dropout)
# 解包输入参数的装饰器函数,定义在call()方法上
@unpack_inputs
# 定义call()方法,接收多个参数,用于模型的前向传播
def call(
self,
input_features=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
@keras_serializable
class TFWhisperDecoder(keras.layers.Layer):
config_class = WhisperConfig
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFWhisperDecoderLayer`]
Args:
config: WhisperConfig
"""
def __init__(self, config: WhisperConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = keras.layers.Dropout(config.dropout) # 初始化一个丢弃层,使用配置中的丢弃率
self.layerdrop = config.decoder_layerdrop # 设置层级丢弃率
self.padding_idx = config.pad_token_id # 设置填充标记索引
self.max_target_positions = config.max_target_positions # 最大目标位置数
self.max_source_positions = config.max_source_positions # 最大源位置数
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 # 如果配置中启用了嵌入缩放,则计算嵌入缩放值,否则为1.0
self.embed_tokens = keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="embed_tokens",
) # 初始化嵌入层,用于将输入标记映射到向量空间
self.embed_positions = TFWhisperPositionalEmbedding(
self.max_target_positions, config.d_model, name="embed_positions"
) # 初始化位置编码器,用于为输入位置信息生成嵌入向量
self.decoder_layers = [TFWhisperDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
# 初始化多层解码器层,每一层是一个 TFWhisperDecoderLayer 对象,索引命名为 layers.{i}
self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") # 初始化层归一化层
def get_input_embeddings(self):
return self.embed_tokens # 返回输入嵌入层对象
def set_input_embeddings(self, value):
self.embed_tokens = value # 设置输入嵌入层对象为指定值
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
batch_size, seq_len = input_shape[0], input_shape[1]
combined_attention_mask = tf.cond(
tf.math.greater(seq_len, 1),
lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length),
lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len),
) # 根据输入形状和过去键值长度生成解码器注意力掩码
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) # 如果存在输入的注意力掩码,则扩展和组合注意力掩码
return combined_attention_mask
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
# 构建模型,如果已经构建过,则直接返回
def build(self, input_shape=None):
if self.built:
return
# 设置标志为已构建
self.built = True
# 如果存在嵌入词向量(embed_tokens)属性,则构建它
if getattr(self, "embed_tokens", None) is not None:
with tf.name_scope(self.embed_tokens.name):
self.embed_tokens.build(None)
# 如果存在嵌入位置信息(embed_positions)属性,则构建它
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
# 如果存在层归一化(layer_norm)属性,则构建它
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
# 构建层归一化,传入的形状为 [None, None, self.config.d_model]
self.layer_norm.build([None, None, self.config.d_model])
# 如果存在解码器层(decoder_layers)属性,则依次构建每一层
if getattr(self, "decoder_layers", None) is not None:
for layer in self.decoder_layers:
with tf.name_scope(layer.name):
# 构建当前解码器层,传入的形状为 None(未指定具体输入形状)
layer.build(None)
# 添加模型的文档字符串,描述该层的输出是裸的隐藏状态,没有特定的头部信息
@add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
WHISPER_START_DOCSTRING,
)
# 使该类可以序列化为Keras模型
@keras_serializable
class TFWhisperMainLayer(keras.layers.Layer):
# 指定配置类为WhisperConfig
config_class = WhisperConfig
# 初始化方法,接受WhisperConfig对象作为参数
def __init__(self, config: WhisperConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
# 创建Whisper编码器对象
self.encoder = TFWhisperEncoder(config, name="encoder")
# 创建Whisper解码器对象
self.decoder = TFWhisperDecoder(config, name="decoder")
# 返回解码器的嵌入层
def get_input_embeddings(self):
return self.decoder.embed_tokens
# 设置解码器的嵌入层
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
# 返回编码器对象
def get_encoder(self):
return self.encoder
# 返回解码器对象
def get_decoder(self):
return self.decoder
# 模型前向传播方法,处理输入特征和解码器的各种输入及其掩码
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_features=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
# 方法内部建立模型结构,确保只建立一次
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果存在编码器对象,则在名称作用域内建立编码器
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
# 如果存在解码器对象,则在名称作用域内建立解码器
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
# 添加模型的文档字符串,描述该模型输出裸的隐藏状态,没有特定的头部信息
@add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
WHISPER_START_DOCSTRING,
)
# TFWhisperModel继承自TFWhisperPreTrainedModel类
class TFWhisperModel(TFWhisperPreTrainedModel):
# 初始化方法,接受WhisperConfig对象作为参数
def __init__(self, config: WhisperConfig, **kwargs):
super().__init__(config, **kwargs)
# 创建TFWhisperMainLayer模型对象作为该模型的一部分
self.model = TFWhisperMainLayer(config, name="model")
# 返回解码器的嵌入层
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
# 设置解码器的嵌入层
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
# 返回模型的编码器对象
def get_encoder(self):
return self.model.encoder
# 返回模型的解码器对象
def get_decoder(self):
return self.model.decoder
# 返回模型的解码器对象
def decoder(self):
return self.model.decoder
# 返回模型的编码器对象
def encoder(self):
return self.model.encoder
# 模型前向传播方法,处理输入特征和解码器的各种输入及其掩码
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_features: TFModelInputType | None = None, # 输入特征,可以是 TensorFlow 模型的输入类型或空
decoder_input_ids: np.ndarray | tf.Tensor | None = None, # 解码器输入的 token IDs,可以是 NumPy 数组、TensorFlow 张量或空
decoder_attention_mask: np.ndarray | tf.Tensor | None = None, # 解码器的注意力掩码,可以是 NumPy 数组、TensorFlow 张量或空
decoder_position_ids: np.ndarray | tf.Tensor | None = None, # 解码器的位置 IDs,可以是 NumPy 数组、TensorFlow 张量或空
head_mask: np.ndarray | tf.Tensor | None = None, # 头部掩码,可以是 NumPy 数组、TensorFlow 张量或空
decoder_head_mask: np.ndarray | tf.Tensor | None = None, # 解码器头部掩码,可以是 NumPy 数组、TensorFlow 张量或空
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, # 跨注意力头部掩码,可以是 NumPy 数组、TensorFlow 张量或空
encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, # 编码器的输出,可选,包含 NumPy 数组或 TensorFlow 张量的元组的元组
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, # 过去的键值对,可选,包含 NumPy 数组或 TensorFlow 张量的元组的元组
decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None, # 解码器的嵌入输入,可选,包含 NumPy 数组或 TensorFlow 张量的元组
use_cache: Optional[bool] = None, # 是否使用缓存,可选布尔值
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选布尔值
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选布尔值
return_dict: Optional[bool] = None, # 是否返回字典格式结果,可选布尔值
training: bool = False, # 是否处于训练模式,默认为 False
) -> Union[Tuple[tf.Tensor], TFSeq2SeqModelOutput]: # 返回值可以是 TensorFlow 张量的元组或 TFSeq2SeqModelOutput 类型
"""
Returns:
Example:
```
>>> import tensorflow as tf
>>> from transformers import TFWhisperModel, AutoFeatureExtractor
>>> from datasets import load_dataset
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
>>> input_features = inputs.input_features
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```
"""
outputs = self.model( # 调用模型的主体部分,传入各种参数进行计算
input_features=input_features, # 输入特征
decoder_input_ids=decoder_input_ids, # 解码器输入的 token IDs
decoder_attention_mask=decoder_attention_mask, # 解码器的注意力掩码
decoder_position_ids=decoder_position_ids, # 解码器的位置 IDs
head_mask=head_mask, # 头部掩码
decoder_head_mask=decoder_head_mask, # 解码器头部掩码
cross_attn_head_mask=cross_attn_head_mask, # 跨注意力头部掩码
encoder_outputs=encoder_outputs, # 编码器的输出
past_key_values=past_key_values, # 过去的键值对
decoder_inputs_embeds=decoder_inputs_embeds, # 解码器的嵌入输入
use_cache=use_cache, # 是否使用缓存
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式结果
training=training, # 是否处于训练模式
)
return outputs # 返回模型计算的结果
# 定义一个方法用于生成服务端输出
def serving_output(self, output):
# 如果配置要求使用缓存,则获取输出中的过去键值对的第二个元素
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
# 如果配置要求输出隐藏状态,则将输出中的解码器隐藏状态转换为张量
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
# 如果配置要求输出注意力分布,则将输出中的解码器注意力分布转换为张量
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
# 如果配置要求输出交叉注意力分布,则将输出中的交叉注意力分布转换为张量
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
# 如果配置要求输出隐藏状态,则将输出中的编码器隐藏状态转换为张量
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
# 如果配置要求输出注意力分布,则将输出中的编码器注意力分布转换为张量
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
# 返回一个 TFSeq2SeqModelOutput 对象,包含以下属性
return TFSeq2SeqModelOutput(
last_hidden_state=output.last_hidden_state, # 最后一个隐藏状态
past_key_values=pkv, # 过去的键值对
decoder_hidden_states=dec_hs, # 解码器隐藏状态
decoder_attentions=dec_attns, # 解码器注意力分布
cross_attentions=cross_attns, # 交叉注意力分布
encoder_last_hidden_state=output.encoder_last_hidden_state, # 编码器最后一个隐藏状态
encoder_hidden_states=enc_hs, # 编码器隐藏状态
encoder_attentions=enc_attns, # 编码器注意力分布
)
# 构建方法用于创建模型
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
# 设置已构建标志为 True
self.built = True
# 如果存在模型对象
if getattr(self, "model", None) is not None:
# 使用模型的名称空间,构建模型
with tf.name_scope(self.model.name):
self.model.build(None)
@add_start_docstrings(
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
WHISPER_START_DOCSTRING,
)
class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLanguageModelingLoss):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"proj_out.weight",
]
_keys_to_ignore_on_save = [
r"proj_out.weight",
]
def __init__(self, config: WhisperConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = TFWhisperMainLayer(config, name="model")
# 返回模型的编码器部分
def get_encoder(self):
return self.model.get_encoder()
# 返回模型的解码器部分
def get_decoder(self):
return self.model.get_decoder()
# 返回模型的输出嵌入层
def get_output_embeddings(self):
return self.get_input_embeddings()
# 设置模型的输出嵌入层
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
# 调整模型的Token嵌入层大小
def resize_token_embeddings(self, new_num_tokens: int) -> keras.layers.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
# 模型前向传播函数,用于生成输出序列
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_features: TFModelInputType | None = None,
decoder_input_ids: np.ndarray | tf.Tensor | None = None,
decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
decoder_head_mask: np.ndarray | tf.Tensor | None = None,
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
labels: np.ndarray | tf.Tensor | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
):
# 此处实现模型的具体前向计算逻辑,生成对应的输出
# 生成函数,用于生成模型的输出序列
def generate(
self,
inputs: Optional[tf.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[TFLogitsProcessorList] = None,
seed: Optional[List[int]] = None,
return_timestamps: Optional[bool] = None,
task: Optional[str] = None,
language: Optional[str] = None,
is_multilingual: Optional[bool] = None,
prompt_ids: Optional[tf.Tensor] = None,
return_token_timestamps=None,
**kwargs,
):
# 此处实现生成函数的逻辑,用于根据输入生成模型的输出序列
def serving_output(self, output):
# 如果配置要求使用缓存,则从输出的过去键值对中获取第一个元素,否则为 None
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
# 如果配置要求输出隐藏状态,则将输出的解码器隐藏状态转换为张量,否则为 None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
# 如果配置要求输出注意力权重,则将输出的解码器注意力权重转换为张量,否则为 None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
# 如果配置要求输出注意力权重,则将输出的交叉注意力权重转换为张量,否则为 None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
# 如果配置要求输出隐藏状态,则将输出的编码器隐藏状态转换为张量,否则为 None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
# 如果配置要求输出注意力权重,则将输出的编码器注意力权重转换为张量,否则为 None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
# 返回 TFSeq2SeqLMOutput 对象,其中包括 logits、过去键值对、解码器隐藏状态、解码器注意力权重、
# 交叉注意力权重、编码器最后隐藏状态、编码器隐藏状态、编码器注意力权重
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
**kwargs,
):
# 如果 past_key_values 不为 None,则仅保留 decoder_input_ids 的最后一个位置的标记
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
# 如果存在 decoder_attention_mask,则使用累积和计算 decoder_position_ids 的最后一个位置
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
# 如果没有 xla 并且存在 past,则使用 past_key_values 中的信息计算 decoder_position_ids
elif past_key_values is not None: # no xla + past
decoder_position_ids = past_key_values[0][0].shape[2]
# 否则,计算 decoder_position_ids 为 decoder_input_ids 的长度范围
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
# 将 decoder_position_ids 广播到与 decoder_input_ids 形状相同
decoder_position_ids = tf.broadcast_to(decoder_position_ids, decoder_input_ids.shape)
# 返回输入生成的准备数据字典,包括输入特征、编码器输出、过去键值对、解码器输入标识、缓存使用情况、
# 解码器注意力掩码、解码器位置标识
return {
"input_features": None, # 传递 None 是为了满足 Keras.layer.__call__ 的要求
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
def build(self, input_shape=None):
# 如果已经构建,则直接返回
if self.built:
return
# 标记为已经构建
self.built = True
# 如果模型存在,则在其命名作用域内构建模型
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
.\models\whisper\modeling_whisper.py
# 设置 Python 文件的编码格式为 UTF-8
# 版权声明和许可信息
# 此处版权归 OpenAI 和 HuggingFace Inc. 团队所有,保留所有权利
""" PyTorch Whisper model. """
# 导入必要的库和模块
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
# 导入自定义模块和函数
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
# 导入 Whisper 配置类
from .configuration_whisper import WhisperConfig
# 导入 Whisper 生成混合类
from .generation_whisper import WhisperGenerationMixin
# 检查是否可用 Flash Attention 2.0
if is_flash_attn_2_available():
# 如果可用,则导入 Flash Attention 相关函数和模块
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# 获取日志记录器
logger = logging.get_logger(__name__)
# 隐藏状态的起始位置
_HIDDEN_STATES_START_POSITION = 1
# 用于文档的配置示例
_CONFIG_FOR_DOC = "WhisperConfig"
# 用于文档的检查点示例
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
# Whisper 预训练模型的存档列表
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai/whisper-base",
# 更多 Whisper 模型请见 https://huggingface.co/models?filter=whisper
]
# 从 transformers.models.llama.modeling_llama._get_unpad_data 复制的函数
def _get_unpad_data(attention_mask):
"""从注意力掩码中获取非填充数据"""
# 计算批次中的序列长度
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
# 找到非填充数据的索引
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
# 找到批次中最大的序列长度
max_seqlen_in_batch = seqlens_in_batch.max().item()
# 计算累积序列长度
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""为位置嵌入返回正弦波"""
# 检查通道数是否为偶数
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
# 计算用于时间缩放的对数时间尺度增量
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
# 计算逆时间尺度,通过 torch.exp 函数对每个通道的对数时间尺度增量进行指数运算
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
# 创建一个二维张量,其中每一行代表一个时间步,每列代表一个通道的缩放时间
# 通过乘以逆时间尺度张量,将时间线性缩放到不同的频率
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
# 返回一个张量,包含了缩放时间的正弦和余弦值,沿着通道维度连接
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
# 创建一个和 input_ids 形状相同的全零张量 shifted_input_ids
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
# 将 input_ids 的每一行向右移动一位,将结果复制到 shifted_input_ids 中
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
# 将每一行的第一个位置填充为 decoder_start_token_id
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
# 如果 pad_token_id 为 None,则抛出数值错误
raise ValueError("self.model.config.pad_token_id has to be defined.")
# 将 shifted_input_ids 中值为 -100 的位置用 pad_token_id 替换
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
# 如果 mask_length 小于 1,则抛出数值错误
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
# 如果 mask_length 大于 sequence_length,则抛出数值错误
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon 用于概率舍入
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
# 计算应该屏蔽的 span 数量
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# 确保 num_masked_span 不超过 sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# 确保 num_masked_span 不超过 input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# 计算批次中每个序列的长度
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# 创建用于 SpecAugment 的掩码
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
# 计算最大允许的 masked span 数量
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# 计算当前输入长度下的 masked span 数量
num_masked_span = compute_num_masked_span(input_length)
# 获取随机的掩码索引
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# 选取第一个作为 dummy 索引,用于填充向量,确保所有批次维度相同
if len(spec_aug_mask_idx) == 0:
# 只有在 input_length 严格小于 sequence_length 时才可能发生这种情况,
# 此时最后一个标记必须是填充标记,可以用作虚拟掩码 id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
# 填充掩码索引数组,确保每个批次的维度相同
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# 将掩码索引扩展为掩码 spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# 将起始索引添加偏移量,使索引现在创建一个 span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# 确保我们不能使用大于 sequence_length - 1 的索引
if spec_aug_mask_idxs.max() > sequence_length - 1:
# 将 spec_aug_mask_idxs 中大于 sequence_length - 1 的索引置为 sequence_length - 1
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# 在 spec_aug_mask 上根据 spec_aug_mask_idxs 的索引位置散布值
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
# 返回 spec_aug_mask 结果
return spec_aug_mask
class WhisperPositionalEmbedding(nn.Embedding):
# 继承自 nn.Embedding 的类 WhisperPositionalEmbedding,用于位置编码的嵌入
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
# 前向传播函数,根据输入的位置 ids 返回对应的嵌入向量
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
if position_ids is None:
# 如果未提供 position_ids,则根据输入的 input_ids 和历史键值的长度返回相应的嵌入向量
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
else:
# 如果提供了 position_ids,则直接返回对应位置的嵌入向量
return self.weight[position_ids]
class WhisperAttention(nn.Module):
"""来自 'Attention Is All You Need' 论文的多头注意力模块"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[WhisperConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
# 初始化线性变换层,用于计算查询、键、值以及输出的投影
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# 从 transformers.models.bart.modeling_bart.BartAttention._shape 复制而来,用于调整张量的形状
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# 从 transformers.models.bart.modeling_bart.BartAttention.forward 复制而来,前向传播函数
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
# 实现注意力机制的前向传播,包括查询、键、值的投影以及输出的投影
# 注意力掩码、层头掩码等参数用于控制注意力的行为
# 返回值包括输出张量以及可选的注意力权重
pass
# 从 Bart->Whisper 改名,并继承自 WhisperAttention,用于实现 Flash Attention 机制
class WhisperFlashAttention2(WhisperAttention):
"""
Whisper flash attention 模块。此模块继承自 `WhisperAttention`,保持模块权重不变。
在前向传播中正确调用 Flash Attention 的公共 API,并处理可能包含的填充令牌。
"""
# 从 transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 复制而来
# 初始化函数,不同之处在于需调整成正确调用 Flash Attention 的接口及处理填充令牌的逻辑
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[WhisperConfig] = None,
):
super().__init__(embed_dim, num_heads, dropout, is_decoder, bias, is_causal, config)
# 此处可能需要进行 Flash Attention 特定的初始化
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, which is default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
# 设置一个属性来处理 Flash Attention 版本间的差异,当 flash_attn<2.1 时,生成的是左上对齐的因果掩码,而我们需要的是右下对齐的掩码,这在 flash_attn>=2.1 中是默认的行为。
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
# 重新塑形张量,将其形状变为 (bsz, seq_len, num_heads, head_dim)
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
# 从隐藏状态开始向前传播,支持可选的键值状态、过去的键值、注意力掩码和层头掩码等参数
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
# Flash Attention 前向传播函数,接受查询状态、键状态、值状态、注意力掩码、查询长度以及可选的 dropout 和 softmax_scale 参数
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Determine if causal masking should be applied based on configuration and query length
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# Conditionally adjust causal based on a specific condition for Flash Attention in RoCm
causal = self.is_causal and query_length != 1
# Check if there are any padding tokens in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
# Unpad the input states using a helper method _upad_input
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
# Retrieve sequence lengths from the computed values
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
# Compute attention scores with variable length support using flash_attn_varlen_func
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
# Pad the attention scores back to the original sequence length
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Compute attention scores without considering any padding tokens
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
# Return the final attention scores
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
# 定义一个私有方法,用于处理输入数据,对查询、键和值进行调整,以及相关的注意力掩码
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
# 获取未填充数据的索引、当前序列长度和批次中的最大序列长度信息
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
# 获取批次大小、键值序列长度、键值头数和头维度
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
# 重新形状化键层和值层,按照未填充数据的索引进行索引
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
# 根据查询长度选择不同的处理分支
if query_length == kv_seq_len:
# 如果查询长度等于键值序列长度,则对查询层进行索引操作
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
# 如果查询长度为1,则处理为标量情况
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # 这里有一个memcpy操作,性能较差。
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# 否则,根据查询长度和注意力掩码进行输入数据的解压缩操作
# 注意,这里的 -query_length: 切片假设左填充。
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
# 返回调整后的查询层、键层、值层、查询索引、当前序列长度信息和批次最大序列长度信息
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class WhisperSdpaAttention(WhisperAttention):
# 从 transformers.models.bart.modeling_bart.BartSdpaAttention.forward 复制而来,将 BART->whisper, Bart->Whisper
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
# 注意力机制的实现,用于计算注意力分数并加权隐藏状态
pass
# WHISPER_ATTENTION_CLASSES 定义了不同实现的注意力类别映射
WHISPER_ATTENTION_CLASSES = {
"eager": WhisperAttention,
"flash_attention_2": WhisperFlashAttention2,
"sdpa": WhisperSdpaAttention, # 使用 WhisperSdpaAttention 作为一种注意力实现
}
# 从 transformers.models.mbart.modeling_mbart.MBartEncoderLayer 复制而来,将 MBart->Whisper, MBART->WHISPER
class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
# self_attn 是自注意力层,根据配置选择不同的注意力实现类别
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] # 激活函数的选择
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) # 第一个全连接层
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) # 第二个全连接层
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
# 编码器层的前向传播,包括自注意力、前馈神经网络和层归一化
pass
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
# 保存输入的原始状态,用于残差连接
residual = hidden_states
# 对输入的 hidden_states 进行 layer normalization
hidden_states = self.self_attn_layer_norm(hidden_states)
# 使用 self-attention 模块处理 normalized 后的 hidden_states
# 返回处理后的 hidden_states、attention 权重和额外的信息
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# 对处理后的 hidden_states 进行 dropout
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# 将残差连接到处理后的 hidden_states 上
hidden_states = residual + hidden_states
# 再次保存当前的 hidden_states 用于残差连接
residual = hidden_states
# 对当前的 hidden_states 进行 layer normalization
hidden_states = self.final_layer_norm(hidden_states)
# 使用激活函数处理第一个全连接层的输出
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 对处理后的 hidden_states 进行 dropout
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
# 使用第二个全连接层处理 hidden_states
hidden_states = self.fc2(hidden_states)
# 对处理后的 hidden_states 进行 dropout
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# 将残差连接到处理后的 hidden_states 上
hidden_states = residual + hidden_states
# 如果 hidden_states 的数据类型是 torch.float16 并且包含无穷大或 NaN 的元素
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
# 对 hidden_states 进行截断处理,避免溢出
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# 构建输出元组,包含处理后的 hidden_states
outputs = (hidden_states,)
# 如果需要输出 attentions,将 attentions 加入输出元组中
if output_attentions:
outputs += (attn_weights,)
# 返回最终的输出元组
return outputs
# 从transformers.models.mbart.modeling_mbart.MBartDecoderLayer复制而来,MBart->Whisper, MBART->WHISPER
class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model # 设置嵌入维度为配置中的d_model
# 初始化自注意力层,根据配置选择的注意力机制类别进行设置
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout # 设置dropout概率
self.activation_fn = ACT2FN[config.activation_function] # 激活函数根据配置选择
self.activation_dropout = config.activation_dropout # 激活函数的dropout概率
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) # 初始化自注意力层的LayerNorm
# 初始化编码器注意力层,根据配置选择的注意力机制类别进行设置
self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) # 初始化编码器注意力层的LayerNorm
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) # 第一个全连接层
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) # 第二个全连接层
self.final_layer_norm = nn.LayerNorm(self.embed_dim) # 最终输出的LayerNorm
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
计算卷积层的输出长度
将输入长度减去1,然后整除2,并加上1,计算卷积层的输出长度。
"""
# 将输入长度减去1,然后整除2,并加上1,得到卷积层的输出长度
input_lengths = (input_lengths - 1) // 2 + 1
# 返回计算得到的卷积层输出长度
return input_lengths
# 定义文档字符串,描述了 `WhisperEncoder` 类的继承和用法说明
WHISPER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`WhisperConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# 空白的输入文档字符串,待后续补充输入参数的描述
WHISPER_INPUTS_DOCSTRING = r"""
"""
# 定义了用于WhisperEncoder类的输入参数的文档字符串,详细描述了每个参数的类型和作用
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class WhisperEncoder(WhisperPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
# `WhisperEncoderLayer` 的编码器层实现。
Args:
config: WhisperConfig # 输入参数为 WhisperConfig 类型的配置对象
def __init__(self, config: WhisperConfig):
super().__init__(config) # 调用父类的初始化方法,传入配置对象
self.dropout = config.dropout # 设置 dropout 概率
self.layerdrop = config.encoder_layerdrop # 设置层丢弃率
embed_dim = config.d_model # 获取嵌入维度
self.num_mel_bins = config.num_mel_bins # 获取梅尔频谱的数量
self.padding_idx = config.pad_token_id # 获取填充标记的索引
self.max_source_positions = config.max_source_positions # 获取最大源序列位置
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 # 计算嵌入缩放因子,根据配置选择是否开启
# 初始化两个一维卷积层,用于特征提取
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
# 初始化位置嵌入层,并设置为不需要梯度计算
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
# 使用 WhisperEncoderLayer 构建编码器层的列表,根据配置中的编码器层数量
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model) # 初始化层归一化层
self.gradient_checkpointing = False # 是否启用梯度检查点
# 初始化权重并应用最终处理
self.post_init()
def _freeze_parameters(self):
# 冻结所有参数,使其不需要梯度计算
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False # 设置不需要梯度计算标志为 False
def get_input_embeddings(self) -> nn.Module:
return self.conv1 # 返回输入嵌入层 conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value # 设置输入嵌入层 conv1 的值为给定的 value
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
# 定义一个名为 WhisperDecoder 的类,继承自 WhisperPreTrainedModel 类
class WhisperDecoder(WhisperPreTrainedModel):
"""
Transformer 解码器,由 *config.decoder_layers* 层组成。每层是一个 [`WhisperDecoderLayer`]
Args:
config: WhisperConfig 对象,包含模型的配置信息
"""
# 主要输入名称为 "input_ids"
main_input_name = "input_ids"
# 初始化方法,接收一个 WhisperConfig 类型的参数 config
def __init__(self, config: WhisperConfig):
super().__init__(config)
# 设置 dropout 概率
self.dropout = config.dropout
# 设置层级丢弃概率
self.layerdrop = config.decoder_layerdrop
# 设置填充索引
self.padding_idx = config.pad_token_id
# 设置最大目标位置
self.max_target_positions = config.max_target_positions
# 设置最大源位置
self.max_source_positions = config.max_source_positions
# 如果开启了 scale_embedding,则使用 sqrt(config.d_model) 作为嵌入尺度,否则为 1.0
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
# 嵌入 tokens,使用 nn.Embedding 创建一个嵌入层
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
# 嵌入位置编码,使用 WhisperPositionalEmbedding 创建一个位置编码嵌入层
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
# 创建解码器层列表,包含 config.decoder_layers 个 WhisperDecoderLayer 层
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
# 根据 config._attn_implementation 决定是否使用 Flash Attention 2.0 注意力机制
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# 根据 config._attn_implementation 决定是否使用 SDPA 注意力机制
self._use_sdpa = config._attn_implementation == "sdpa"
# 层归一化,使用 nn.LayerNorm 进行归一化处理
self.layer_norm = nn.LayerNorm(config.d_model)
# 梯度检查点,默认为 False,是否使用梯度检查点
self.gradient_checkpointing = False
# 初始化权重并应用最终处理
self.post_init()
# 获取输入嵌入层对象,返回 self.embed_tokens
def get_input_embeddings(self):
return self.embed_tokens
# 设置输入嵌入层对象为 value
def set_input_embeddings(self, value):
self.embed_tokens = value
# 前向传播函数定义,接收多个参数用于解码器的输入和控制
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return input_features
# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.size()
if self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
input_features[mask_time_indices] = 0
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
input_features[mask_feature_indices] = 0
return input_features
注释:
# 根据输入特征的尺寸在时间轴和/或特征轴上屏蔽提取的特征,根据 SpecAugment 方法
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 添加文档字符串到类定义,描述了 WhisperForConditionalGeneration 类的用途和功能
@add_start_docstrings(
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
WHISPER_START_DOCSTRING,
)
class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
# 设置基础模型前缀,用于指定模型中与权重共享相关的键
base_model_prefix = "model"
# 指定应当共享权重的键名列表
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: WhisperConfig):
# 调用父类的初始化方法,传入 WhisperConfig 对象
super().__init__(config)
# 创建 WhisperModel 对象,并将其保存在实例变量 self.model 中
self.model = WhisperModel(config)
# 创建线性层,用于输出模型的预测结果,不带偏置项
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
# 调用额外的初始化方法,用于权重初始化和最终处理
self.post_init()
def get_encoder(self):
# 返回模型中的编码器部分,通过调用 self.model 的 get_encoder 方法实现
return self.model.get_encoder()
def get_decoder(self):
# 返回模型中的解码器部分,通过调用 self.model 的 get_decoder 方法实现
return self.model.get_decoder()
def get_output_embeddings(self):
# 返回输出嵌入层,即预测输出的线性层 self.proj_out
return self.proj_out
def set_output_embeddings(self, new_embeddings):
# 设置新的输出嵌入层,更新 self.proj_out 的值为 new_embeddings
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
# 返回模型中的输入嵌入层,通过调用 self.model 的 get_input_embeddings 方法实现
return self.model.get_input_embeddings()
def freeze_encoder(self):
"""
调用此方法将禁用 Whisper 编码器的梯度计算,使其在训练过程中不会更新参数。
"""
self.model.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
覆盖父类中的 forward 方法,实现 Whisper 模型的前向传播。
Args:
input_features (Optional[torch.FloatTensor], optional): 输入特征张量。默认为 None。
attention_mask (Optional[torch.LongTensor], optional): 注意力掩码张量。默认为 None。
decoder_input_ids (Optional[torch.LongTensor], optional): 解码器输入 ID 张量。默认为 None。
decoder_attention_mask (Optional[torch.LongTensor], optional): 解码器注意力掩码张量。默认为 None。
head_mask (Optional[torch.Tensor], optional): 头部掩码张量。默认为 None。
decoder_head_mask (Optional[torch.Tensor], optional): 解码器头部掩码张量。默认为 None。
cross_attn_head_mask (Optional[torch.Tensor], optional): 交叉注意力头部掩码张量。默认为 None。
encoder_outputs (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): 编码器输出元组。默认为 None。
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): 过去的键值元组。默认为 None。
decoder_inputs_embeds (Optional[Tuple[torch.FloatTensor]], optional): 解码器输入嵌入张量元组。默认为 None。
decoder_position_ids (Optional[Tuple[torch.LongTensor]], optional): 解码器位置 ID 张量元组。默认为 None。
labels (Optional[torch.LongTensor], optional): 标签张量。默认为 None。
use_cache (Optional[bool], optional): 是否使用缓存。默认为 None。
output_attentions (Optional[bool], optional): 是否输出注意力。默认为 None。
output_hidden_states (Optional[bool], optional): 是否输出隐藏状态。默认为 None。
return_dict (Optional[bool], optional): 是否返回字典。默认为 None。
Returns:
Seq2SeqLMOutput: 序列到序列的语言模型输出。
"""
# 实际的前向传播逻辑将在此处实现
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
**kwargs,
):
"""
准备生成过程中的输入,以便在生成文本时使用。
Args:
decoder_input_ids: 解码器输入 ID。
past_key_values: 过去的键值对。
use_cache: 是否使用缓存。
encoder_outputs: 编码器输出。
attention_mask: 注意力掩码。
decoder_attention_mask: 解码器注意力掩码。
**kwargs: 其他关键字参数。
Returns:
dict: 包含生成过程输入的字典。
"""
# 实现生成输入准备的逻辑
):
# 初始化变量 decoder_position_ids 为 None
decoder_position_ids = None
# 如果存在 decoder_attention_mask,计算每个位置累积和后减一,并确保不小于零
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
# 如果存在 past_key_values,则获取其长度
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# 某些生成方法可能只传递最后一个输入 ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# 默认行为:保留最后一个 ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
# 仅保留 decoder_input_ids 中的后缀部分
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# 如果存在 decoder_position_ids 并且其长度大于 decoder_input_ids 的长度,则也截断之
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
# 返回重构后的信息字典
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
# 初始化重新排序的 past_key_values
reordered_past = ()
# 遍历 past_key_values 中的每个层的过去状态
for layer_past in past_key_values:
# 使用 beam_idx 对每个 past_state 进行重新排序,并将结果添加到 reordered_past 中
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
# 返回重新排序后的 past_key_values
return reordered_past
class WhisperDecoderWrapper(WhisperPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
"""
def __init__(self, config):
super().__init__(config)
# 设置当前模型不是编码器-解码器结构
config.is_encoder_decoder = False
# 初始化一个WhisperDecoder对象作为解码器
self.decoder = WhisperDecoder(config)
def get_input_embeddings(self):
# 返回当前模型的解码器的嵌入层
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
# 设置当前模型的解码器的嵌入层
self.decoder.embed_tokens = value
def forward(self, *args, **kwargs):
# 前向传播,调用当前模型的解码器进行处理
return self.decoder(*args, **kwargs)
@add_start_docstrings(
"""
Whisper decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings).
""",
WHISPER_START_DOCSTRING,
)
class WhisperForCausalLM(WhisperPreTrainedModel):
_tied_weights_keys = ["proj_out.weight"]
main_input_name = "input_ids"
def __init__(self, config):
super().__init__(config)
# 设置当前模型不是编码器-解码器结构
config.is_encoder_decoder = False
# 初始化一个WhisperDecoderWrapper对象作为当前模型的主模型
self.model = WhisperDecoderWrapper(config)
# 初始化一个线性层,作为模型的输出投影层,将隐藏状态映射到词汇表大小的向量空间
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 初始化权重并进行最终处理
self.post_init()
def get_output_embeddings(self):
# 返回当前模型的输出投影层
return self.proj_out
def set_output_embeddings(self, new_embeddings):
# 设置当前模型的输出投影层
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
# 返回当前模型主模型的解码器的嵌入层
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
# 设置当前模型主模型的解码器的嵌入层
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
# 设置当前模型主模型的解码器
self.model.decoder = decoder
def get_decoder(self):
# 返回当前模型主模型的解码器
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 前向传播函数,调用当前模型的主模型进行处理
pass # 实际操作在self.model.forward中定义
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
**kwargs,
):
# 为生成过程准备输入数据,调用当前模型的主模型方法处理
pass # 实际操作在self.model.prepare_inputs_for_generation中定义
):
# 如果过去的键值不为 None,则获取过去键值的第一个元素的第三维度长度作为过去长度
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# 某些生成方法可能只传递最后一个输入 ID
if input_ids.shape[1] > past_length:
# 如果输入的 ID 数量大于过去长度,则移除前缀长度为过去长度
remove_prefix_length = past_length
else:
# 否则,默认行为:只保留最后一个 ID
remove_prefix_length = input_ids.shape[1] - 1
# 更新输入的 ID,移除前缀部分
input_ids = input_ids[:, remove_prefix_length:]
# 返回一个包含各种输出和参数的字典
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"input_ids": input_ids,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
# 遍历过去键值中的每一层,并重新排序以匹配 beam_idx 的顺序
for layer_past in past_key_values:
reordered_past += (
# 对于每个过去状态,根据 beam_idx 在设备上选择相应的索引
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
# 返回重新排序后的过去键值
return reordered_past
@add_start_docstrings(
"""
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
like SUPERB Keyword Spotting.
""",
WHISPER_ENCODER_INPUTS_DOCSTRING,
)
class WhisperForAudioClassification(WhisperPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.encoder = WhisperEncoder(config) # 初始化Whisper编码器,使用给定的配置
num_layers = config.num_hidden_layers + 1 # 计算层数,包括transformer层和输入嵌入层
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) # 如果使用加权层求和,初始化权重参数
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) # 初始化线性投影层
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) # 初始化分类器线性层
# Initialize weights and apply final processing
self.post_init() # 执行初始化权重和最终处理步骤
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training. Only the projection layers and classification head will be updated.
"""
self.encoder._freeze_parameters() # 冻结Whisper编码器的参数,使其在训练过程中不更新梯度,只更新投影层和分类头部
def get_input_embeddings(self) -> nn.Module:
return self.encoder.get_input_embeddings() # 返回Whisper编码器的输入嵌入层模块
def set_input_embeddings(self, value: nn.Module):
self.encoder.set_input_embeddings(value) # 设置Whisper编码器的输入嵌入层模块
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
.\models\whisper\processing_whisper.py
# coding=utf-8
# 文件编码声明,使用 UTF-8 编码
# Copyright 2022 The HuggingFace Inc. team.
# 版权声明,标明代码版权归 HuggingFace Inc. 团队所有
# Licensed under the Apache License, Version 2.0 (the "License");
# 遵循 Apache License, Version 2.0 许可协议,允许在特定条件下使用本代码
# you may not use this file except in compliance with the License.
# 除非符合许可协议的条件,否则禁止使用本文件
# You may obtain a copy of the License at
# 可以在上述许可协议链接处获取完整的许可协议文本
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 根据适用法律或书面同意,本软件按“原样”分发,不提供任何担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不提供任何担保或条件,包括但不限于特定用途的适销性或适用性
# See the License for the specific language governing permissions and
# limitations under the License.
# 请参阅许可协议以了解权限和限制的具体条款
"""
Speech processor class for Whisper
"""
# 注释: Whisper 的语音处理器类
from ...processing_utils import ProcessorMixin
# 导入处理工具类 ProcessorMixin
class WhisperProcessor(ProcessorMixin):
r"""
Constructs a Whisper processor which wraps a Whisper feature extractor and a Whisper tokenizer into a single
processor.
构建 Whisper 处理器,将 Whisper 特征提取器和 Whisper 分词器封装到一个处理器中
[`WhisperProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`WhisperTokenizer`]. See
the [`~WhisperProcessor.__call__`] and [`~WhisperProcessor.decode`] for more information.
WhisperProcessor 提供了所有 WhisperFeatureExtractor 和 WhisperTokenizer 的功能。查看 `~WhisperProcessor.__call__` 和 `~WhisperProcessor.decode` 获取更多信息
Args:
feature_extractor (`WhisperFeatureExtractor`):
An instance of [`WhisperFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`WhisperTokenizer`):
An instance of [`WhisperTokenizer`]. The tokenizer is a required input.
"""
# 参数说明:feature_extractor 是 WhisperFeatureExtractor 实例,tokenizer 是 WhisperTokenizer 实例
feature_extractor_class = "WhisperFeatureExtractor"
# 类属性:特征提取器类名为 WhisperFeatureExtractor
tokenizer_class = "WhisperTokenizer"
# 类属性:分词器类名为 WhisperTokenizer
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
# 调用父类的初始化方法,使用 feature_extractor 和 tokenizer 进行初始化
self.current_processor = self.feature_extractor
# 设置当前处理器为特征提取器对象
self._in_target_context_manager = False
# 设置目标上下文管理器状态为 False
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
# 调用分词器的方法获取解码器提示符的 ID
def __call__(self, *args, **kwargs):
"""
Forwards the `audio` argument to WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] and the `text`
argument to [`~WhisperTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
information.
"""
# 如果在目标上下文管理器中,则调用当前处理器并返回结果
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
# 从 kwargs 中弹出 `audio`, `sampling_rate`, `text` 参数
audio = kwargs.pop("audio", None)
sampling_rate = kwargs.pop("sampling_rate", None)
text = kwargs.pop("text", None)
# 如果有额外的位置参数,将第一个作为 `audio` 参数处理,其余作为 args
if len(args) > 0:
audio = args[0]
args = args[1:]
# 如果 `audio` 和 `text` 都为 None,则抛出 ValueError
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
# 如果有 `audio`,调用特征提取器处理音频输入
if audio is not None:
inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
# 如果有 `text`,调用分词器处理文本输入
if text is not None:
encodings = self.tokenizer(text, **kwargs)
# 如果 `text` 为 None,返回特征提取器处理的结果
if text is None:
return inputs
# 如果 `audio` 为 None,返回分词器处理的结果
elif audio is None:
return encodings
else:
# 将分词器的编码结果作为特征提取器结果的标签
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
# 调用分词器的批量解码方法并返回结果
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
# 调用分词器的解码方法并返回结果
return self.tokenizer.decode(*args, **kwargs)
def get_prompt_ids(self, text: str, return_tensors="np"):
# 调用分词器的获取提示符编号方法并返回结果
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
.\models\whisper\tokenization_whisper.py
# 定义一个函数,用于将 UTF-8 字节映射为 Unicode 字符。避免映射到空白字符或控制字符,以确保 BPE 处理正常运行。
def bytes_to_unicode():
"""
返回 utf-8 字节列表及其对应的 Unicode 字符映射。特别避免映射到空白字符或控制字符,以免在 BPE 处理时出错。
可逆的 BPE 编码适用于 Unicode 字符串。这意味着如果要避免 UNK 标记,需要在词汇表中包含大量的 Unicode 字符。
例如,处理约 100 亿个标记的数据集时,可能需要大约 5000 个字符才能覆盖得好。这占了通常 32K BPE 词汇表的显著比例。
为了避免这种情况,我们需要 UTF-8 字节与 Unicode 字符串之间的查找表。
"""
# 定义基本的 UTF-8 字节范围
bs = (
list(range(ord("!"), ord("~") + 1)) + # printable ASCII 字符
list(range(ord("¡"), ord("¬") + 1)) + # Latin-1 扩展字符
list(range(ord("®"), ord("ÿ") + 1)) # Latin-1 补充字符
)
cs = bs[:] # 复制基本字节范围
n = 0
# 遍历所有 2^8 个字节值,确保包含所有可能的字节值
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs] # 将编码的字节映射为 Unicode 字符
return dict(zip(bs, cs)) # 返回字节到 Unicode 字符的映射字典
# 获取与当前模块相关联的日志记录器
logger = logging.get_logger(__name__)
# 定义一个函数,用于获取给定单词中的符号对集合
def get_pairs(word):
"""
返回单词中的符号对集合。
单词被表示为符号的元组(符号是长度可变的字符串)。
"""
pairs = set()
prev_char = word[0] # 获取单词的第一个符号
# 对单词中除第一个字符外的每个字符进行迭代
for char in word[1:]:
# 将前一个字符和当前字符作为一个元组加入到集合中
pairs.add((prev_char, char))
# 更新前一个字符为当前字符,以便下一次迭代使用
prev_char = char
# 返回存储了字符对的集合
return pairs
# 支持的语言列表,每个键值对表示语言代码和语言名称的映射关系
LANGUAGES = {
"en": "english", # 英语
"zh": "chinese", # 中文
"de": "german", # 德语
"es": "spanish", # 西班牙语
"ru": "russian", # 俄语
"ko": "korean", # 韩语
"fr": "french", # 法语
"ja": "japanese", # 日语
"pt": "portuguese", # 葡萄牙语
"tr": "turkish", # 土耳其语
"pl": "polish", # 波兰语
"ca": "catalan", # 加泰罗尼亚语
"nl": "dutch", # 荷兰语
"ar": "arabic", # 阿拉伯语
"sv": "swedish", # 瑞典语
"it": "italian", # 意大利语
"id": "indonesian", # 印尼语
"hi": "hindi", # 印地语
"fi": "finnish", # 芬兰语
"vi": "vietnamese", # 越南语
"he": "hebrew", # 希伯来语
"uk": "ukrainian", # 乌克兰语
"el": "greek", # 希腊语
"ms": "malay", # 马来语
"cs": "czech", # 捷克语
"ro": "romanian", # 罗马尼亚语
"da": "danish", # 丹麦语
"hu": "hungarian", # 匈牙利语
"ta": "tamil", # 泰米尔语
"no": "norwegian", # 挪威语
"th": "thai", # 泰语
"ur": "urdu", # 乌尔都语
"hr": "croatian", # 克罗地亚语
"bg": "bulgarian", # 保加利亚语
"lt": "lithuanian", # 立陶宛语
"la": "latin", # 拉丁语
"mi": "maori", # 毛利语
"ml": "malayalam", # 马拉雅拉姆语
"cy": "welsh", # 威尔士语
"sk": "slovak", # 斯洛伐克语
"te": "telugu", # 泰卢固语
"fa": "persian", # 波斯语
"lv": "latvian", # 拉脱维亚语
"bn": "bengali", # 孟加拉语
"sr": "serbian", # 塞尔维亚语
"az": "azerbaijani", # 阿塞拜疆语
"sl": "slovenian", # 斯洛文尼亚语
"kn": "kannada", # 卡纳达语
"et": "estonian", # 爱沙尼亚语
"mk": "macedonian", # 马其顿语
"br": "breton", # 布列塔尼语
"eu": "basque", # 巴斯克语
"is": "icelandic", # 冰岛语
"hy": "armenian", # 亚美尼亚语
"ne": "nepali", # 尼泊尔语
"mn": "mongolian", # 蒙古语
"bs": "bosnian", # 波斯尼亚语
"kk": "kazakh", # 哈萨克语
"sq": "albanian", # 阿尔巴尼亚语
"sw": "swahili", # 斯瓦希里语
"gl": "galician", # 加利西亚语
"mr": "marathi", # 马拉地语
"pa": "punjabi", # 旁遮普语
"si": "sinhala", # 僧伽罗语
"km": "khmer", # 高棉语
"sn": "shona", # 绍纳语
"yo": "yoruba", # 约鲁巴语
"so": "somali", # 索马里语
"af": "afrikaans", # 南非荷兰语
"oc": "occitan", # 奥克语
"ka": "georgian", # 格鲁吉亚语
"be": "belarusian", # 白俄罗斯语
"tg": "tajik", # 塔吉克语
"sd": "sindhi", # 信德语
"gu": "gujarati", # 古吉拉特语
"am": "amharic", # 阿姆哈拉语
"yi": "yiddish", # 意第绪语
"lo": "lao", # 老挝语
"uz": "uzbek", # 乌兹别克语
"fo": "faroese", # 法罗语
"ht": "haitian creole", # 海地克里奥尔语
"ps": "pashto", # 普什图语
"tk": "turkmen", # 土库曼语
"nn": "nynorsk", # 新挪威语
"mt": "maltese", # 马耳他语
"sa": "sanskrit", # 梵语
"lb": "luxembourgish", # 卢森堡语
"my": "myanmar", # 缅甸语
"bo": "tibetan", # 藏语
"tl": "tagalog", # 菲律宾语
"mg": "malagasy", # 马达加斯加语
"as": "assamese", # 阿萨姆语
"tt": "tatar", # 鞑靼语
"haw": "hawaiian", # 夏威夷语
"ln": "lingala", # 林加拉语
"ha": "hausa", # 豪萨语
"ba": "bashkir", # 巴什基尔语
"jw": "javanese", # 爪哇语
"su": "sundanese", # 巽他语
"yue": "cantonese", # 粤语
}
# 根据语言名称查找对应的语言代码,包含几个语言别名
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my", # 缅甸语
"valencian": "ca", # 瓦伦西亚语
"flemish": "nl", # 佛兰芒语
"haitian": "ht", # 海地克里奥尔语
"letzeburgesch": "lb", # 卢森堡语
"pushto": "ps", # 普什图语
"panjabi": "pa", # 旁遮普语
"moldavian":
# 这个类的目的是为了初始化一个文本处理模型的配置,用于处理文本生成和处理任务。
# 下面这些变量定义了与模型配置相关的常量和映射
vocab_files_names = VOCAB_FILES_NAMES # 从外部引入的词汇表文件名常量
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP # 预训练词汇文件的映射表
max_model_input_sizes = MAX_MODEL_INPUT_SIZES # 最大模型输入尺寸的常量
model_input_names = ["input_ids", "attention_mask"] # 模型输入的名称列表
def __init__(
self,
vocab_file,
merges_file,
normalizer_file=None,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_token=None,
add_prefix_space=False,
language=None,
task=None,
predict_timestamps=False,
**kwargs,
):
# 初始化函数,用于创建一个新的文本处理模型配置实例
self.vocab_file = vocab_file # 词汇文件路径
self.merges_file = merges_file # 合并文件路径
self.normalizer_file = normalizer_file # 规范化器文件路径(可选)
self.errors = errors # 解码字节为UTF-8时的错误处理方式
self.unk_token = unk_token # 未知标记(默认为"<|endoftext|>")
self.bos_token = bos_token # 序列起始标记(默认为"<|endoftext|>")
self.eos_token = eos_token # 序列结束标记(默认为"<|endoftext|>")
self.pad_token = pad_token # 填充标记(可选)
self.add_prefix_space = add_prefix_space # 是否在输入前添加空格(默认为False)
self.language = language # 文本语言(可选)
self.task = task # 任务标识符(可选)
self.predict_timestamps = predict_timestamps # 是否预测时间戳(默认为False)
self.kwargs = kwargs # 其它未命名参数
):
# 如果给定的 bos_token 是字符串,则创建一个特殊的 AddedToken 对象
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(bos_token, str)
else bos_token
)
# 如果给定的 eos_token 是字符串,则创建一个特殊的 AddedToken 对象
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(eos_token, str)
else eos_token
)
# 如果给定的 unk_token 是字符串,则创建一个特殊的 AddedToken 对象
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(unk_token, str)
else unk_token
)
# 如果给定的 pad_token 是字符串,则创建一个特殊的 AddedToken 对象
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(pad_token, str)
else pad_token
)
# 使用 UTF-8 编码打开词汇文件,并加载其中的编码器
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
# 根据编码器创建解码器,反转键值对
self.decoder = {v: k for k, v in self.encoder.items()}
# 设置解码时的错误处理方式
self.errors = errors # how to handle errors in decoding
# 创建字节到 Unicode 的编码映射
self.byte_encoder = bytes_to_unicode()
# 创建 Unicode 到字节的解码映射
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
# 使用 UTF-8 编码打开 BPE 合并文件,读取并处理成列表
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
# 将 BPE 合并操作转换为元组并创建一个排名字典
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# 初始化缓存字典
self.cache = {}
# 设置是否在词前添加空格的选项
self.add_prefix_space = add_prefix_space
# 如果提供了正规化器文件,则使用 UTF-8 编码打开它并加载英语拼写正规化器
if normalizer_file is not None:
with open(normalizer_file, encoding="utf-8") as vocab_handle:
self.english_spelling_normalizer = json.load(vocab_handle)
else:
# 否则,将英语拼写正规化器设置为 None
self.english_spelling_normalizer = None
# 正则表达式模式,用于匹配文本中的特定模式,包括缩略词和时间戳
# 添加 re.IGNORECASE 选项以支持大小写不敏感的 BPE 合并操作
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
# 初始化父类 GPT2Tokenizer,并传递参数
self.language = language
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
# 设置语言模型的任务类型和是否预测时间戳的选项
self.task = task
self.predict_timestamps = predict_timestamps
@property
# 返回词汇表的大小
def vocab_size(self) -> int:
return len(self.encoder)
# 返回当前词汇表及额外添加的 token
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# 从 transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe 复制,修改为使用 Whisper
def bpe(self, token):
# 如果缓存中已经存在该 token 的处理结果,则直接返回缓存中的结果
if token in self.cache:
return self.cache[token]
# 将 token 转换成元组形式
word = tuple(token)
# 获取 token 的所有字符对
pairs = get_pairs(word)
# 如果 token 没有字符对,则直接返回原始 token
if not pairs:
return token
# 开始 BPE 算法的处理过程,直到不能再合并字符对为止
while True:
# 找到频率最低的字符对
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
# 如果找到的字符对不在预训练的 BPE 词汇表中,则停止合并
if bigram not in self.bpe_ranks:
break
# 将词汇表中的第一个字符和第二个字符分开
first, second = bigram
new_word = []
i = 0
# 遍历 token 的所有字符
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
# 如果找不到第一个字符,则直接添加剩余的字符
new_word.extend(word[i:])
break
else:
# 将第一个字符前面的字符添加到新的单词中
new_word.extend(word[i:j])
i = j
# 如果当前字符和下一个字符组成一个字符对,则合并
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
# 否则将当前字符添加到新的单词中
new_word.append(word[i])
i += 1
# 更新处理后的 token 为元组形式
new_word = tuple(new_word)
word = new_word
# 如果只剩下一个字符,则停止处理
if len(word) == 1:
break
else:
# 否则继续获取新的字符对
pairs = get_pairs(word)
# 将处理后的 token 转换为字符串形式
word = " ".join(word)
# 将处理结果加入缓存中
self.cache[token] = word
# 返回处理后的结果
return word
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None):
"""
Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
update the prefix tokens as required when fine-tuning. Example:
```
>>> # instantiate the tokenizer and set the prefix token to Spanish
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
>>> # now switch the prefix token from Spanish to French
>>> tokenizer.set_prefix_tokens(language="french")
```
Args:
language (`str`, *optional*, defaults to `None`):
The language of the transcription text.
task (`str`, *optional*, defaults to `None`):
Task identifier to append at the start of sequence (if any).
predict_timestamps (`bool`, *optional*, defaults to `None`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
"""
# 更新语言设置,如果未提供则保持原样
self.language = language if language is not None else self.language
# 更新任务标识,如果未提供则保持原样
self.task = task if task is not None else self.task
# 更新是否预测时间戳的设置,如果未提供则保持原样
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
@property
# 返回一个包含特殊前缀 token 的列表,用于初始化模型输入
def prefix_tokens(self) -> List[int]:
# 将特殊开始转录标记转换为其对应的 token ID
bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
# 将特殊翻译标记转换为其对应的 token ID
translate_token_id = self.convert_tokens_to_ids("<|translate|>")
# 将特殊转录标记转换为其对应的 token ID
transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
# 将特殊无时间戳标记转换为其对应的 token ID
notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
# 取得所有语言代码的元组
langs = tuple(LANGUAGES.keys())
# 如果指定了语言
if self.language is not None:
# 将语言名称转换为小写
self.language = self.language.lower()
# 如果语言在语言到语言代码的映射中
if self.language in TO_LANGUAGE_CODE:
# 获取语言代码
language_id = TO_LANGUAGE_CODE[self.language]
# 如果语言在语言代码列表中
elif self.language in TO_LANGUAGE_CODE.values():
# 直接使用语言代码
language_id = self.language
else:
# 判断语言是否是两位字母代码
is_language_code = len(self.language) == 2
# 抛出不支持的语言异常,提示支持的语言列表
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
# 如果指定了任务
if self.task is not None:
# 如果任务不在任务ID列表中,则抛出异常
if self.task not in TASK_IDS:
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
# 构建前缀序列,起始于开始 token
bos_sequence = [bos_token_id]
# 如果指定了语言,则将语言相关 token ID 添加到序列中
if self.language is not None:
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
# 如果指定了任务,则根据任务类型添加对应的 token ID
if self.task is not None:
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
# 如果不需要预测时间戳,则添加不含时间戳的 token ID
if not self.predict_timestamps:
bos_sequence.append(notimestamps_token_id)
# 返回构建好的前缀序列
return bos_sequence
# 从一个序列构建模型输入,通过添加结束 token
# 拷贝自 transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
# 如果没有第二个序列,则直接添加第一个序列和结束 token
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
# 否则,按照 API 一致性保留对序列对的处理逻辑,添加前缀 token、两个序列以及结束 token
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
# 获取包含特殊 token 的掩码
# 拷贝自 transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
):
# 返回特殊标记掩码列表,用于指示是否为特殊标记
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
# 前缀部分的特殊标记列表全为1
prefix_ones = [1] * len(self.prefix_tokens)
# 后缀部分的特殊标记列表包含一个1
suffix_ones = [1]
if token_ids_1 is None:
# 如果没有第二个序列,返回前缀标记后接0序列标记,再接一个后缀标记
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
# 如果有第二个序列,返回前缀标记后接两个0序列标记,再接一个后缀标记
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# 从GPT2Tokenizer._tokenize复制并改名为Whisper的私有方法_tokenize
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
# 使用正则表达式查找文本中的所有匹配项
for token in re.findall(self.pat, text):
# 将每个token编码为字节,并使用字节编码器映射到Unicode字符串,避免BPE的控制标记
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
# 将BPE处理后的token分割并扩展到bpe_tokens列表中
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# 从GPT2Tokenizer._convert_token_to_id复制并改名为Whisper的私有方法_convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
# 根据词汇表将token转换为对应的ID,若未找到使用unk_token对应的ID
return self.encoder.get(token, self.encoder.get(self.unk_token))
# 从GPT2Tokenizer._convert_id_to_token复制并改名为Whisper的私有方法_convert_id_to_token
def _convert_id_to_token(self, index):
"""
Converts an index (integer) in a token (str) using the vocab. Whisper's base tokenizer always decodes OOV
tokens as "", thus we do not use the `unk_token` here.
"""
# 根据词汇表将ID转换为对应的token,若未找到使用空字符串表示未知token
return self.decoder.get(index, "")
# 私有方法_normalize已废弃,在v5版本中将被移除,建议使用normalize方法
def _normalize(self, text):
warnings.warn(
"The private method `_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper English normalizer using the `normalize` method."
)
# 直接返回文本的normalize结果
return self.normalize(text)
# 发出警告,提示私有方法 `_basic_normalize` 已被弃用,将在 Transformers 的 v5 版本中移除
# 建议使用 `basic_normalize` 方法来规范化输入字符串
def _basic_normalize(self, text, remove_diacritics=False):
warnings.warn(
"The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method."
)
# 调用 `basic_normalize` 方法来规范化文本
return self.basic_normalize(text, remove_diacritics=remove_diacritics)
def normalize(self, text):
"""
使用 `EnglishTextNormalizer` 类来规范化给定的字符串,该类对英语文本进行常见转换。
"""
# 创建 `EnglishTextNormalizer` 实例
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
# 调用实例的规范化方法来处理文本
return normalizer(text)
@staticmethod
def basic_normalize(text, remove_diacritics=False):
"""
使用 `BasicTextNormalizer` 类来规范化给定的字符串,该类对多语言文本进行常见转换。
"""
# 创建 `BasicTextNormalizer` 实例
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
# 调用实例的规范化方法来处理文本
return normalizer(text)
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
"""
解码带有时间戳的 token 序列,时间戳的 token ID 大于特殊 token 的 ID 范围,会被 `decode()` 忽略。
该方法将带有时间戳的 token 序列解码,例如 "<|1.08|>"。
"""
# 时间戳开始的 token ID
timestamp_begin = self.all_special_ids[-1] + 1
# 输出列表,用于存储解码后的字符串或 token 列表
outputs = [[]]
# 当前最大时间戳
cur_max_timestamp = 0.0
# 前一个段落的长度
prev_segments_len = 0.0
for token in token_ids:
if token >= timestamp_begin:
# 计算时间戳
timestamp = float((token - timestamp_begin) * time_precision)
if timestamp < cur_max_timestamp:
# 下一个段落已开始
prev_segments_len += cur_max_timestamp
cur_max_timestamp = timestamp
# 添加带有时间戳的标记到输出列表
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
outputs.append([])
else:
# 将 token 添加到当前段落的输出列表中
outputs[-1].append(token)
# 解码输出列表中的每个段落,并将结果连接成一个字符串
outputs = [
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
]
return "".join(outputs)
def _compute_offsets(self, token_ids, time_precision=0.02):
"""
Compute offsets for a given tokenized input
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
offsets = []
# 确保 token_ids 是放置在 CPU 上的 torch 张量
if "torch" in str(type(token_ids)) and (hasattr(token_ids, "cpu") and callable(token_ids.cpu)):
token_ids = token_ids.cpu()
# 将 token_ids 转换为 numpy 数组
token_ids = np.array(token_ids)
# 检查是否只能处理单个输入
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
raise ValueError("Can only process a single input at a time")
# 确定时间戳开始的位置
timestamp_begin = self.all_special_ids[-1] + 1
# 标记出时间戳所在的位置
timestamp_tokens = token_ids >= timestamp_begin
# 找出连续的时间戳位置
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
# 如果没有连续的时间戳或者只有一个时间戳,则返回空列表
if consecutive.shape[0] == 0 and timestamp_tokens.sum() <= 1:
return []
elif np.where(timestamp_tokens)[0][-1] + 1 not in consecutive:
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
# 初始化最后一个时间戳位置
last_slice = np.where(timestamp_tokens)[0][0]
# 遍历连续的时间戳位置
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
# 如果切片长度大于1,则处理时间戳位置并进行预处理
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# 从文本输出中去除时间戳标记的 token
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
# 解码处理后的 token
text = self._decode(sliced_tokens)
# 过滤文本中的时间戳标记
text = self._filter_timestamp_ids(text)
# 将处理后的信息添加到偏移列表中
offsets.append(
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
"""
# 如果 skip_special_tokens 为 True,则获取特殊标记的 token id
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
# 调用 _strip_prompt 方法去除 token_ids 中的提示和时间戳 token id
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
return token_ids
def _filter_timestamp_ids(self, token_ids):
"""
Filter out timestamp ids from token_ids using regex pattern.
Args:
token_ids (`str`): Token ids to filter.
Returns:
`str`: Token ids with timestamps removed.
"""
# 使用正则表达式模式 self.timestamp_pat 去除 token_ids 中的时间戳
return re.sub(self.timestamp_pat, "", token_ids)
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
output_offsets: bool = False,
time_precision: float = 0.02,
decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
):
"""
Convert token ids into human-readable text.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids during decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `None`):
Whether to clean up extra spaces around tokenized output.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether to return the offsets of tokens in the original input.
time_precision (`float`, *optional*, defaults to `0.02`):
Precision of time-related information in seconds.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether to decode timestamps along with token ids.
normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the decoded text.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether to apply basic normalization to the decoded text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether to remove diacritics from the decoded text.
**kwargs: Additional keyword arguments.
Returns:
`str` or (`str`, `List[Tuple[int, int]]`) or (`str`, `List[Tuple[int, int]]`, `List[str]`): Depending on
the combination of arguments, returns decoded text, offsets of tokens, and possibly normalized forms.
"""
# 实现将 token_ids 转换为人类可读文本的方法,根据参数控制输出格式和内容
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
):
"""
Decode token ids into human-readable text.
Args:
token_ids (`Union[int, List[int]]`):
Tokenized input ids.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to skip decoding special tokens.
normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the decoded text.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether to apply basic normalization to the decoded text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether to remove diacritics from the decoded text.
**kwargs: Additional keyword arguments.
Returns:
`str`: Decoded text based on token_ids and specified options.
"""
# 实现将 token_ids 解码为人类可读文本的方法,根据参数控制输出格式和内容
) -> str:
# 从kwargs中取出"use_source_tokenizer"参数,并设置为实例变量
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
# 使用convert_ids_to_tokens方法将token_ids转换为tokens列表,跳过特殊token
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# 避免在字节级BPT中混合使用字节级和Unicode,需要分别构建字符串,用于添加的token和字节级token
# 参考:https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
# 如果skip_special_tokens为True且token是特殊token,则跳过
if skip_special_tokens and token in self.all_special_ids:
continue
# 如果token在added_tokens_encoder中,则将当前子文本转换为字符串并清空,然后添加token
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
# 如果还有未添加到sub_texts的current_sub_text,则将其添加到sub_texts中
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
# 将所有子文本拼接成最终的文本
text = "".join(sub_texts)
# 根据参数normalize或basic_normalize对文本进行相应处理并返回
if normalize:
clean_text = self.normalize(text)
return clean_text
elif basic_normalize:
clean_text = self.basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text
# 从transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string复制,将tokens列表转换为单个字符串
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
# 将tokens列表连接成字符串
text = "".join(tokens)
# 使用byte_decoder将字节编码的文本解码为utf-8格式的字符串
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
# 将词汇表保存到指定目录下的文件中,并返回保存的文件路径
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# 检查保存目录是否存在,如果不存在则记录错误并返回
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
# 构建词汇表文件的路径
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
# 构建合并文件的路径
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
# 构建规范化文件的路径
normalizer_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["normalizer_file"]
)
# 将编码器(encoder)的内容以 JSON 格式写入词汇表文件
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
# 将BPE合并信息写入合并文件
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
# 如果存在英语拼写规范化器,将其内容以 JSON 格式写入规范化文件
if self.english_spelling_normalizer is not None:
with open(normalizer_file, "w", encoding="utf-8") as f:
f.write(
json.dumps(self.english_spelling_normalizer, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
)
# 返回保存的文件路径:词汇表文件、合并文件、规范化文件
return vocab_file, merge_file, normalizer_file
# 从GPT2Tokenizer.prepare_for_tokenization复制,准备文本进行分词处理
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
# 如果文本已经被分成词或需要在文本前添加空格,则在文本前加一个空格
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)
@property
# 从GPT2Tokenizer.default_chat_template复制,默认聊天模板
# 返回一个简单的聊天模板,忽略角色信息,仅将消息与EOS标记连接起来。
def default_chat_template(self):
# 警告日志:如果未定义聊天模板,则使用默认模板。
logger.warning_once(
"\nNo chat template is defined for this tokenizer - using the default template "
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
# 返回格式化后的聊天模板字符串,包含消息内容和EOS标记。
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
# 获取解码器提示的标识符列表,根据任务和语言设置前缀标记,并生成解码所需的强制标记。
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
# 前缀标记的形式为: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|>
# 不希望在位置1处强制BOS标记,因为这是生成时的起始标记,
# 因此我们将前缀标记切片为: <|lang_id|> <|task|> <|notimestamps|>
forced_tokens = self.prefix_tokens[1:]
# 返回带有强制标记的标识符和其在列表中的位置的元组列表
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
return forced_decoder_ids
# 调用静态方法 `_decode_asr`,将ASR模型输出解码为文字结果。
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
return _decode_asr(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
# 将提示文本转换为可以传递给生成器的标识符列表,避免特殊标记。
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
# 检查特殊标记
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
# 如果在提示文本中遇到不允许的特殊标记,则引发错误
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
# 将批量编码转换为指定类型的张量
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]
# 静态方法:从标识符列表中去除前缀和解码器起始标记。
@staticmethod
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
# 返回从解码器起始标记开始的标识符列表
return token_ids[token_ids.index(decoder_start_token_id):]
else:
# 如果解码器起始标记不在列表中,返回空列表
return []
# 如果没有前缀,则直接返回原始标识符列表
return token_ids
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models
"""
# =========== Overview ============
# - iterate over all outputs
# - all tokens within output
# - Each token can be
# - language token
# - special token
# - timestamp token
# - text token
# - We accumulate the text tokens.
# - We split on end timestamps
# - Lots of complexity comes from stride and timestamps
last_language = None
def new_chunk():
return {"language": last_language, "timestamp": [None, None], "text": ""}
# Welcome to the state machine !
chunks = [] # 初始化空列表,用于存储文本块的信息
chunk = new_chunk() # 创建一个新的文本块对象
time_offset = 0.0 # 时间偏移量初始化为0.0
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 # 获取时间戳开始的特殊标记ID
previous_tokens = [] # 初始化空列表,用于存储先前处理的token
previous_token_timestamps = [] # 初始化空列表,用于存储先前处理的token的时间戳
skip = False # 标志位,用于控制是否跳过处理
right_stride_start = None # 初始化右侧步幅开始标记为None
all_special_ids = set(tokenizer.all_special_ids) # 获取所有特殊token的ID集合
# - iterate over all outputs
if previous_tokens:
if return_timestamps:
logger.warning(
"Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. "
"Also make sure WhisperTimeStampLogitsProcessor was used during generation."
)
# Happens when we don't use timestamps
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
previous_tokens, previous_token_timestamps
)
resolved_text = tokenizer.decode(resolved_tokens) # 解码得到文本
chunk["text"] = resolved_text # 将解码得到的文本存入当前文本块对象
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
) # 整理单词级别的时间戳信息
chunks.append(chunk) # 将当前文本块对象添加到文本块列表中
# Preparing and cleaning up the pipeline output
full_text = "".join(chunk["text"] for chunk in chunks) # 将所有文本块中的文本合并为完整文本
if return_timestamps or return_language:
for chunk in chunks:
if not return_timestamps:
chunk.pop("timestamp") # 如果不需要时间戳信息,则移除当前文本块对象中的时间戳
else:
chunk["timestamp"] = tuple(chunk["timestamp"]) # 将当前文本块对象中的时间戳转换为元组形式
if not return_language:
chunk.pop("language") # 如果不需要语言信息,则移除当前文本块对象中的语言信息
if return_timestamps == "word":
new_chunks = []
for chunk in chunks:
new_chunks.extend(chunk["words"]) # 扩展单词级别的时间戳信息到新的文本块列表中
optional = {"chunks": new_chunks} # 构建输出的可选信息字典
else:
optional = {"chunks": chunks} # 构建输出的可选信息字典
else:
optional = {} # 如果不需要时间戳和语言信息,则置为空字典
return full_text, optional # 返回完整文本和可选信息字典
def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
# It would be much harder to do O(n) because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
pass # 占位符函数,用于查找最长公共子序列
# 如果提供了 token_timestamp_sequences 参数,将按照相同的方式分割这些序列。
# 从 sequences 列表中获取第一个序列
left_sequence = sequences[0]
# 计算左侧序列的长度
left_length = len(left_sequence)
# 初始化总序列为空列表
total_sequence = []
# 如果 token_timestamp_sequences 参数被提供
if token_timestamp_sequences:
# 从 token_timestamp_sequences 列表中获取第一个序列
left_token_timestamp_sequence = token_timestamp_sequences[0]
# 初始化总的 token_timestamp_sequence 为空列表
total_token_timestamp_sequence = []
# 遍历 sequences 列表中除第一个元素外的所有序列,同时获取它们的索引值 seq_idx 和序列内容 right_sequence
for seq_idx, right_sequence in enumerate(sequences[1:]):
# 初始化 max_ 变量为 0.0,用于存储最大匹配值
max_ = 0.0
# 初始化 max_indices 为元组 (left_length, left_length, 0, 0),记录最大匹配时的索引范围
max_indices = (left_length, left_length, 0, 0)
# 这里我们正在滑动匹配
# [a, b, c, d]
# [c, d, f]
# = [c] == [d]
#
# [a, b, c, d]
# [c, d, f]
# = [c, d] == [c, d]
#
# (省略中间部分)
# 获取 right_sequence 的长度
right_length = len(right_sequence)
# 遍历左侧序列 left_sequence 和右侧序列 right_sequence 的所有可能匹配位置
for i in range(1, left_length + right_length):
# epsilon 用于偏向长的完美匹配
eps = i / 10000.0
# 针对左侧序列和右侧序列进行切片,确保不越界
left_start = max(0, left_length - i)
left_stop = min(left_length, left_length + right_length - i)
left = np.array(left_sequence[left_start:left_stop])
right_start = max(0, i - left_length)
right_stop = min(right_length, i)
right = np.array(right_sequence[right_start:right_stop])
# 只能匹配相同长度的子序列
if len(left) != len(right):
# 如果长度不同,抛出运行时错误
raise RuntimeError(
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
)
# 计算左右序列的匹配度
matches = np.sum(left == right)
matching = matches / i + eps
# 如果匹配数大于 1 并且匹配度大于 max_,更新 max_ 和 max_indices
if matches > 1 and matching > max_:
max_ = matching
max_indices = (left_start, left_stop, right_start, right_stop)
# 将 max_indices 解构为 left_start, left_stop, right_start, right_stop
(left_start, left_stop, right_start, right_stop) = max_indices
# 这是一个小冲突优化,因为这些序列在音频中有重叠
# 对于重叠的左侧,我们会更加信任左侧序列
# 对于重叠的右侧,我们会更加信任右侧序列
left_mid = (left_stop + left_start) // 2
right_mid = (right_stop + right_start) // 2
# 将 left_sequence 的一部分添加到 total_sequence 中,并更新 left_sequence 和 left_length
total_sequence.extend(left_sequence[:left_mid])
left_sequence = right_sequence[right_mid:]
left_length = len(left_sequence)
# 如果 token_timestamp_sequences 存在,则将其对应部分也加入 total_token_timestamp_sequence 中
if token_timestamp_sequences:
total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid])
left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:]
# 将剩余的 left_sequence 加入 total_sequence 中
total_sequence.extend(left_sequence)
# 如果 token_timestamp_sequences 不存在,则返回 total_sequence
if token_timestamp_sequences is None:
return total_sequence
# 如果 token_timestamp_sequences 列表长度大于 0,则执行以下操作
if len(token_timestamp_sequences) > 0:
# 将 left_token_timestamp_sequence 扩展到 total_token_timestamp_sequence 中
total_token_timestamp_sequence.extend(left_token_timestamp_sequence)
# 返回总序列和合并后的总 token 时间戳序列
return total_sequence, total_token_timestamp_sequence
else:
# 如果 token_timestamp_sequences 列表为空,则返回总序列和空列表作为 token 时间戳序列
return total_sequence, []
# 将给定的 tokens 列表按照单词进行分组,并返回单词列表、以及每个单词对应的 token_id 序列。
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
# 调用内部函数 _combine_tokens_into_words,将 tokens 组合成单词
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
# 构建 timings 列表,每个元素包含单词和其起始和结束的时间戳元组
timings = [
{
"text": word,
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
}
for word, indices in zip(words, token_indices)
]
return timings
# 将 tokens 按照空格或标点符号分割成单词,并进行必要的标点符号处理
def _combine_tokens_into_words(
tokenizer,
tokens: List[int],
language: str = None,
prepend_punctuations: str = "\"'“¡¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
):
"""
Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id`
sequences with the tokens making up each word.
"""
# 如果未指定 language,则使用 tokenizer 的默认语言
if language is None:
language = tokenizer.language
# 如果 language 仍未指定,设置为英语
if language is None:
language = "english"
# 对于中文、日文、泰文、老挝文、缅甸文和广东话,不使用空格分割
if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}:
# 调用 _split_tokens_on_unicode,根据 Unicode 分割 tokens
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
else:
# 否则,调用 _split_tokens_on_spaces,按空格分割 tokens
words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens)
# 合并前置和后置标点符号到单词列表中
_merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations)
return words, word_tokens, token_indices
# 将 tokens 按照 Unicode 码点进行分割成单词
def _split_tokens_on_unicode(tokenizer, tokens: List[int]):
"""Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points."""
# 使用 tokenizer 解码 tokens,以获取完整的解码结果
decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True)
replacement_char = "\ufffd"
words = []
word_tokens = []
token_indices = []
current_tokens = []
current_indices = []
unicode_offset = 0
for token_idx, token in enumerate(tokens):
current_tokens.append(token)
current_indices.append(token_idx)
# 使用 tokenizer 解码当前 tokens
decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True)
# 判断是否包含替换字符或者完全匹配替换字符
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
token_indices.append(current_indices)
current_tokens = []
current_indices = []
unicode_offset += len(decoded)
return words, word_tokens, token_indices
# 将 tokens 按照空格或标点符号进行分割成单词
def _split_tokens_on_spaces(tokenizer, tokens: List[int]):
"""Combine tokens into words by splitting at whitespace and punctuation tokens."""
# 调用 _split_tokens_on_unicode,按 Unicode 码点分割 tokens
subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens)
words = []
word_tokens = []
token_indices = []
# 遍历三个列表:subwords, subword_tokens_list, subword_indices_list,同时迭代获取对应的元素
for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list):
# 检查当前子词的第一个标记是否大于或等于tokenizer.eos_token_id,判断是否为特殊标记
special = subword_tokens[0] >= tokenizer.eos_token_id
# 检查当前子词是否以空格开头
with_space = subword.startswith(" ")
# 检查当前子词去除两端空白后是否是标点符号
punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
# 如果满足特殊标记、以空格开头、是标点符号或者words列表为空,则将当前子词加入words列表,以及相应的标记和索引
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
token_indices.append(subword_indices)
# 否则,将当前子词连接到words列表的最后一个元素上,并扩展相应的标记和索引
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
token_indices[-1].extend(subword_indices)
# 返回处理后的words列表、word_tokens列表和token_indices列表
return words, word_tokens, token_indices
# 合并标点符号与相邻单词
def _merge_punctuations(words, tokens, indices, prepended, appended):
# 在单词列表末尾添加标点符号
i = len(words) - 2
j = len(words) - 1
while i >= 0:
# 如果前一个单词以空格开头且在预定义的前置标点符号列表中
if words[i].startswith(" ") and words[i].strip() in prepended:
# 将当前标点符号与前一个单词合并
words[j] = words[i] + words[j]
tokens[j] = tokens[i] + tokens[j]
indices[j] = indices[i] + indices[j]
# 清空前一个单词的内容,以及对应的 tokens 和 indices
words[i] = ""
tokens[i] = []
indices[i] = []
else:
j = i
i -= 1
# 在单词列表开头添加标点符号
i = 0
j = 1
while j < len(words):
# 如果当前单词不以空格结尾且在预定义的后置标点符号列表中
if not words[i].endswith(" ") and words[j] in appended:
# 将当前标点符号与前一个单词合并
words[i] += words[j]
tokens[i] += tokens[j]
indices[i] += indices[j]
# 清空当前单词的内容,以及对应的 tokens 和 indices
words[j] = ""
tokens[j] = []
indices[j] = []
else:
i = j
j += 1
# 移除现在为空的元素
words[:] = [word for word in words if word]
tokens[:] = [token for token in tokens if token]
indices[:] = [idx for idx in indices if idx]