Lucidrains-系列项目源码解析-二十九-

Lucidrains 系列项目源码解析(二十九)

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组 Config,用于存储 EfficientAttention 的配置信息
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用 once 装饰 print 函数,确保只打印一次
print_once = once(print)

# 主要的 Attend 类
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        use_flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.use_flash = use_flash
        assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定在 cuda 和 cpu 上的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not use_flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取掩码
    def get_mask(self, n, device):
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 推荐的多查询单键值注意力结构
        if k.ndim == 3:
            k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)

        if v.ndim == 3:
            v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 pytorch 2.0 的 Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out
    # 定义一个前向传播函数,实现注意力机制
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取序列长度和设备信息
        n, device = q.shape[-2], q.device

        # 计算缩放因子
        scale = q.shape[-1] ** -0.5

        # 如果使用闪回注意力机制,则调用相应函数
        if self.use_flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 根据输入维度确定键值对的 einsum 方程
        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        # 计算相似度
        sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale

        # 处理键的填充掩码
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 处理因果掩码
        if self.causal:
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 计算注意力权重
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合数值
        out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)

        return out

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\naturalspeech2_pytorch.py

# 导入所需的库
import math
import copy
from multiprocessing import cpu_count
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple

import numpy as np

import torch
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

import torchaudio
import torchaudio.transforms as T

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

from audiolm_pytorch import SoundStream, EncodecWrapper
from audiolm_pytorch.data import SoundDataset, get_dataloader

from beartype import beartype
from beartype.typing import Tuple, Union, Optional, List
from beartype.door import is_bearable

from naturalspeech2_pytorch.attend import Attend
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss
from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak
from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask
from naturalspeech2_pytorch.version import __version__

from accelerate import Accelerator
from ema_pytorch import EMA

from tqdm.auto import tqdm
import pyworld as pw

# 定义常量

mlist = nn.ModuleList

def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# 辅助函数

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def divisible_by(num, den):
    return (num % den) == 0

def identity(t, *args, **kwargs):
    return t

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

# 张量辅助函数

def pad_or_curtail_to_length(t, length):
    if t.shape[-1] == length:
        return t

    if t.shape[-1] > length:
        return t[..., :length]

    return F.pad(t, (0, length - t.shape[-1]))

def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob

def generate_mask_from_repeats(repeats):
    repeats = repeats.int()
    device = repeats.device

    lengths = repeats.sum(dim=-1)
    max_length = lengths.amax().item()
    cumsum = repeats.cumsum(dim=-1)
    cumsum_exclusive = F.pad(cumsum, (1, -1), value=0.)

    seq = torch.arange(max_length, device=device)
    seq = repeat(seq, '... j -> ... i j', i=repeats.shape[-1])

    cumsum = rearrange(cumsum, '... i -> ... i 1')
    cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1')

    lengths = rearrange(lengths, 'b -> b 1 1')
    mask = (seq < cumsum) & (seq >= cumsum_exclusive) & (seq < lengths)
    return mask

# 正弦位置嵌入

class LearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        assert divisible_by(dim, 2)
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        fouriered = torch.cat((x, fouriered), dim=-1)
        return fouriered

# 计算音高

def compute_pitch_pytorch(wav, sample_rate):
    # 使用 torchaudio 库中的 compute_kaldi_pitch 函数计算音高特征
    pitch_feature = torchaudio.functional.compute_kaldi_pitch(wav, sample_rate)
    pitch, nfcc = pitch_feature.unbind(dim=-1)
    return pitch

# 根据论文使用 pyworld 计算音高

def compute_pitch_pyworld(wav, sample_rate, hop_length, pitch_fmax=640.0):
    is_tensor_input = torch.is_tensor(wav)

    if is_tensor_input:
        device = wav.device
        wav = wav.contiguous().cpu().numpy()
    # 如果音频长度可以被 hop_length 整除,则在末尾填充一半的 hop_length 长度,使用反射模式填充
    if divisible_by(len(wav), hop_length):
        wav = np.pad(wav, (0, hop_length // 2), mode="reflect")

    # 将音频数据类型转换为双精度浮点型
    wav = wav.astype(np.double)

    # 初始化一个空列表用于存储音频样本的基频值
    outs = []

    # 遍历音频样本,提取基频值
    for sample in wav:
        # 使用 dio 函数提取音频样本的基频值和时间信息
        f0, t = pw.dio(
            sample,
            fs = sample_rate,
            f0_ceil = pitch_fmax,
            frame_period = 1000 * hop_length / sample_rate,
        )

        # 使用 stonemask 函数对基频值进行修正
        f0 = pw.stonemask(sample, f0, t, sample_rate)
        # 将修正后的基频值添加到 outs 列表中
        outs.append(f0)

    # 将 outs 列表转换为 numpy 数组
    outs = np.stack(outs)

    # 如果输入是张量形式,则将 outs 转换为张量并移动到指定设备上
    if is_tensor_input:
        outs = torch.from_numpy(outs).to(device)

    # 返回提取的基频值
    return outs
def f0_to_coarse(f0, f0_bin = 256, f0_max = 1100.0, f0_min = 50.0):
    # 计算最大和最小频率对应的梅尔频率
    f0_mel_max = 1127 * torch.log(1 + torch.tensor(f0_max) / 700)
    f0_mel_min = 1127 * torch.log(1 + torch.tensor(f0_min) / 700)

    # 计算输入频率对应的梅尔频率
    f0_mel = 1127 * (1 + f0 / 700).log()
    # 对梅尔频率进行线性变换,映射到[1, f0_bin-1]的范围
    f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1

    # 将小于等于1的值设置为1
    f0_mel[f0_mel <= 1] = 1
    # 将大于f0_bin-1的值设置为f0_bin-1
    f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
    # 对梅尔频率四舍五入取整
    f0_coarse = (f0_mel + 0.5).int()
    # 断言确保f0_coarse的取值范围在[1, 255]之间
    assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
    return f0_coarse

# peripheral models

# audio to mel

class AudioToMel(nn.Module):
    def __init__(
        self,
        *,
        n_mels = 100,
        sampling_rate = 24000,
        f_max = 8000,
        n_fft = 1024,
        win_length = 640,
        hop_length = 160,
        log = True
    ):
        super().__init__()
        self.log = log
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.f_max = f_max
        self.win_length = win_length
        self.hop_length = hop_length
        self.sampling_rate = sampling_rate

    def forward(self, audio):
        # 创建STFT变换对象
        stft_transform = T.Spectrogram(
            n_fft = self.n_fft,
            win_length = self.win_length,
            hop_length = self.hop_length,
            window_fn = torch.hann_window
        )

        # 对音频进行STFT变换得到频谱图
        spectrogram = stft_transform(audio)

        # 创建梅尔频率变换对象
        mel_transform = T.MelScale(
            n_mels = self.n_mels,
            sample_rate = self.sampling_rate,
            n_stft = self.n_fft // 2 + 1,
            f_max = self.f_max
        )

        # 对频谱图进行梅尔频率变换得到梅尔频谱图
        mel = mel_transform(spectrogram)

        # 如果log为True,则将梅尔频谱图转换为对数幅度
        if self.log:
            mel = T.AmplitudeToDB()(mel)

        return mel

# phoneme - pitch - speech prompt - duration predictors

class PhonemeEncoder(nn.Module):
    def __init__(
        self,
        *,
        tokenizer: Optional[Tokenizer] = None,
        num_tokens = None,
        dim = 512,
        dim_hidden = 512,
        kernel_size = 9,
        depth = 6,
        dim_head = 64,
        heads = 8,
        conv_dropout = 0.2,
        attn_dropout = 0.,
        use_flash = False
    ):
        super().__init__()

        # 初始化模型参数
        self.tokenizer = tokenizer
        num_tokens = default(num_tokens, tokenizer.vocab_size if exists(tokenizer) else None)

        self.token_emb = nn.Embedding(num_tokens + 1, dim) if exists(num_tokens) else nn.Identity()
        self.pad_id = num_tokens

        same_padding = (kernel_size - 1) // 2

        # 定义卷积层和变换层
        self.conv = nn.Sequential(
            Rearrange('b n c -> b c n'),
            CausalConv1d(dim, dim_hidden, kernel_size),
            nn.SiLU(),
            nn.Dropout(conv_dropout),
            Rearrange('b c n -> b n c'),
        )

        self.transformer = Transformer(
            dim = dim_hidden,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            dropout = attn_dropout,
            use_flash = use_flash
        )

    @beartype
    def forward(
        self,
        x: Union[Tensor, List[str]],
        mask = None
    ):
        # 如果输入为字符串列表,则将其转换为张量
        if is_bearable(x, List[str]):
            assert exists(self.tokenizer)
            x = self.tokenizer.texts_to_tensor_ids(x)

        # 将小于0的值设置为pad_id
        is_padding = x < 0
        x = x.masked_fill(is_padding, self.pad_id)

        x = self.token_emb(x)
        x = self.conv(x)
        x = self.transformer(x, mask = mask)
        return x

class SpeechPromptEncoder(nn.Module):

    @beartype
    def __init__(
        self,
        dim_codebook,
        dims: Tuple[int] = (256, 2048, 2048, 2048, 2048, 512, 512, 512),
        *,
        depth = 6,
        heads = 8,
        dim_head = 64,
        dropout = 0.2,
        kernel_size = 9,
        padding = 4,
        use_flash_attn = True
    # 定义一个继承自 nn.Module 的类,用于实现一个包含卷积和Transformer的模型
    ):
        # 调用父类的构造函数
        super().__init__()

        # 将dim_codebook添加到dims列表的开头
        dims = [dim_codebook, *dims]

        # 设置self.dim为dims列表的第一个元素,设置self.dim_out为dims列表的最后一个元素
        self.dim, self.dim_out = dims[0], dims[-1]

        # 将dims列表中相邻的两个元素组成一对,形成一个维度对的列表
        dim_pairs = zip(dims[:-1], dims[1:])

        # 初始化一个空的模块列表
        modules = []
        # 遍历维度对列表,为每一对维度创建一个卷积层和SiLU激活函数,并添加到模块列表中
        for dim_in, dim_out in dim_pairs:
            modules.extend([
                nn.Conv1d(dim_in, dim_out, kernel_size, padding = padding),
                nn.SiLU()
            ])

        # 构建一个包含卷积层和SiLU激活函数的序列模块
        self.conv = nn.Sequential(
            Rearrange('b n c -> b c n'),
            *modules,
            Rearrange('b c n -> b n c')
        )

        # 初始化一个Transformer模块
        self.transformer = Transformer(
            dim = dims[-1],
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            dropout = dropout,
            use_flash = use_flash_attn
        )

    # 定义前向传播函数
    def forward(self, x):
        # 断言输入张量x的最后一个维度与self.dim相等
        assert x.shape[-1] == self.dim

        # 将输入张量通过卷积层和Transformer模块进行前向传播
        x = self.conv(x)
        x = self.transformer(x)
        return x
# 定义一个名为 Block 的类,继承自 nn.Module
class Block(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out、卷积核大小 kernel、分组数 groups 和 dropout 概率
    def __init__(
        self,
        dim,
        dim_out,
        kernel = 3,
        groups = 8,
        dropout = 0.
    ):
        super().__init__()
        # 创建一个卷积层,将输入维度映射到输出维度
        self.proj = nn.Conv1d(dim, dim_out, kernel, padding = kernel // 2)
        # 对输出进行分组归一化
        self.norm = nn.GroupNorm(groups, dim_out)
        # 使用 SiLU 激活函数
        self.act = nn.SiLU()
        # 使用 dropout 进行正则化
        self.dropout = nn.Dropout(dropout)

    # 前向传播函数
    def forward(self, x):
        # 对输入进行卷积操作
        x = self.proj(x)
        # 对卷积结果进行分组归一化
        x = self.norm(x)
        # 使用激活函数
        x = self.act(x)
        # 使用 dropout
        x = self.dropout(x)
        return x

# 定义一个名为 ResnetBlock 的类,继承自 nn.Module
class ResnetBlock(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out、卷积核大小 kernel、dropout 概率、分组数 groups 和卷积层数 num_convs
    def __init__(
        self,
        dim,
        dim_out,
        kernel,
        *,
        dropout = 0.,
        groups = 8,
        num_convs = 2
    ):
        super().__init__()

        blocks = []
        # 循环创建 num_convs 个 Block 实例
        for ind in range(num_convs):
            is_first = ind == 0
            dim_in = dim if is_first else dim_out
            # 创建一个 Block 实例
            block = Block(
                dim_in,
                dim_out,
                kernel,
                groups = groups,
                dropout = dropout
            )
            blocks.append(block)

        # 将所有 Block 实例组合成一个序列
        self.blocks = nn.Sequential(*blocks)

        # 如果输入维度和输出维度不相等,使用 1x1 卷积进行维度匹配,否则使用恒等映射
        self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数
    def forward(self, x):
        # 将输入维度重新排列
        x = rearrange(x, 'b n c -> b c n')
        # 对输入进行 Block 序列操作
        h = self.blocks(x)
        # 将 Block 序列的输出与输入进行残差连接
        out = h + self.res_conv(x)
        # 将输出维度重新排列
        return rearrange(out, 'b c n -> b n c')

# 定义一个函数 ConvBlock,接受输入维度 dim、输出维度 dim_out、卷积核大小 kernel 和 dropout 概率
def ConvBlock(dim, dim_out, kernel, dropout = 0.):
    # 返回一个包含卷积、激活函数、dropout 的序列
    return nn.Sequential(
        Rearrange('b n c -> b c n'),
        nn.Conv1d(dim, dim_out, kernel, padding = kernel // 2),
        nn.SiLU(),
        nn.Dropout(dropout),
        Rearrange('b c n -> b n c'),
    )

# 定义一个名为 DurationPitchPredictorTrunk 的类,继承自 nn.Module
class DurationPitchPredictorTrunk(nn.Module):
    # 初始化函数,接受输入维度 dim、深度 depth、卷积核大小 kernel_size、上下文维度 dim_context、头数 heads、头维度 dim_head、dropout 概率、是否使用 ResNet 块 use_resnet_block、每个 ResNet 块的卷积层数 num_convs_per_resnet_block、每个块的卷积层数 num_convolutions_per_block、是否使用 Flash 注意力 use_flash_attn
    def __init__(
        self,
        dim = 512,
        depth = 10,
        kernel_size = 3,
        dim_context = None,
        heads = 8,
        dim_head = 64,
        dropout = 0.2,
        use_resnet_block = True,
        num_convs_per_resnet_block = 2,
        num_convolutions_per_block = 3,
        use_flash_attn = False,
    ):
        super().__init__()
        # 初始化一个空的模块列表
        self.layers = nn.ModuleList([])

        # 根据是否使用 ResNet 块选择卷积类
        conv_klass = ConvBlock if not use_resnet_block else partial(ResnetBlock, num_convs = num_convs_per_resnet_block)

        # 循环创建 depth 个层
        for _ in range(depth):
            # 每个层包含一个卷积序列、RMSNorm 归一化和注意力机制
            layer = nn.ModuleList([
                nn.Sequential(*[
                    conv_klass(dim, dim, kernel_size) for _ in range(num_convolutions_per_block)
                ]),
                RMSNorm(dim),
                Attention(
                    dim,
                    dim_context = dim_context,
                    heads = heads,
                    dim_head = dim_head,
                    dropout = dropout,
                    use_flash = use_flash_attn,
                    cross_attn_include_queries = True
                )
            ])

            self.layers.append(layer)

        # 最后的预测层,包含线性层、维度重排和 ReLU 激活函数
        self.to_pred = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... 1 -> ...'),
            nn.ReLU()
        )
    
    # 前向传播函数,接受输入 x、编码的提示信息 encoded_prompts 和提示信息的掩码 prompt_mask
    def forward(
        self,
        x,
        encoded_prompts,
        prompt_mask = None,
    ):
        # 对每个层进行操作
        for conv, norm, attn in self.layers:
            x = conv(x)
            x = attn(norm(x), encoded_prompts, mask = prompt_mask) + x

        return self.to_pred(x)

# 定义一个名为 DurationPitchPredictor 的类,继承自 nn.Module
class DurationPitchPredictor(nn.Module):
    # 初始化函数,接受维度 dim、音素标记数 num_phoneme_tokens、分词器 tokenizer、编码提示信息的维度 dim_encoded_prompts、每个块的卷积层数 num_convolutions_per_block、是否使用 ResNet 块 use_resnet_block、每个 ResNet 块的卷积层数 num_convs_per_resnet_block、深度 depth、卷积核大小 kernel_size、头数 heads、头维度 dim_head、隐藏层维度 dim_hidden、dropout 概率、是否使用 Flash 注意力 use_flash_attn
    def __init__(
        self,
        *,
        dim,
        num_phoneme_tokens = None,
        tokenizer: Optional[Tokenizer] = None,
        dim_encoded_prompts = None,
        num_convolutions_per_block = 3,
        use_resnet_block = True,
        num_convs_per_resnet_block = 2,
        depth = 10,
        kernel_size = 3,
        heads = 8,
        dim_head = 64,
        dim_hidden = 512,
        dropout = 0.2,
        use_flash_attn = False
    ):
        super().__init__()
        # 略
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化 tokenizer 属性
        self.tokenizer = tokenizer
        # 如果存在 tokenizer,则将 num_phoneme_tokens 设置为 tokenizer 的词汇表大小,否则为 None
        num_phoneme_tokens = default(num_phoneme_tokens, tokenizer.vocab_size if exists(tokenizer) else None)

        # 如果未提供 dim_encoded_prompts,则将其设置为 dim
        dim_encoded_prompts = default(dim_encoded_prompts, dim)

        # 如果存在 num_phoneme_tokens,则创建一个 num_phoneme_tokens x dim 的嵌入层,否则创建一个恒等映射
        self.phoneme_token_emb = nn.Embedding(num_phoneme_tokens, dim) if exists(num_phoneme_tokens) else nn.Identity()

        # 初始化 to_pitch_pred 属性为 DurationPitchPredictorTrunk 类的实例
        self.to_pitch_pred = DurationPitchPredictorTrunk(
            dim = dim_hidden,
            depth = depth,
            kernel_size = kernel_size,
            dim_context = dim_encoded_prompts,
            heads = heads,
            dim_head = dim_head,
            dropout = dropout,
            use_resnet_block = use_resnet_block,
            num_convs_per_resnet_block = num_convs_per_resnet_block,
            num_convolutions_per_block = num_convolutions_per_block,
            use_flash_attn = use_flash_attn,
        )

        # 使用深拷贝创建 to_duration_pred 属性
        self.to_duration_pred = copy.deepcopy(self.to_pitch_pred)

    # 定义 forward 方法
    @beartype
    def forward(
        self,
        x: Union[Tensor, List[str]],
        encoded_prompts,
        prompt_mask = None
    ):
        # 如果 x 是 List[str] 类型,则将其转换为张量
        if is_bearable(x, List[str]):
            assert exists(self.tokenizer)
            x = self.tokenizer.texts_to_tensor_ids(x)

        # 对输入 x 进行嵌入
        x = self.phoneme_token_emb(x)

        # 使用 map 函数对 to_duration_pred 和 to_pitch_pred 进行计算
        duration_pred, pitch_pred = map(lambda fn: fn(x, encoded_prompts = encoded_prompts, prompt_mask = prompt_mask), (self.to_duration_pred, self.to_pitch_pred))

        # 返回持续时间预测和音高预测结果
        return duration_pred, pitch_pred
# 使用来自 flamingo 论文的 Perceiver Resampler,替代 "q-k-v" 注意力机制,其中 m 个查询成为网络条件的关键/值

class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_context = None,
        num_latents = 64, # 论文中的 m
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        use_flash_attn = False
    ):
        super().__init__()
        dim_context = default(dim_context, dim)

        self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()

        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        nn.init.normal_(self.latents, std = 0.02)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    use_flash = use_flash_attn,
                    cross_attn_include_queries = True
                ),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        self.norm = RMSNorm(dim)

    def forward(self, x, mask = None):
        batch = x.shape[0]

        x = self.proj_context(x)

        latents = repeat(self.latents, 'n d -> b n d', b = batch)

        for attn, ff in self.layers:
            latents = attn(latents, x, mask = mask) + latents
            latents = ff(latents) + latents

        return self.norm(latents)

# 模型,即 Wavenet + Transformer

class CausalConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        kernel_size, = self.kernel_size
        dilation, = self.dilation
        stride, = self.stride

        assert stride == 1
        self.causal_padding = dilation * (kernel_size - 1)

    def forward(self, x):
        causal_padded_x = F.pad(x, (self.causal_padding, 0), value = 0.)
        return super().forward(causal_padded_x)

class WavenetResBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dilation,
        kernel_size = 3,
        skip_conv = False,
        dim_cond_mult = None
    ):
        super().__init__()

        self.cond = exists(dim_cond_mult)
        self.to_time_cond = None

        if self.cond:
            self.to_time_cond = nn.Linear(dim * dim_cond_mult, dim * 2)

        self.conv = CausalConv1d(dim, dim, kernel_size, dilation = dilation)
        self.res_conv = CausalConv1d(dim, dim, 1)
        self.skip_conv = CausalConv1d(dim, dim, 1) if skip_conv else None

    def forward(self, x, t = None):

        if self.cond:
            assert exists(t)
            t = self.to_time_cond(t)
            t = rearrange(t, 'b c -> b c 1')
            t_gamma, t_beta = t.chunk(2, dim = -2)

        res = self.res_conv(x)

        x = self.conv(x)

        if self.cond:
            x = x * t_gamma + t_beta

        x = x.tanh() * x.sigmoid()

        x = x + res

        skip = None
        if exists(self.skip_conv):
            skip = self.skip_conv(x)

        return x, skip


class WavenetStack(nn.Module):
    def __init__(
        self,
        dim,
        *,
        layers,
        kernel_size = 3,
        has_skip = False,
        dim_cond_mult = None
    ):
        super().__init__()
        dilations = 2 ** torch.arange(layers)

        self.has_skip = has_skip
        self.blocks = mlist([])

        for dilation in dilations.tolist():
            block = WavenetResBlock(
                dim = dim,
                kernel_size = kernel_size,
                dilation = dilation,
                skip_conv = has_skip,
                dim_cond_mult = dim_cond_mult
            )

            self.blocks.append(block)
    # 定义前向传播函数,接受输入 x 和时间 t
    def forward(self, x, t):
        # 初始化残差和跳跃连接列表
        residuals = []
        skips = []

        # 如果输入 x 是张量类型,则将其重复多次,以匹配网络块的数量
        if isinstance(x, Tensor):
            x = (x,) * len(self.blocks)

        # 遍历输入 x 和网络块,计算残差和跳跃连接
        for block_input, block in zip(x, self.blocks):
            residual, skip = block(block_input, t)

            # 将计算得到的残差和跳跃连接添加到对应的列表中
            residuals.append(residual)
            skips.append(skip)

        # 如果存在跳跃连接,则返回所有跳跃连接的张量堆叠
        if self.has_skip:
            return torch.stack(skips)

        # 否则返回所有残差的列表
        return residuals
class Wavenet(nn.Module):
    def __init__(
        self,
        dim,
        *,
        stacks,
        layers,
        init_conv_kernel = 3,
        dim_cond_mult = None
    ):
        # 初始化 Wavenet 类
        super().__init__()
        # 创建初始卷积层对象
        self.init_conv = CausalConv1d(dim, dim, init_conv_kernel)
        # 初始化堆栈列表
        self.stacks = mlist([])

        # 循环创建堆栈
        for ind in range(stacks):
            is_last = ind == (stacks - 1)

            # 创建 WavenetStack 对象
            stack = WavenetStack(
                dim,
                layers = layers,
                dim_cond_mult = dim_cond_mult,
                has_skip = is_last
            )

            # 将堆栈对象添加到堆栈列表中
            self.stacks.append(stack)

        # 创建最终卷积层对象
        self.final_conv = CausalConv1d(dim, dim, 1)

    def forward(self, x, t = None):
        # 对输入数据进行初始卷积
        x = self.init_conv(x)

        # 遍历堆栈列表,对数据进行处理
        for stack in self.stacks:
            x = stack(x, t)

        # 对处理后的数据进行最终卷积并返回结果
        return self.final_conv(x.sum(dim = 0))

class RMSNorm(nn.Module):
    def __init__(self, dim, scale = True, dim_cond = None):
        # 初始化 RMSNorm 类
        super().__init__()
        # 检查是否有条件输入
        self.cond = exists(dim_cond)
        # 根据条件初始化线性层
        self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None

        # 初始化缩放参数和 gamma 参数
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim)) if scale else None

    def forward(self, x, cond = None):
        # 获取 gamma 参数
        gamma = default(self.gamma, 1)
        # 对输入数据进行归一化处理
        out = F.normalize(x, dim = -1) * self.scale * gamma

        # 如果没有条件输入,则直接返回处理后的数据
        if not self.cond:
            return out

        # 如果有条件输入,则根据条件计算 gamma 和 beta,并进行处理
        assert exists(cond)
        gamma, beta = self.to_gamma_beta(cond).chunk(2, dim = -1)
        gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta))
        return out * gamma + beta

class ConditionableTransformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        ff_causal_conv = False,
        dim_cond_mult = None,
        cross_attn = False,
        use_flash = False
    ):
        # 初始化 ConditionableTransformer 类
        super().__init__()
        # 设置维度和层列表
        self.dim = dim
        self.layers = mlist([])

        # 检查是否有条件输入
        cond = exists(dim_cond_mult)

        # 根据条件初始化 RMSNorm 层
        maybe_adaptive_norm_kwargs = dict(scale = not cond, dim_cond = dim * dim_cond_mult) if cond else dict()
        rmsnorm = partial(RMSNorm, **maybe_adaptive_norm_kwargs)

        # 循环创建层
        for _ in range(depth):
            self.layers.append(mlist([
                rmsnorm(dim),
                Attention(dim = dim, dim_head = dim_head, heads = heads, use_flash = use_flash),
                rmsnorm(dim) if cross_attn else None,
                Attention(dim = dim, dim_head = dim_head, heads = heads, use_flash = use_flash) if cross_attn else None,
                rmsnorm(dim),
                FeedForward(dim = dim, mult = ff_mult, causal_conv = ff_causal_conv)
            ]))

        # 创建预测层
        self.to_pred = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, dim, bias = False)
        )

    def forward(
        self,
        x,
        times = None,
        context = None
    ):
        t = times

        # 遍历层列表,对输入数据进行处理
        for attn_norm, attn, cross_attn_norm, cross_attn, ff_norm, ff in self.layers:
            res = x
            x = attn_norm(x, cond = t)
            x = attn(x) + res

            # 如果有交叉注意力,则进行处理
            if exists(cross_attn):
                assert exists(context)
                res = x
                x = cross_attn_norm(x, cond = t)
                x = cross_attn(x, context = context) + res

            res = x
            x = ff_norm(x, cond = t)
            x = ff(x) + res

        # 返回预测结果
        return self.to_pred(x)

class Model(nn.Module):

    @beartype
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        wavenet_layers = 8,
        wavenet_stacks = 4,
        dim_cond_mult = 4,
        use_flash_attn = True,
        dim_prompt = None,
        num_latents_m = 32,   # number of latents to be perceiver resampled ('q-k-v' with 'm' queries in the paper)
        resampler_depth = 2,
        cond_drop_prob = 0.,
        condition_on_prompt= False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化模型的维度
        self.dim = dim

        # 时间条件

        # 根据维度计算时间条件的维度
        dim_time = dim * dim_cond_mult

        # 创建时间条件的网络层
        self.to_time_cond = Sequential(
            LearnedSinusoidalPosEmb(dim),  # 学习的正弦位置编码
            nn.Linear(dim + 1, dim_time),   # 线性层,将输入维度转换为时间条件的维度
            nn.SiLU()                       # SiLU激活函数
        )

        # 提示条件

        self.cond_drop_prob = cond_drop_prob  # 用于分类器无指导的概率
        self.condition_on_prompt = condition_on_prompt
        self.to_prompt_cond = None

        if self.condition_on_prompt:
            self.null_prompt_cond = nn.Parameter(torch.randn(dim_time))  # 随机初始化空提示条件
            self.null_prompt_tokens = nn.Parameter(torch.randn(num_latents_m, dim))  # 随机初始化空提示标记

            nn.init.normal_(self.null_prompt_cond, std = 0.02)  # 使用正态分布初始化空提示条件
            nn.init.normal_(self.null_prompt_tokens, std = 0.02)  # 使用正态分布初始化空提示标记

            # 创建提示条件的网络层
            self.to_prompt_cond = Sequential(
                Reduce('b n d -> b d', 'mean'),  # 减少维度
                nn.Linear(dim_prompt, dim_time),  # 线性层,将输入维度转换为提示条件的维度
                nn.SiLU()  # SiLU激活函数
            )

            # 创建PerceiverResampler对象
            self.perceiver_resampler = PerceiverResampler(
                dim = dim,
                dim_context = dim_prompt,
                num_latents = num_latents_m,
                depth = resampler_depth,
                dim_head = dim_head,
                heads = heads,
                use_flash_attn = use_flash_attn
            )

        # 从对齐器和持续时间模块获取对齐的条件

        self.null_cond = None
        self.cond_to_model_dim = None

        if self.condition_on_prompt:
            self.cond_to_model_dim = nn.Conv1d(dim_prompt, dim, 1)  # 一维卷积层,将提示条件转换为模型维度
            self.null_cond = nn.Parameter(torch.zeros(dim, 1))  # 初始化空条件

        # 条件包括时间和可选的提示

        dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1)  # 更新条件的维度乘数

        # WaveNet

        # 创建WaveNet模型
        self.wavenet = Wavenet(
            dim = dim,
            stacks = wavenet_stacks,
            layers = wavenet_layers,
            dim_cond_mult = dim_cond_mult
        )

        # Transformer

        # 创建ConditionableTransformer模型
        self.transformer = ConditionableTransformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            ff_mult = ff_mult,
            ff_causal_conv = True,
            dim_cond_mult = dim_cond_mult,
            use_flash = use_flash_attn,
            cross_attn = condition_on_prompt
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        # 前向传播函数,带有条件缩放
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1.:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)

        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        x,
        times,
        prompt = None,
        prompt_mask = None,
        cond = None,
        cond_drop_prob = None
        ):
        # 获取输入张量 x 的 batch 大小
        b = x.shape[0]
        # 如果未指定条件丢弃概率,则使用默认值
        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        # 准备时间条件
        # 概率应该在向前移除

        # 将时间转换为条件
        t = self.to_time_cond(times)
        c = None

        # 如果存在 prompt 条件
        if exists(self.to_prompt_cond):
            assert exists(prompt)

            # 创建与 prompt 条件大小相同的概率掩码
            prompt_cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)

            # 将 prompt 转换为条件
            prompt_cond = self.to_prompt_cond(prompt)

            # 根据概率掩码更新 prompt 条件
            prompt_cond = torch.where(
                rearrange(prompt_cond_drop_mask, 'b -> b 1'),
                self.null_prompt_cond,
                prompt_cond,
            )

            # 将时间条件和 prompt 条件连接起来
            t = torch.cat((t, prompt_cond), dim = -1)

            # 对 prompt 进行重采样
            resampled_prompt_tokens = self.perceiver_resampler(prompt, mask = prompt_mask)

            # 根据概率掩码更新 prompt tokens
            c = torch.where(
                rearrange(prompt_cond_drop_mask, 'b -> b 1 1'),
                self.null_prompt_tokens,
                resampled_prompt_tokens
            )

        # 重新排列为通道优先格式
        x = rearrange(x, 'b n d -> b d n')

        # 将对齐的条件加到输入序列中
        if exists(self.cond_to_model_dim):
            assert exists(cond)
            # 将条件转换为模型维度
            cond = self.cond_to_model_dim(cond)

            # 创建与条件大小相同的概率掩码
            cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)

            # 根据概率掩码更新条件
            cond = torch.where(
                rearrange(cond_drop_mask, 'b -> b 1 1'),
                self.null_cond,
                cond
            )

            # 目前,将条件调整为潜在特征的长度
            cond = pad_or_curtail_to_length(cond, x.shape[-1])

            # 将条件加到输入张量中
            x = x + cond

        # 主要的 WaveNet 模块
        x = self.wavenet(x, t)
        x = rearrange(x, 'b d n -> b n d')

        # 使用 Transformer 模块
        x = self.transformer(x, t, context = c)
        return x
# feedforward

# GEGLU 激活函数类,用于前向传播
class GEGLU(nn.Module):
    # 前向传播函数
    def forward(self, x):
        # 将输入张量 x 按照最后一个维度分成两部分
        x, gate = x.chunk(2, dim = -1)
        # 返回 GEGLU 激活函数的结果
        return F.gelu(gate) * x

# 创建前馈神经网络层
def FeedForward(dim, mult = 4, causal_conv = False):
    # 计算内部维度
    dim_inner = int(dim * mult * 2 / 3)

    conv = None
    # 如果是因果卷积
    if causal_conv:
        # 创建因果卷积层
        conv = nn.Sequential(
            Rearrange('b n d -> b d n'),
            CausalConv1d(dim_inner, dim_inner, 3),
            Rearrange('b d n -> b n d'),
        )

    return Sequential(
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        conv,
        nn.Linear(dim_inner, dim)
    )

# attention

# 注意力机制类
class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        dim_context = None,
        causal = False,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        use_flash = False,
        cross_attn_include_queries = False
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.cross_attn_include_queries = cross_attn_include_queries

        dim_inner = dim_head * heads
        dim_context = default(dim_context, dim)

        self.attend = Attend(causal = causal, dropout = dropout, use_flash = use_flash)
        self.to_q = nn.Linear(dim, dim_inner, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)

    # 前向传播函数
    def forward(self, x, context = None, mask = None):
        h, has_context = self.heads, exists(context)

        context = default(context, x)

        if has_context and self.cross_attn_include_queries:
            context = torch.cat((x, context), dim = -2)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        out = self.attend(q, k, v, mask = mask)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# transformer encoder

# Transformer 编码器类
class Transformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        depth,
        causal = False,
        dim_head = 64,
        heads = 8,
        use_flash = False,
        dropout = 0.,
        ff_mult = 4,
        final_norm = False
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 创建多层 Transformer 编码器
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                RMSNorm(dim),
                Attention(
                    dim,
                    causal = causal,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = dropout,
                    use_flash = use_flash
                ),
                RMSNorm(dim),
                FeedForward(
                    dim,
                    mult = ff_mult
                )
            ]))

        self.norm = RMSNorm(dim) if final_norm else nn.Identity()

    # 前向传播函数
    def forward(self, x, mask = None):
        for attn_norm, attn, ff_norm, ff in self.layers:
            x = attn(attn_norm(x), mask = mask) + x
            x = ff(ff_norm(x)) + x

        return self.norm(x)

# tensor helper functions

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# ���全除法函数
def safe_div(numer, denom):
    return numer / denom.clamp(min = 1e-10)

# 将 x 张量的维度右侧填充到与 t 张量相同维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# noise schedules

# 简单线性调度函数
def simple_linear_schedule(t, clip_min = 1e-9):
    return (1 - t).clamp(min = clip_min)

# 余弦调度函数
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = math.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

# sigmoid 调度函数
def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
    # 根据起始时间和结束时间计算对应的 sigmoid 值
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    # 计算 gamma 值,用于调整时间范围
    gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    # 对 gamma 进行范围限制,确保在指定范围内
    return gamma.clamp_(min=clamp_min, max=1.)
# 将 gamma 转换为 alpha、sigma 或 logsnr

def gamma_to_alpha_sigma(gamma, scale = 1):
    # 计算 alpha 和 sigma,并乘以指定的比例
    return torch.sqrt(gamma) * scale, torch.sqrt(1 - gamma)

def gamma_to_log_snr(gamma, scale = 1, eps = 1e-5):
    # 计算 logsnr,根据给定的 gamma、比例和 eps
    return log(gamma * (scale ** 2) / (1 - gamma), eps = eps)

# 高斯扩散

class NaturalSpeech2(nn.Module):

    @beartype
    def __init__(
        self,
        model: Model,
        codec: Optional[Union[SoundStream, EncodecWrapper]] = None,
        *,
        
        tokenizer: Optional[Tokenizer] = None,
        target_sample_hz = None,
        timesteps = 1000,
        use_ddim = True,
        noise_schedule = 'sigmoid',
        objective = 'v',
        schedule_kwargs: dict = dict(),
        time_difference = 0.,
        min_snr_loss_weight = True,
        min_snr_gamma = 5,
        train_prob_self_cond = 0.9,
        rvq_cross_entropy_loss_weight = 0., # 默认关闭,直到确定其是否有效。不确定这是否至关重要
        dim_codebook: int = 128,
        duration_pitch_dim: int = 512,
        aligner_dim_in: int = 80,
        aligner_dim_hidden: int = 512,
        aligner_attn_channels: int = 80,
        num_phoneme_tokens: int = 150,
        pitch_emb_dim: int = 256,
        pitch_emb_pp_hidden_dim: int= 512,
        calc_pitch_with_pyworld = True,     # 使用 pyworld 或 kaldi 从 torchaudio 计算音高
        mel_hop_length = 160,
        audio_to_mel_kwargs: dict = dict(),
        scale = 1., # 在训练高分辨率图像时,将此设置为 < 1 以获得更好的收敛性
        duration_loss_weight = 1.,
        pitch_loss_weight = 1.,
        aligner_loss_weight = 1.,
        aligner_bin_loss_weight = 0.
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 设置条件变量
        self.conditional = model.condition_on_prompt

        # 设置模型和编解码器
        self.model = model
        self.codec = codec

        # 确保编解码器存在或目标采样率存在
        assert exists(codec) or exists(target_sample_hz)

        # 设置目标采样率和序列长度的倍数
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = None

        # 如果编解码器存在,则设置目标采样率和序列长度的倍数
        if exists(codec):
            self.target_sample_hz = codec.target_sample_hz
            self.seq_len_multiple_of = codec.seq_len_multiple_of

        # 准备条件
        if self.conditional:
            # 如果目标采样率存在,则更新音频到梅尔频谱的参数
            if exists(self.target_sample_hz):
                audio_to_mel_kwargs.update(sampling_rate = self.target_sample_hz)

            # 设置梅尔频谱的跳跃长度
            self.mel_hop_length = mel_hop_length

            # 创建音频到梅尔频谱的转换器
            self.audio_to_mel = AudioToMel(
                n_mels = aligner_dim_in,
                hop_length = mel_hop_length,
                **audio_to_mel_kwargs
            )

            # 设置是否使用 PyWorld 计算音高
            self.calc_pitch_with_pyworld = calc_pitch_with_pyworld

            # 初始化音素编码器、语音提示编码器、持续时间和音高预测器、对齐器、音高嵌入层等
            self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=num_phoneme_tokens)
            self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook)
            self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
            self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
            self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)

            # 初始化对齐器损失和二值损失
            self.aligner_loss = ForwardSumLoss()
            self.bin_loss = BinLoss()
            self.aligner_bin_loss_weight = aligner_bin_loss_weight

        # 其余的 DDPM

        # 确保编解码器维度与模型维度相等
        assert not exists(codec) or model.dim == codec.codebook_dim, f'transformer model dimension {model.dim} must be equal to codec dimension {codec.codebook_dim}'

        # 设置维度
        self.dim = codec.codebook_dim if exists(codec) else model.dim

        # 确保目标是 'x0', 'eps', 'v' 中的一个
        assert objective in {'x0', 'eps', 'v'}, 'objective must be either predict x0 or noise'
        self.objective = objective

        # 根据噪声调度设置 gamma 调度
        if noise_schedule == "linear":
            self.gamma_schedule = simple_linear_schedule
        elif noise_schedule == "cosine":
            self.gamma_schedule = cosine_schedule
        elif noise_schedule == "sigmoid":
            self.gamma_schedule = sigmoid_schedule
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        # 设置缩放比例
        assert scale <= 1, 'scale must be less than or equal to 1'
        self.scale = scale

        # 设置 gamma 调度的参数
        self.gamma_schedule = partial(self.gamma_schedule, **schedule_kwargs)

        # 设置时间步长和是否使用 DDIM
        self.timesteps = timesteps
        self.use_ddim = use_ddim

        # 提出的方法,将时间差加到下一个时间步长,以修复自我条件不足和在采样时间步长小于 400 时降低 FID
        self.time_difference = time_difference

        # 训练时自我条件的概率
        self.train_prob_self_cond = train_prob_self_cond

        # 最小 SNR 损失权重
        self.min_snr_loss_weight = min_snr_loss_weight
        self.min_snr_gamma = min_snr_gamma

        # 持续时间和音高的损失权重
        self.duration_loss_weight = duration_loss_weight
        self.pitch_loss_weight = pitch_loss_weight
        self.aligner_loss_weight = aligner_loss_weight

    # 设备属性
    @property
    def device(self):
        return next(self.model.parameters()).device

    # 打印方法
    def print(self, s):
        return self.accelerator.print(s)
    # 获取采样时间步长
    def get_sampling_timesteps(self, batch, *, device):
        # 在设备上创建一个从1到0的时间序列
        times = torch.linspace(1., 0., self.timesteps + 1, device=device)
        # 将时间序列重复batch次
        times = repeat(times, 't -> b t', b=batch)
        # 将时间序列拆分成相邻时间步长的对
        times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
        times = times.unbind(dim=-1)
        return times

    # 生成DDPM采样
    @torch.no_grad()
    def ddpm_sample(self, shape, prompt=None, time_difference=None, cond_scale=1., cond=None):
        batch, device = shape[0], self.device

        # 设置时间差
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间对
        time_pairs = self.get_sampling_timesteps(batch, device=device)

        # 生成随机音频
        audio = torch.randn(shape, device=device)

        x_start = None
        last_latents = None

        # 遍历时间对
        for time, time_next in tqdm(time_pairs, desc='sampling loop time step', total=self.timesteps):

            # 添加时间延迟
            time_next = (time_next - self.time_difference).clamp(min=0.)

            noise_cond = time

            # 获取预测的x0
            model_output = self.model.forward_with_cond_scale(audio, noise_cond, prompt=prompt, cond_scale=cond_scale, cond=cond)

            # 获取log(snr)
            gamma = self.gamma_schedule(time)
            gamma_next = self.gamma_schedule(time_next)
            gamma, gamma_next = map(partial(right_pad_dims_to, audio), (gamma, gamma_next))

            # 获取alpha和sigma
            alpha, sigma = gamma_to_alpha_sigma(gamma, self.scale)
            alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next, self.scale)

            # 计算x0和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(audio - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * audio - sigma * model_output

            # 推导后验均值和方差
            log_snr, log_snr_next = map(gamma_to_log_snr, (gamma, gamma_next))
            c = -expm1(log_snr - log_snr_next)
            mean = alpha_next * (audio * (1 - c) / alpha + c * x_start)
            variance = (sigma_next ** 2) * c
            log_variance = log(variance)

            # 获取噪声
            noise = torch.where(
                rearrange(time_next > 0, 'b -> b 1 1 1'),
                torch.randn_like(audio),
                torch.zeros_like(audio)
            )

            # 更新音频
            audio = mean + (0.5 * log_variance).exp() * noise

        return audio

    @torch.no_grad()
    # 生成一个指定形状的样本,可以设置时间差异、条件比例和条件
    def ddim_sample(self, shape, prompt = None, time_difference = None, cond_scale = 1., cond = None):
        # 获取批次大小和设备
        batch, device = shape[0], self.device

        # 设置时间差异
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成随机噪声
        audio = torch.randn(shape, device = device)

        x_start = None
        last_latents = None

        # 遍历时间步
        for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):

            # 获取时间和噪声水平
            gamma = self.gamma_schedule(times)
            gamma_next = self.gamma_schedule(times_next)

            # 填充时间和噪声水平
            padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, audio), (gamma, gamma_next))

            # 将噪声水平转换为 alpha 和 sigma
            alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)
            alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next, self.scale)

            # 添加时间延迟
            times_next = (times_next - time_difference).clamp(min = 0.)

            # 预测 x0
            model_output = self.model.forward_with_cond_scale(audio, times, prompt = prompt, cond_scale = cond_scale, cond = cond)

            # 计算 x0 和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(audio - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * audio - sigma * model_output

            # 获取预测噪声
            pred_noise = safe_div(audio - alpha * x_start, sigma)

            # 计算下一个 x
            audio = x_start * alpha_next + pred_noise * sigma_next

        return audio

    # 处理提示信息
    def process_prompt(self, prompt = None):
        if not exists(prompt):
            return None

        assert self.model.condition_on_prompt

        is_raw_prompt = prompt.ndim == 2
        assert not (is_raw_prompt and not exists(self.codec)), 'codec must be passed in if one were to train on raw prompt'

        if is_raw_prompt:
            with torch.no_grad():
                self.codec.eval()
                prompt, _, _ = self.codec(prompt, curtail_from_left = True, return_encoded = True)

        return prompt

    # 扩展编码
    def expand_encodings(self, phoneme_enc, attn, pitch):
        expanded_dur = einsum('k l m n, k j m -> k j n', attn, phoneme_enc)
        pitch_emb = self.pitch_emb(rearrange(f0_to_coarse(pitch), 'b 1 t -> b t'))
        pitch_emb = rearrange(pitch_emb, 'b t d -> b d t')
        expanded_pitch = einsum('k l m n, k j m -> k j n', attn, pitch_emb)
        expanded_encodings = expanded_dur + expanded_pitch
        return expanded_encodings

    # 生成样本
    @torch.no_grad()
    def sample(
        self,
        *,
        length,
        prompt = None,
        batch_size = 1,
        cond_scale = 1.,
        text = None,
        text_lens = None,
    ):
        # 如果不使用 DDIM,则使用 DDPM 进行采样
        sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample

        prompt_enc = cond = None

        # 如果是有条件的生成
        if self.conditional:
            # 确保 prompt 和 text 存在
            assert exists(prompt) and exists(text)
            # 处理 prompt
            prompt = self.process_prompt(prompt)
            # 对 prompt 进行编码
            prompt_enc = self.prompt_enc(prompt)
            # 对 text 进行音素编码
            phoneme_enc = self.phoneme_enc(text)

            # 计算音频的持续时间和音高
            duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc)
            # 重新排列 pitch 的维度
            pitch = rearrange(pitch, 'b n -> b 1 n')

            # 生成基于重复的掩码
            aln_mask = generate_mask_from_repeats(duration).float()

            # 对编码进行扩展
            cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)

        # 如果 prompt 存在
        if exists(prompt):
            # 获取批量大小
            batch_size = prompt.shape[0]

        # 生成音频
        audio = sample_fn(
            (batch_size, length, self.dim),
            prompt = prompt_enc,
            cond = cond,
            cond_scale = cond_scale
        )

        # 如果存在编解码器
        if exists(self.codec):
            # 解码音频
            audio = self.codec.decode(audio)

            # 如果音频维度为 3
            if audio.ndim == 3:
                # 重新排列音频的维度
                audio = rearrange(audio, 'b 1 n -> b n')

        # 返回音频
        return audio

    def forward(
        self,
        audio,
        text = None,
        text_lens = None,
        mel = None,
        mel_lens = None,
        codes = None,
        prompt = None,
        pitch = None,
        *args,
        **kwargs
# trainer

# 定义一个循环生成器函数,用于循环遍历数据集
def cycle(dl):
    while True:
        for data in dl:
            yield data

# Trainer 类,用于训练模型
class Trainer(object):
    def __init__(
        self,
        diffusion_model: NaturalSpeech2,
        *,
        dataset: Optional[Dataset] = None,
        folder = None,
        train_batch_size = 16,
        gradient_accumulate_every = 1,
        train_lr = 1e-4,
        train_num_steps = 100000,
        ema_update_every = 10,
        ema_decay = 0.995,
        adam_betas = (0.9, 0.99),
        save_and_sample_every = 1000,
        num_samples = 1,
        results_folder = './results',
        amp = False,
        mixed_precision_type = 'fp16',
        use_ema = True,
        split_batches = True,
        dataloader = None,
        data_max_length = None,
        data_max_length_seconds = 2,
        sample_length = None
    ):
        super().__init__()

        # accelerator

        # 初始化加速器,用于加速训练过程
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = mixed_precision_type if amp else 'no'
        )

        # model

        # 设置模型为扩散模型
        self.model = diffusion_model
        assert exists(diffusion_model.codec)

        self.dim = diffusion_model.dim

        # training hyperparameters

        # 设置训练超参数
        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.train_num_steps = train_num_steps

        # dataset and dataloader

        dl = dataloader

        if not exists(dl):
            assert exists(dataset) or exists(folder)

            if exists(dataset):
                self.ds = dataset
            elif exists(folder):
                # create dataset

                if exists(data_max_length_seconds):
                    assert not exists(data_max_length)
                    data_max_length = int(data_max_length_seconds * diffusion_model.target_sample_hz)
                else:
                    assert exists(data_max_length)

                # 创建数据集
                self.ds = SoundDataset(
                    folder,
                    max_length = data_max_length,
                    target_sample_hz = diffusion_model.target_sample_hz,
                    seq_len_multiple_of = diffusion_model.seq_len_multiple_of
                )

                dl = DataLoader(
                    self.ds,
                    batch_size = train_batch_size,
                    shuffle = True,
                    pin_memory = True,
                    num_workers = cpu_count()
                )

        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # optimizer

        # 初始化优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)

        # for logging results in a folder periodically

        self.use_ema = use_ema
        self.ema = None

        if self.accelerator.is_main_process and use_ema:
            # make sure codec is not part of the EMA
            # encodec seems to be not deepcopyable, so this is a necessary hack

            codec = diffusion_model.codec
            diffusion_model.codec = None

            # 初始化指数移动平均模型
            self.ema = EMA(
                diffusion_model,
                beta = ema_decay,
                update_every = ema_update_every,
                ignore_startswith_names = set(['codec.'])
            ).to(self.device)

            diffusion_model.codec = codec
            self.ema.ema_model.codec = codec

        # sampling hyperparameters

        # 设置采样超参数
        self.sample_length = default(sample_length, data_max_length)
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        # results folder

        # 设置结果保存文件夹
        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(exist_ok = True)

        # step counter state

        # 设置步数计数器
        self.step = 0

        # prepare model, dataloader, optimizer with accelerator

        # 使用加速器准备模型、数据加载器和优化器
        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    # 打印函数
    def print(self, msg):
        return self.accelerator.print(msg)

    @property
    # 返回未包装的模型
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)
    
    # 返回设备加速器的设备
    @property
    def device(self):
        return self.accelerator.device

    # 保存训练里程碑的模型状态
    def save(self, milestone):
        # 如果不是本地主进程,则返回
        if not self.accelerator.is_local_main_process:
            return

        # 构建保存的数据字典
        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
            'version': __version__
        }

        # 保存数据到文件
        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    # 加载训练里程碑的模型状态
    def load(self, milestone):
        accelerator = self.accelerator
        device = accelerator.device

        # 从文件加载数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)

        # 解包模型并加载状态
        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(data["ema"])

        # 打印加载的版本信息
        if 'version' in data:
            print(f"loading from version {data['version']}")

        # 如果存在加速器的缩放器和数据中的缩放器,则加载缩放器状态
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    # 训练模型
    def train(self):
        accelerator = self.accelerator
        device = accelerator.device

        # 使用 tqdm 显示训练进度
        with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:

            while self.step < self.train_num_steps:

                total_loss = 0.

                # 累积梯度并更新模型
                for _ in range(self.gradient_accumulate_every):
                    data = next(self.dl).to(device)

                    with self.accelerator.autocast():
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    self.accelerator.backward(loss)

                accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
                pbar.set_description(f'loss: {total_loss:.4f}')

                accelerator.wait_for_everyone()

                self.opt.step()
                self.opt.zero_grad()

                accelerator.wait_for_everyone()

                self.step += 1

                # 如果是主进程,更新指数移动平均模型并保存模型
                if accelerator.is_main_process:
                    self.ema.update()

                    if divisible_by(self.step, self.save_and_sample_every):
                        milestone = self.step // self.save_and_sample_every

                        models = [(self.unwrapped_model, str(self.step))]

                        if self.use_ema:
                            models.append((self.ema.ema_model, f'{self.step}.ema'))

                        for model, label in models:
                            model.eval()

                            with torch.no_grad():
                                generated = model.sample(
                                    batch_size=self.num_samples,
                                    length=self.sample_length
                                )

                            for ind, t in enumerate(generated):
                                filename = str(self.results_folder / f'sample_{label}.flac')
                                t = rearrange(t, 'n -> 1 n')
                                torchaudio.save(filename, t.cpu().detach(), self.unwrapped_model.target_sample_hz)

                        self.print(f'{self.step}: saving to {str(self.results_folder)}')

                        self.save(milestone)

                pbar.update(1)

        self.print('training complete')

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\cleaner.py

import re
from pathlib import Path
from naturalspeech2_pytorch.utils.expand.abbreviations import AbbreviationExpander
from naturalspeech2_pytorch.utils.expand.number_norm import NumberNormalizer
from naturalspeech2_pytorch.utils.expand.time_norm import TimeExpander

CURRENT_DIR = Path(__file__).resolve().parent

class TextProcessor:
    def __init__(self, lang="en"):
        self.lang = lang
        self._whitespace_re = re.compile(r"\s+")
        # 实例化缩写展开器对象
        self.ab_expander = AbbreviationExpander(str(CURRENT_DIR / 'expand/abbreviations.csv'))
        # 实例化时间展开器对象
        self.time_expander = TimeExpander()
        # 实例化数字归一化器对象
        self.num_normalizer = NumberNormalizer()
        # 添加货币转换率
        symbol = '$'
        conversion_rates ={0.01: "cent", 0.02: "cents", 1: "dollar", 2: "dollars" }
        self.num_normalizer.add_currency(symbol, conversion_rates)
    def lowercase(self, text):
        return text.lower()

    def collapse_whitespace(self, text):
        return re.sub(self._whitespace_re, " ", text).strip()

    def remove_aux_symbols(self, text):
        # 移除文本中的辅助符号
        text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text)
        return text

    def phoneme_cleaners(self, text, language = 'en'):
        # 展开时间表达式
        text = self.time_expander.expand_time(text, language=language)
        # 归一化数字
        text = self.num_normalizer.normalize_numbers(text, language=language)
        # 替换文本中的缩写
        text = self.ab_expander.replace_text_abbreviations(text, language=language)
        # 移除辅助符号
        text = self.remove_aux_symbols(text)
        # 合并多余空格
        text = self.collapse_whitespace(text)
        return text

if __name__ == "__main__":
    # 创建英语实例
    text_processor_en = TextProcessor(lang="en")

    # 处理英语文本
    english_text = "Hello, Mr. Example, this is 9:30 am and  my number is 30."
    processed_english_text = text_processor_en.phoneme_cleaners(english_text, language='en')
    print(processed_english_text)

    # 创建西班牙语实例
    text_processor_es = TextProcessor(lang="es")

    # 处理西班牙语文本
    spanish_text = "Hola, Sr. Ejemplo, son las 9:30 am y mi número es el 30."
    processed_spanish_text = text_processor_es.phoneme_cleaners(spanish_text, language='es')
    print(processed_spanish_text)

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\expand\abbreviations.py

# 导入 csv 和 re 模块
import csv
import re

# 定义一个缩写扩展类
class AbbreviationExpander:
    # 初始化方法,接收缩写文件作为参数
    def __init__(self, abbreviations_file):
        # 初始化缩写字典和模式字典
        self.abbreviations = {}
        self.patterns = {}
        # 载入缩写文件
        self.load_abbreviations(abbreviations_file)

    # 载入缩写文件的方法
    def load_abbreviations(self, abbreviations_file):
        # 打开缩写文件
        with open(abbreviations_file, 'r') as file:
            # 读取文件内容
            reader = csv.DictReader(file)
            # 遍历文件中的每一行
            for row in reader:
                # 获取缩写、扩展和语言信息
                abbreviation = row['abbreviation']
                expansion = row['expansion']
                language = row['language'].lower()
                # 将缩写和扩展信息存入缩写字典中
                self.abbreviations.setdefault(language, {})[abbreviation] = expansion

                # 如果语言不在模式字典中,则创建一个正则表达式模式
                if language not in self.patterns:
                    self.patterns[language] = re.compile(
                        r"\b(" + "|".join(re.escape(key) for key in self.abbreviations[language].keys()) + r")\b",
                        re.IGNORECASE
                    )

    # 替换缩写的方法
    def replace_abbreviations(self, match, language):
        return self.abbreviations[language][match.group(0).lower()]

    # 替换文本中的缩写的方法
    def replace_text_abbreviations(self, text, language):
        # 如果语言在模式字典中,则使用正则表达式替换缩写
        if language.lower() in self.patterns:
            return self.patterns[language.lower()].sub(
                lambda match: self.replace_abbreviations(match, language.lower()),
                text
            )
        else:
            return text

# 如果该脚本被直接执行
if __name__ == "__main__":
    # 创建一个 AbbreviationExpander 实例,载入缩写文件
    expander = AbbreviationExpander('abbreviations.csv')

    # 示例用法
    text_en = "Hello, Mr. Example. How are you today? I work at Intl. Corp."
    # 替换英文文本中的缩写
    replaced_text_en = expander.replace_text_abbreviations(text_en, 'en')
    print(replaced_text_en)

    text_fr = "Bonjour, Sr. Example. Comment ça va aujourd'hui? Je travaille chez Intl. Corp."
    # 替换法文文本中的缩写
    replaced_text_fr = expander.replace_text_abbreviations(text_fr, 'fr')
    print(replaced_text_fr)

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\expand\number_norm.py

import re
import inflect
from num2words import num2words
from num_to_words import num_to_word

# 创建一个数字标准化类
class NumberNormalizer:
    def __init__(self):
        # 初始化 inflect 引擎
        self._inflect = inflect.engine()
        # 编译正则表达式,用于匹配数字
        self._number_re = re.compile(r"-?[0-9]+")
        # 编译正则表达式,用于匹配货币
        self._currency_re = re.compile(r"([$€£¥₹])([0-9\,\.]*[0-9]+)")
        # 存储货币转换率的字典
        self._currencies = {}

    # 添加货币转换率
    def add_currency(self, symbol, conversion_rates):
        self._currencies[symbol] = conversion_rates

    # 标准化文本中的数字
    def normalize_numbers(self, text, language='en'):
        self._inflect = inflect.engine()
        self._set_language(language)
        # 替换文本中的货币
        text = re.sub(self._currency_re, self._expand_currency, text)
        # 替换文本中的数字
        text = re.sub(self._number_re, lambda match: self._expand_number(match, language), text)
        return text

    # 设置语言
    def _set_language(self, language):
        if language == 'en':
            self._inflect = inflect.engine()
        else:
            self._inflect = inflect.engine()
            # 在这里添加对其他语言的支持

    # 扩展货币
    def _expand_currency(self, match):
        unit = match.group(1)
        currency = self._currencies.get(unit)
        if currency:
            value = match.group(2)
            return self._expand_currency_value(value, currency)
        return match.group(0)

    # 扩展货币值
    def _expand_currency_value(self, value, inflection):
        parts = value.replace(",", "").split(".")
        if len(parts) > 2:
            return f"{value} {inflection[2]}"  # 意外的格式
        text = []
        integer = int(parts[0]) if parts[0] else 0
        if integer > 0:
            integer_unit = inflection.get(integer, inflection[2])
            text.append(f"{integer} {integer_unit}")
        fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
        if fraction > 0:
            fraction_unit = inflection.get(fraction / 100, inflection[0.02])
            text.append(f"{fraction} {fraction_unit}")
        if not text:
            return f"zero {inflection[2]}"
        return " ".join(text)

    # 扩展数字
    def _expand_number(self, match, language: str) -> str:
        num = int(match.group(0))
        if 1000 < num < 3000:
            if num == 2000:
                return self._number_to_words(num, language)
            if 2000 < num < 2010:
                return f"{self._number_to_words(2000, language)} {self._number_to_words(num % 100, language)}"
            if num % 100 == 0:
                return f"{self._number_to_words(num // 100, language)} {self._get_word('hundred')}"
            return self._number_to_words(num, language)
        return self._number_to_words(num, language)

    # 将数字转换为单词
    def _number_to_words(self, n: int, language: str) -> str:
        try:
            if language == 'en':
                return self._inflect.number_to_words(n)
            else:
                return num2words(n, lang=language)
        except:
            try:
                return num_to_word(n, lang=language)
            except:
                raise NotImplementedError("language not implemented")

    # 获取单词
    def _get_word(self, word):
        return word

# 如果作为主程序运行
if __name__ == "__main__":
    # 创建 NumberNormalizer 的实例
    normalizer = NumberNormalizer()
    # 添加货币转换率
    symbol = '$'
    conversion_rates ={
            0.01: "cent",
            0.02: "cents",
            1: "dollar",
            2: "dollars",
        }
    normalizer.add_currency(symbol, conversion_rates)
    # 示例 1:英语(en)语言
    text_en = "I have $1,000 and 5 apples."
    normalized_text_en = normalizer.normalize_numbers(text_en, language='en')
    print(normalized_text_en)
    # 输出: "I have one thousand dollars and five apples."

    # 示例 2:西班牙语(es)语言
    text_es = "Tengo $1.000 y 5 manzanas."
    normalized_text_es = normalizer.normalize_numbers(text_es, language='es')
    print(normalized_text_es)
    # 输出: "Tengo mil dólares y cinco manzanas."

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\expand\time_norm.py

import re
import inflect
from num2words import num2words
from num_to_words import num_to_word

# 定义一个时间扩展器类
class TimeExpander:
    def __init__(self):
        # 初始化 inflect 引擎
        self._inflect = inflect.engine()
        # 获取时间正则表达式
        self._time_re = self._get_time_regex()

    # 获取时间正则表达式的私有方法
    def _get_time_regex(self):
        return re.compile(
            r"""\b
            ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3]))  # hours
            :
            ([0-5][0-9])                            # minutes
            \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
            \b""",
            re.IGNORECASE | re.X,
        )

    # 将数字扩展为单词的私有方法
    def _expand_num(self, n: int, language: str) -> str:
        try:
            if language == 'en':
                return self._inflect.number_to_words(n)
            else:
                return num2words(n, lang=language)
        except:
            try:
                return num_to_word(n, lang=language)
            except:
                raise NotImplementedError("language not implemented")

    # 扩展时间的私有方法
    def _expand_time(self, match: "re.Match", language: str) -> str:
        hour = int(match.group(1))
        past_noon = hour >= 12
        time = []
        if hour > 12:
            hour -= 12
        elif hour == 0:
            hour = 12
            past_noon = True
        time.append(self._expand_num(hour, language))

        minute = int(match.group(6))
        if minute > 0:
            if minute < 10:
                time.append("oh")
            time.append(self._expand_num(minute, language))

        am_pm = match.group(7)
        if am_pm is not None:
            time.extend(list(am_pm.replace(".", "")))

        return " ".join(time)

    # 扩展时间的公共方法
    def expand_time(self, text: str, language: str) -> str:
        return re.sub(self._time_re, lambda match: self._expand_time(match, language), text)

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\expand\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\phonemizers\base.py

""" from https://github.com/coqui-ai/TTS/"""
# 导入必要的模块
import abc
from typing import List, Tuple

from naturalspeech2_pytorch.utils.phonemizers.punctuation import Punctuation

# 定义一个抽象基类 BasePhonemizer
class BasePhonemizer(abc.ABC):
    """Base phonemizer class

    Phonemization follows the following steps:
        1. Preprocessing:
            - remove empty lines
            - remove punctuation
            - keep track of punctuation marks

        2. Phonemization:
            - convert text to phonemes

        3. Postprocessing:
            - join phonemes
            - restore punctuation marks

    Args:
        language (str):
            Language used by the phonemizer.

        punctuations (List[str]):
            List of punctuation marks to be preserved.

        keep_puncs (bool):
            Whether to preserve punctuation marks or not.
    """

    def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
        # ensure the backend is installed on the system
        if not self.is_available():
            raise RuntimeError("{} not installed on your system".format(self.name()))  # pragma: nocover

        # ensure the backend support the requested language
        self._language = self._init_language(language)

        # setup punctuation processing
        self._keep_puncs = keep_puncs
        self._punctuator = Punctuation(punctuations)

    def _init_language(self, language):
        """Language initialization

        This method may be overloaded in child classes (see Segments backend)

        """
        if not self.is_supported_language(language):
            raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
        return language

    @property
    def language(self):
        """The language code configured to be used for phonemization"""
        return self._language

    @staticmethod
    @abc.abstractmethod
    def name():
        """The name of the backend"""
        ...

    @classmethod
    @abc.abstractmethod
    def is_available(cls):
        """Returns True if the backend is installed, False otherwise"""
        ...

    @classmethod
    @abc.abstractmethod
    def version(cls):
        """Return the backend version as a tuple (major, minor, patch)"""
        ...

    @staticmethod
    @abc.abstractmethod
    def supported_languages():
        """Return a dict of language codes -> name supported by the backend"""
        ...

    def is_supported_language(self, language):
        """Returns True if `language` is supported by the backend"""
        return language in self.supported_languages()

    @abc.abstractmethod
    def _phonemize(self, text, separator):
        """The main phonemization method"""

    def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
        """Preprocess the text before phonemization

        1. remove spaces
        2. remove punctuation

        Override this if you need a different behaviour
        """
        text = text.strip()
        if self._keep_puncs:
            # a tuple (text, punctuation marks)
            return self._punctuator.strip_to_restore(text)
        return [self._punctuator.strip(text)], []

    def _phonemize_postprocess(self, phonemized, punctuations) -> str:
        """Postprocess the raw phonemized output

        Override this if you need a different behaviour
        """
        if self._keep_puncs:
            return self._punctuator.restore(phonemized, punctuations)[0]
        return phonemized[0]
    # 定义一个方法,将文本转换为给定语言的音素表示
    def phonemize(self, text: str, separator="|", language: str = None) -> str:  # pylint: disable=unused-argument
        """Returns the `text` phonemized for the given language

        Args:
            text (str):
                Text to be phonemized.

            separator (str):
                string separator used between phonemes. Default to '_'.

        Returns:
            (str): Phonemized text
        """
        # 对文本进行预处理,获取文本和标点符号
        text, punctuations = self._phonemize_preprocess(text)
        phonemized = []
        # 遍历文本中的每个字符
        for t in text:
            # 将每个字符转换为音素表示,并使用指定的分隔符
            p = self._phonemize(t, separator)
            phonemized.append(p)
        # 对音素表示进行后处理,恢复标点符号
        phonemized = self._phonemize_postprocess(phonemized, punctuations)
        # 返回音素表示的文本
        return phonemized

    # 打印日志信息,包括音素语言和后端信息
    def print_logs(self, level: int = 0):
        # 根据缩进级别生成缩进字符串
        indent = "\t" * level
        # 打印音素语言信息
        print(f"{indent}| > phoneme language: {self.language}")
        # 打印后端信息
        print(f"{indent}| > phoneme backend: {self.name()}")

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\phonemizers\espeak_wrapper.py

""" from https://github.com/coqui-ai/TTS/"""
# 导入所需的模块
import logging
import re
import subprocess
from typing import Dict, List

from packaging.version import Version

from naturalspeech2_pytorch.utils.phonemizers.base import BasePhonemizer
from naturalspeech2_pytorch.utils.phonemizers.punctuation import Punctuation

# 检查系统中是否存在指定的可执行程序
def is_tool(name):
    from shutil import which

    return which(name) is not None

# 使用正则表达式模式匹配 espeak 版本号
espeak_version_pattern = re.compile(r"text-to-speech:\s(?P<version>\d+\.\d+(\.\d+)?)")


# 获取 espeak 版本号
def get_espeak_version():
    output = subprocess.getoutput("espeak --version")
    match = espeak_version_pattern.search(output)

    return match.group("version")

# 获取 espeak-ng 版本号
def get_espeakng_version():
    output = subprocess.getoutput("espeak-ng --version")
    return output.split()[3]

# 优先使用 espeak-ng,其次使用 espeak
if is_tool("espeak-ng"):
    _DEF_ESPEAK_LIB = "espeak-ng"
    _DEF_ESPEAK_VER = get_espeakng_version()
elif is_tool("espeak"):
    _DEF_ESPEAK_LIB = "espeak"
    _DEF_ESPEAK_VER = get_espeak_version()
else:
    _DEF_ESPEAK_LIB = None
    _DEF_ESPEAK_VER = None

# 运行 espeak 命令行工具
def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
    """Run espeak with the given arguments."""
    cmd = [
        espeak_lib,
        "-q",
        "-b",
        "1",  # UTF8 text encoding
    ]
    cmd.extend(args)
    logging.debug("espeakng: executing %s", repr(cmd))

    with subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
    ) as p:
        res = iter(p.stdout.readline, b"")
        if not sync:
            p.stdout.close()
            if p.stderr:
                p.stderr.close()
            if p.stdin:
                p.stdin.close()
            return res
        res2 = []
        for line in res:
            res2.append(line)
        p.stdout.close()
        if p.stderr:
            p.stderr.close()
        if p.stdin:
            p.stdin.close()
        p.wait()
    return res2

# ESpeak 类,用于调用 espeak 或 espeak-ng 执行 G2P
class ESpeak(BasePhonemizer):
    """ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P

    Args:
        language (str):
            Valid language code for the used backend.

        backend (str):
            Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically
            prefering `espeak-ng` over `espeak`. Defaults to None.

        punctuations (str):
            Characters to be treated as punctuation. Defaults to Punctuation.default_puncs().

        keep_puncs (bool):
            If True, keep the punctuations after phonemization. Defaults to True.

    Example:
        >>> phonemizer = ESpeak("tr")
        >>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|")
        'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.'

    """

    _ESPEAK_LIB = _DEF_ESPEAK_LIB
    _ESPEAK_VER = _DEF_ESPEAK_VER

    def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True):
        if self._ESPEAK_LIB is None:
            raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.")
        self.backend = self._ESPEAK_LIB

        # band-aid for backwards compatibility
        if language == "en":
            language = "en-us"
        if language == "zh-cn":
            language = "cmn"

        super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
        if backend is not None:
            self.backend = backend

    @property
    def backend(self):
        return self._ESPEAK_LIB

    @property
    def backend_version(self):
        return self._ESPEAK_VER

    @backend.setter
    # 设置后端引擎
    def backend(self, backend):
        # 检查后端引擎是否为有效值
        if backend not in ["espeak", "espeak-ng"]:
            raise Exception("Unknown backend: %s" % backend)
        # 设置 ESPEAK_LIB 为指定的后端引擎
        self._ESPEAK_LIB = backend
        # 根据后端引擎设置 ESPEAK_VER
        self._ESPEAK_VER = get_espeakng_version() if backend == "espeak-ng" else get_espeak_version()

    # 自动设置 espeak 库
    def auto_set_espeak_lib(self) -> None:
        # 检查是否存在 espeak-ng 工具
        if is_tool("espeak-ng"):
            self._ESPEAK_LIB = "espeak-ng"
            self._ESPEAK_VER = get_espeakng_version()
        # 检查是否存在 espeak 工具
        elif is_tool("espeak"):
            self._ESPEAK_LIB = "espeak"
            self._ESPEAK_VER = get_espeak_version()
        else:
            raise Exception("Cannot set backend automatically. espeak-ng or espeak not found")

    # 返回引擎名称
    @staticmethod
    def name():
        return "espeak"

    # 将输入文本转换为音素
    def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str:
        """Convert input text to phonemes.

        Args:
            text (str):
                Text to be converted to phonemes.

            tie (bool, optional) : When True use a '͡' character between
                consecutive characters of a single phoneme. Else separate phoneme
                with '_'. This option requires espeak>=1.49. Default to False.
        """
        # 设置参数
        args = ["-v", f"{self._language}"]
        # 根据 tie 参数选择不同的音素分隔方式
        if tie:
            # 在音素之间使用 '͡'
            if self.backend == "espeak":
                args.append("--ipa=1")
            else:
                args.append("--ipa=3")
        else:
            # 使用 '_' 分隔音素
            if self.backend == "espeak":
                if Version(self.backend_version) >= Version("1.48.15"):
                    args.append("--ipa=1")
                else:
                    args.append("--ipa=3")
            else:
                args.append("--ipa=1")
        if tie:
            args.append("--tie=%s" % tie)

        args.append('"' + text + '"')
        # 计算音素
        phonemes = ""
        for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
            logging.debug("line: %s", repr(line))
            ph_decoded = line.decode("utf8").strip()
            # 处理 espeak ��� espeak-ng 返回的文本
            ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:]
            # 移除 espeak-ng 返回文本中的语言标记
            ph_decoded = re.sub(r"\(.+?\)", "", ph_decoded)
            phonemes += ph_decoded.strip()
        return phonemes.replace("_", separator)

    # 调用 phonemize_espeak 方法,设置 tie 参数为 False
    def _phonemize(self, text, separator=None):
        return self.phonemize_espeak(text, separator, tie=False)

    # 返回支持的语言字典
    @staticmethod
    def supported_languages() -> Dict:
        """Get a dictionary of supported languages.

        Returns:
            Dict: Dictionary of language codes.
        """
        if _DEF_ESPEAK_LIB is None:
            return {}
        args = ["--voices"]
        langs = {}
        count = 0
        for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True):
            line = line.decode("utf8").strip()
            if count > 0:
                cols = line.split()
                lang_code = cols[1]
                lang_name = cols[3]
                langs[lang_code] = lang_name
            logging.debug("line: %s", repr(line))
            count += 1
        return langs
    # 返回当前使用的后端的版本号
    def version(self) -> str:
        """Get the version of the used backend.

        Returns:
            str: Version of the used backend.
        """
        # 定义参数列表,包含获取版本信息的参数
        args = ["--version"]
        # 遍历执行 espeak_exe 函数返回的结果,同步执行
        for line in _espeak_exe(self.backend, args, sync=True):
            # 解码行内容为 UTF-8 格式,去除空格并按空格分割,获取版本号
            version = line.decode("utf8").strip().split()[2]
            # 记录调试信息
            logging.debug("line: %s", repr(line))
            # 返回版本号
            return version

    @classmethod
    # 检查 ESpeak 是否可用,可用返回 True,否则返回 False
    def is_available(cls):
        """Return true if ESpeak is available else false"""
        # 检查是否存在 espeak 或 espeak-ng 工具
        return is_tool("espeak") or is_tool("espeak-ng")
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 创建一个 ESpeak 对象,指定语言为英语
    e = ESpeak(language="en-us")
    # 打印支持的语言列表
    print(e.supported_languages())
    # 打印 ESpeak 的版本信息
    print(e.version())
    # 打印 ESpeak 对象的语言属性
    print(e.language)
    # 打印 ESpeak 对象的名称
    print(e.name())
    # 打印 ESpeak 对象是否可用
    print(e.is_available())

    # 创建一个 ESpeak 对象,指定语言为英语,不保留标点符号
    e = ESpeak(language="en-us", keep_puncs=False)
    # 打印使用 ESpeak 对象将文本转换为音素的结果,加上反引号
    print("`" + e.phonemize("hello how are you today?") + "`")

    # 创建一个 ESpeak 对象,指定语言为英语,保留标点符号
    e = ESpeak(language="en-us", keep_puncs=True)
    # 打印使用 ESpeak 对象将文本转换为音素的结果,加上反引号
    print("`" + e.phonemize("hello how are you today?") + "`")

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\phonemizers\punctuation.py

""" from https://github.com/coqui-ai/TTS/"""
# 导入所需的库
import collections
import re
from enum import Enum

import six

# 默认的标点符号
_DEF_PUNCS = ';:,.!?¡¿—…"«»“”'

# 命名元组,用于表示标点符号和位置
_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])

# 枚举类,表示标点符号的位置
class PuncPosition(Enum):
    """Enum for the punctuations positions"""
    BEGIN = 0
    END = 1
    MIDDLE = 2
    ALONE = 3

# 处理文本中的标点符号
class Punctuation:
    """Handle punctuations in text.

    Just strip punctuations from text or strip and restore them later.

    Args:
        puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.

    Example:
        >>> punc = Punctuation()
        >>> punc.strip("This is. example !")
        'This is example'

        >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
        >>> ' '.join(text_striped)
        'This is example'

        >>> text_restored = punc.restore(text_striped, punc_map)
        >>> text_restored[0]
        'This is. example !'
    """

    def __init__(self, puncs: str = _DEF_PUNCS):
        self.puncs = puncs

    @staticmethod
    def default_puncs():
        """Return default set of punctuations."""
        return _DEF_PUNCS

    @property
    def puncs(self):
        return self._puncs

    @puncs.setter
    def puncs(self, value):
        if not isinstance(value, six.string_types):
            raise ValueError("[!] Punctuations must be of type str.")
        self._puncs = "".join(list(dict.fromkeys(list(value))))  # remove duplicates without changing the oreder
        self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")

    def strip(self, text):
        """Remove all the punctuations by replacing with `space`.

        Args:
            text (str): The text to be processed.

        Example::

            "This is. example !" -> "This is example "
        """
        return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()

    def strip_to_restore(self, text):
        """Remove punctuations from text to restore them later.

        Args:
            text (str): The text to be processed.

        Examples ::

            "This is. example !" -> [["This is", "example"], [".", "!"]]

        """
        text, puncs = self._strip_to_restore(text)
        return text, puncs

    def _strip_to_restore(self, text):
        """Auxiliary method for Punctuation.preserve()"""
        matches = list(re.finditer(self.puncs_regular_exp, text))
        if not matches:
            return [text], []
        # the text is only punctuations
        if len(matches) == 1 and matches[0].group() == text:
            return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
        # build a punctuation map to be used later to restore punctuations
        puncs = []
        for match in matches:
            position = PuncPosition.MIDDLE
            if match == matches[0] and text.startswith(match.group()):
                position = PuncPosition.BEGIN
            elif match == matches[-1] and text.endswith(match.group()):
                position = PuncPosition.END
            puncs.append(_PUNC_IDX(match.group(), position))
        # convert str text to a List[str], each item is separated by a punctuation
        splitted_text = []
        for idx, punc in enumerate(puncs):
            split = text.split(punc.punc)
            prefix, suffix = split[0], punc.punc.join(split[1:])
            splitted_text.append(prefix)
            # if the text does not end with a punctuation, add it to the last item
            if idx == len(puncs) - 1 and len(suffix) > 0:
                splitted_text.append(suffix)
            text = suffix
        return splitted_text, puncs

    @classmethod
    # 从给定文本中恢复标点符号
    def restore(cls, text, puncs):
        """Restore punctuation in a text.

        Args:
            text (str): The text to be processed.
            puncs (List[str]): The list of punctuations map to be used for restoring.

        Examples ::

            ['This is', 'example'], ['.', '!'] -> "This is. example!"

        """
        # 调用内部方法 _restore() 来执行标点符号的恢复
        return cls._restore(text, puncs, 0)

    @classmethod
    def _restore(cls, text, puncs, num):  # pylint: disable=too-many-return-statements
        """Auxiliary method for Punctuation.restore()"""
        # 如果没有标点符号,则直接返回文本
        if not puncs:
            return text

        # 如果文本为空,则返回标点符号列表
        if not text:
            return ["".join(m.punc for m in puncs)]

        # 获取当前处理的标点符号
        current = puncs[0]

        # 如果当前标点符号在句子开头
        if current.position == PuncPosition.BEGIN:
            return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)

        # 如果当前标点符号在句子结尾
        if current.position == PuncPosition.END:
            return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)

        # 如果当前标点符号独立存在
        if current.position == PuncPosition.ALONE:
            return [current.mark] + cls._restore(text, puncs[1:], num + 1)

        # 如果当前标点符号在句子中间
        if len(text) == 1:  # pragma: nocover
            # 一个特殊情况,中间标点符号的最后部分未被处理
            return cls._restore([text[0] + current.punc], puncs[1:], num)

        return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
# 如果当前脚本作为主程序执行
if __name__ == "__main__":
    # 创建一个标点符号处理对象
    punc = Punctuation()
    # 定义一个包含标点符号的文本
    text = "This is. This is, example!"

    # 打印去除标点符号后的文本
    print(punc.strip(text))

    # 将文本分割成不包含标点符号的部分和标点符号部分
    split_text, puncs = punc.strip_to_restore(text)
    print(split_text, " ---- ", puncs)

    # 恢复文本,将不包含标点符号的部分和标点符号部分合并
    restored_text = punc.restore(split_text, puncs)
    print(restored_text)

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\phonemizers\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\tokenizer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 Tensor 类型
from torch import Tensor
# 从 typing 模块中导入 Callable, List, Optional, Tuple 类型
from typing import Callable, List, Optional, Tuple

# 从 torch.nn.utils.rnn 模块中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence

# 从 naturalspeech2_pytorch.utils.cleaner 模块中导入 TextProcessor 类
from naturalspeech2_pytorch.utils.cleaner import TextProcessor
# 从 naturalspeech2_pytorch.utils.phonemizers.espeak_wrapper 模块中导入 ESpeak 类

from naturalspeech2_pytorch.utils.phonemizers.espeak_wrapper import ESpeak

# 默认的音素集合

_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "'̃ˈˌːˑ. ,-"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics

# 默认的语言映射

LANGUAGE_MAP = {
    'en-us': 'en',
    'fr-fr': 'es',
    'hi': 'hi'
}

# 函数

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# 主类

class Tokenizer:
    def __init__(
        self,
        vocab = _phonemes,
        text_cleaner: Optional[Callable] = None,
        phonemizer: Optional[Callable] = None,
        default_lang = "en-us",
        add_blank: bool = False,
        use_eos_bos = False,
        pad_id = -1
    ):
        # 初始化 Tokenizer 类
        self.text_cleaner = default(text_cleaner, TextProcessor().phoneme_cleaners)
        self.add_blank = add_blank
        self.use_eos_bos = use_eos_bos
        self.pad_id = pad_id

        self.vocab = vocab
        self.vocab_size = len(vocab)

        self.char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
        self.id_to_char = {idx: char for idx, char in enumerate(self.vocab)}

        self.phonemizer = phonemizer
        if not exists(self.phonemizer):
            self.phonemizer = ESpeak(language = default_lang)

        self.language = self.phonemizer.language
        self.not_found_characters = []

    @property
    def espeak_language(self):
        return LANGUAGE_MAP.get(self.language, None)

    def encode(self, text: str) -> List[int]:
        """Encodes a string of text as a sequence of IDs."""
        token_ids = []
        for char in text:
            try:
                idx = self.char_to_id[char]
                token_ids.append(idx)
            except KeyError:
                # 丢弃但存储未找到的字符
                if char not in self.not_found_characters:
                    self.not_found_characters.append(char)
                    print(text)
                    print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
        return token_ids

    def decode(self, token_ids: List[int]) -> str:
        """Decodes a sequence of IDs to a string of text."""
        text = ""
        for token_id in token_ids:
            text += self.id_to_char[token_id]
        return text

    def text_to_ids(
        self,
        text: str,
        language: str = None
    ) -> Tuple[List[int], str, str]:
        """Converts a string of text to a sequence of token IDs.

        Args:
            text(str):
                The text to convert to token IDs.

            language(str):
                The language code of the text. Defaults to None.

        TODO:
            - Add support for language-specific processing.

        1. Text normalizatin
        2. Phonemization (if use_phonemes is True)
        3. Add blank char between characters
        4. Add BOS and EOS characters
        5. Text to token IDs
        """

        language = default(language, self.espeak_language)

        cleaned_text = None
        if self.text_cleaner is not None:
            text = self.text_cleaner(text, language=language)
            cleaned_text = text
        phonemized = self.phonemizer.phonemize(text, separator="", language=language)
        if self.add_blank:
            phonemized = self.intersperse_blank_char(phonemized, True)
        if self.use_eos_bos:
            phonemized = self.pad_with_bos_eos(phonemized)

        return self.encode(phonemized), cleaned_text, phonemized
    # 将文本转换为张量的 ID 序列
    def texts_to_tensor_ids(self, texts: List[str], language: str = None) -> Tensor:
        # 存储所有文本的 ID 序列
        all_ids = []

        # 遍历每个文本
        for text in texts:
            # 调用 text_to_ids 方法将文本转换为 ID 序列
            ids, *_ = self.text_to_ids(text, language=language)
            # 将 ID 序列转换为张量并添加到 all_ids 中
            all_ids.append(torch.tensor(ids))

        # 使用 pad_sequence 函数对所有 ID 序列进行填充,并返回结果张量
        return pad_sequence(all_ids, batch_first=True, padding_value=self.pad_id)

    # 将 ID 序列转换为文本
    def ids_to_text(self, id_sequence: List[int]) -> str:
        """Converts a sequence of token IDs to a string of text."""
        # 调用 decode 方法将 ID 序列转换为文本
        return self.decode(id_sequence)

    # 在字符序列前后添加特殊的 BOS 和 EOS 字符
    def pad_with_bos_eos(self, char_sequence: List[str]):
        """Pads a sequence with the special BOS and EOS characters."""
        # 在字符序列前后分别添加 BOS 和 EOS 字符
        return [self.characters.bos] + list(char_sequence) + [self.characters.eos]

    # 在字符序列中插入空白字符
    def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
        """Intersperses the blank character between characters in a sequence.

        Use the ```py```py character if defined else use the ```pad```py character.
        """
        # 根据 use_blank_char 决定使用 blank 或 pad 字符
        char_to_use = self.characters.blank if use_blank_char else self.characters.pad
        # 创建一个新列表,将空白字符插入到字符序列中
        result = [char_to_use] * (len(char_sequence) * 2 + 1)
        result[1::2] = char_sequence
        return result
# 如果当前脚本被直接执行,则执行以下代码
if __name__ == "__main__":
    # 创建一个文本处理器对象
    txt_cleaner = TextProcessor()
    # 创建一个分词器对象,指定词汇表、文本清洁器和音素合成器
    tokenizer = Tokenizer(vocab = _phonemes, text_cleaner = txt_cleaner.phoneme_cleaners, phonemizer = ESpeak(language="en-us"))
    # 将文本转换为对应的 ID 序列,并打印结果
    print(tokenizer.text_to_ids("Hello, Mr. Example, this is 9:30 am and  my number is 30.", language="en"))
    # 创建另一个分词器对象,指定不同的语言的音素合成器
    tokenizer = Tokenizer(vocab = _phonemes, text_cleaner = txt_cleaner.phoneme_cleaners, phonemizer = ESpeak(language="fr-fr"))
    # 将文本转换为对应的 ID 序列,并打印结果
    print(tokenizer.text_to_ids("Hola, Sr. Ejemplo, son las 9:30 am y mi número es el 30.", language="es"))
    # 创建另一个分词器对象,指定不同的语言的音素合成器
    tokenizer = Tokenizer(vocab = _phonemes, text_cleaner = txt_cleaner.phoneme_cleaners, phonemizer = ESpeak(language="hi"))
    # 将文本转换为对应的 ID 序列,并打印结果
    print(tokenizer.text_to_ids("हैलो, मिस्टर उदाहरण, यह सुबह 9:30 बजे है और मेरा नंबर 30 है।", language="hi"))

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\utils.py

import torch
from einops import repeat, rearrange

def average_over_durations(values, durs):
    """
        - in:
            - values: B, 1, T_de
            - durs: B, T_en
        - out:
            - avg: B, 1, T_en
    """
    # 计算累积持续时间的结束位置
    durs_cums_ends = torch.cumsum(durs, dim=1).long()
    # 计算累积持续时间的开始位置
    durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
    # 计算非零值的累积
    values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0))
    # 计算值的累积
    values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0))

    bs, l = durs_cums_ends.size()
    n_formants = values.size(1)
    # 重复持续时间的开始位置
    dcs = repeat(durs_cums_starts, 'bs l -> bs n l', n=n_formants)
    # 重复持续时间的结束位置
    dce = repeat(durs_cums_ends, 'bs l -> bs n l', n=n_formants)

    # 计算值的总和
    values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).to(values.dtype)
    # 计算值的元素个数
    values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).to(values.dtype)

    # 计算平均值
    avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems).to(values.dtype)
    return avg

def create_mask(sequence_length, max_len):
    dtype, device = sequence_length.dtype, sequence_length.device
    # 创建一个序列范围
    seq_range = torch.arange(max_len, dtype=dtype, device=device)
    sequence_length = rearrange(sequence_length, 'b -> b 1')
    seq_range = rearrange(seq_range, 't -> 1 t')
    return seq_range < sequence_length

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\utils\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\version.py

# 定义变量 __version__,表示当前代码的版本号为 '0.1.8'
__version__ = '0.1.8'

.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\__init__.py

# 导入 torch 库
import torch
# 导入版本比较模块 version
from packaging import version

# 检查 torch 库的版本是否大于等于 '2.0.0',如果是则执行以下代码
if version.parse(torch.__version__) >= version.parse('2.0.0'):
    # 从 einops._torch_specific 模块中导入 allow_ops_in_compiled_graph 函数
    from einops._torch_specific import allow_ops_in_compiled_graph
    # 调用 allow_ops_in_compiled_graph 函数

# 从 naturalspeech2_pytorch.naturalspeech2_pytorch 模块中导入以下类
from naturalspeech2_pytorch.naturalspeech2_pytorch import (
    NaturalSpeech2,
    Transformer,
    Wavenet,
    Model,
    Trainer,
    PhonemeEncoder,
    DurationPitchPredictor,
    SpeechPromptEncoder,
    Tokenizer,
    ESpeak
)

# 从 audiolm_pytorch 模块中导入以下类
from audiolm_pytorch import (
    SoundStream,
    EncodecWrapper
)

Natural Speech 2 - Pytorch (wip)

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch

NaturalSpeech 2 is a TTS system that leverages a neural audio codec with continuous latent vectors and a latent diffusion model with non-autoregressive generation to enable natural and zero-shot text-to-speech synthesis

This repository will use denoising diffusion rather than score-based SDE, and may potentially offer elucidated version as well. It will also offer improvements for the attention / transformer components wherever applicable.

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • 🤗 Huggingface for the amazing accelerate library

  • Manmay for submitting the initial code for phoneme, pitch, duration, and speech prompt encoders as well as the multilingual phonemizer and phoneme aligner!

  • Manmay for wiring up the complete end-to-end conditioning of the diffusion network!

  • You? If you are an aspiring ML / AI engineer or work in the TTS field and would like to contribute to open sourcing state-of-the-art, jump right in!

Install

$ pip install naturalspeech2-pytorch

Usage

import torch
from naturalspeech2_pytorch import (
    EncodecWrapper,
    Model,
    NaturalSpeech2
)

# use encodec as an example

codec = EncodecWrapper()

model = Model(
    dim = 128,
    depth = 6
)

# natural speech diffusion model

diffusion = NaturalSpeech2(
    model = model,
    codec = codec,
    timesteps = 1000
).cuda()

# mock raw audio data

raw_audio = torch.randn(4, 327680).cuda()

loss = diffusion(raw_audio)
loss.backward()

# do the above in a loop for a lot of raw audio data...
# then you can sample from your generative model as so

generated_audio = diffusion.sample(length = 1024) # (1, 327680)

With conditioning

ex.

import torch
from naturalspeech2_pytorch import (
    EncodecWrapper,
    Model,
    NaturalSpeech2,
    SpeechPromptEncoder
)

# use encodec as an example

codec = EncodecWrapper()

model = Model(
    dim = 128,
    depth = 6,
    dim_prompt = 512,
    cond_drop_prob = 0.25,                  # dropout prompt conditioning with this probability, for classifier free guidance
    condition_on_prompt = True
)

# natural speech diffusion model

diffusion = NaturalSpeech2(
    model = model,
    codec = codec,
    timesteps = 1000
)

# mock raw audio data

raw_audio = torch.randn(4, 327680)
prompt = torch.randn(4, 32768)               # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically

text = torch.randint(0, 100, (4, 100))
text_lens = torch.tensor([100, 50 , 80, 100])

# forwards and backwards

loss = diffusion(
    audio = raw_audio,
    text = text,
    text_lens = text_lens,
    prompt = prompt
)

loss.backward()

# after much training

generated_audio = diffusion.sample(
    length = 1024,
    text = text,
    prompt = prompt
) # (1, 327680)

Or if you want a Trainer class to take care of the training and sampling loop, just simply do

from naturalspeech2_pytorch import Trainer

trainer = Trainer(
    diffusion_model = diffusion,     # diffusion model + codec from above
    folder = '/path/to/speech',
    train_batch_size = 16,
    gradient_accumulate_every = 2,
)

trainer.train()

Todo

Citations

@inproceedings{Shen2023NaturalSpeech2L,
    title   = {NaturalSpeech 2: Latent Diffusion Models are Natural and Zero-Shot Speech and Singing Synthesizers},
    author  = {Kai Shen and Zeqian Ju and Xu Tan and Yanqing Liu and Yichong Leng and Lei He and Tao Qin and Sheng Zhao and Jiang Bian},
    year    = {2023}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@article{Alayrac2022FlamingoAV,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
    journal  = {ArXiv},
    year     = {2022},
    volume   = {abs/2204.14198}
}
@article{Badlani2021OneTA,
    title   = {One TTS Alignment to Rule Them All},
    author  = {Rohan Badlani and Adrian Lancucki and Kevin J. Shih and Rafael Valle and Wei Ping and Bryan Catanzaro},
    journal = {ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
    year    = {2021},
    pages   = {6092-6096},
    url     = {https://api.semanticscholar.org/CorpusID:237277973}
}

.\lucidrains\naturalspeech2-pytorch\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('naturalspeech2_pytorch/version.py').read())

# 设置包的元数据
setup(
  name = 'naturalspeech2-pytorch', # 包名
  packages = find_packages(exclude=[]), # 查找所有包
  version = __version__, # 版本号
  license='MIT', # 许可证
  description = 'Natural Speech 2 - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  include_package_data = True, # 包含包数据
  url = 'https://github.com/lucidrains/naturalspeech2-pytorch', # 项目链接
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'latent diffusion',
    'speech synthesis'
  ],
  install_requires=[ # 安装依赖
    'accelerate',
    'audiolm-pytorch>=0.30.2',
    'beartype',
    'einops>=0.6.1',
    'ema-pytorch',
    'indic-num2words',
    'inflect',
    'local-attention',
    'num2words',
    'pyworld',
    'pydantic<2.0',
    'torch>=1.6',
    'tqdm',
    'vector-quantize-pytorch>=1.4.1'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\neural-plexer-pytorch\neural_plexer_pytorch\neural_plexer_pytorch.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\neural-plexer-pytorch\neural_plexer_pytorch\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

NeuralPlexer - Pytorch (wip)

Implementation of Nvidia's NeuralPlexer, for end-to-end differentiable design of functional small-molecules and ligand-binding proteins, in Pytorch

Citations

@article{Qiao2022DynamicBackbonePS,
    title   = {Dynamic-Backbone Protein-Ligand Structure Prediction with Multiscale Generative Diffusion Models},
    author  = {Zhuoran Qiao and Weili Nie and Arash Vahdat and Thomas F. Miller and Anima Anandkumar},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.15171}
}

.\lucidrains\neural-plexer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'neural-plexer-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.1',
  # 许可证
  license='MIT',
  # 描述
  description = 'Neural Plexer',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/neural-plexer-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'drug design'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'torch>=2.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Neural Sequence Chunkers - Pytorch (wip)

Implementation of the Neural Sequence Chunker, Schmidhuber paper back from 1991, in the context of Attention and Transformers. Someone had anonymously requested it to be built here.

Citations

@TECHREPORT{Schmidhuber91neuralsequence,
    author = {Jürgen Schmidhuber},
    title = {Neural Sequence Chunkers},
    year = {1991}
}

nim-genetic-algorithm

a simple genetic algorithm written in Nim

running

$ nim c -r ga.nim

Nim Tokenizer (wip)

Implementation of a simple BPE tokenizer, but in Nim. May contain BPE Dropout too

Todo

Citations

@inproceedings{Wang2019NeuralMT,
    title   = {Neural Machine Translation with Byte-Level Subwords},
    author  = {Changhan Wang and Kyunghyun Cho and Jiatao Gu},
    booktitle = {AAAI Conference on Artificial Intelligence},
    year    = {2019}
}
@inproceedings{provilkov-etal-2020-bpe,
    title   = "{BPE}-Dropout: Simple and Effective Subword Regularization",
    author  = "Provilkov, Ivan  and Emelianenko, Dmitrii  and Voita, Elena",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://aclanthology.org/2020.acl-main.170",
    doi     = "10.18653/v1/2020.acl-main.170",
    pages   = "1882--1892",
}

.\lucidrains\nuwa-pytorch\nuwa_pytorch\image_utils.py

# 导入 torch 库
import torch
# 导入 torchvision.transforms 库并重命名为 T
import torchvision.transforms as T
# 从 PIL 库中导入 Image 类
from PIL import Image

# 定义常量

# 通道数到模式的映射关系
CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

# 遍历所有图像帧
def seek_all_images(img, channels = 3):
    # 检查通道数是否在 CHANNELS_TO_MODE 中
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    # 获取对应通道数的图像模式
    mode = CHANNELS_TO_MODE[channels]

    # 初始化帧数为 0
    i = 0
    # 循环直到遇到异常
    while True:
        try:
            # 尝试定位到第 i 帧
            img.seek(i)
            # 将图像转换为指定模式
            yield img.convert(mode)
        except EOFError:
            # 遇到文件结尾异常时退出循环
            break
        # 帧数加一
        i += 1

# 将张量转换为 GIF 图像
def video_tensor_to_gif(tensor, path, duration = 80, loop = 0, optimize = True):
    # 将张量中的每一帧转换为 PIL 图像
    images = map(T.ToPILImage(), tensor.unbind(0))
    # 获取第一帧图像和剩余图像
    first_img, *rest_imgs = images
    # 保存 GIF 图像到指定路径
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
    # 返回图像列表
    return images

# 将 GIF 图像转换为张量 (帧数, 通道数, 高度, 宽度)
def gif_to_tensor(path, channels = 3):
    # 打开 GIF 图像
    img = Image.open(path)
    # 获取图像中的每一帧并转换为张量
    tensors = tuple(map(T.ToTensor(), seek_all_images(img, channels = channels)))
    # 沿着第 0 维度堆叠张量
    return torch.stack(tensors, dim = 0)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\nuwa_pytorch.py

# 导入 functools 模块
import functools
# 从 functools 模块中导入 partial 函数
from functools import partial

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn、einsum 模块
from torch import nn, einsum
# 从 torch 模块中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange、reduce、repeat 函数
from einops import rearrange, reduce, repeat
# 从 einops.layers.torch 模块中导入 Rearrange、Reduce 类
from einops.layers.torch import Rearrange, Reduce

# 导入 nuwa_pytorch 模块中的 ReversibleSequence、DualModalityReversibleSequence 类
from nuwa_pytorch.reversible import ReversibleSequence
from nuwa_pytorch.reversible_video_audio import DualModalityReversibleSequence

# 导入 unfoldNd 模块
from unfoldNd import unfoldNd

# 导入 tqdm 模块
from tqdm import tqdm

# 常量定义

# 定义 MList 为 nn.ModuleList 类
MList = nn.ModuleList

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组的函数
def cast_tuple(val, size = 1):
    return val if isinstance(val, tuple) else (val,) * size

# 计算相同填充的函数
def calc_same_padding(kernel_size, dilation = 1):
    return dilation * (kernel_size - 1) // 2

# 将填充值调整为倍数的函数
def padding_to_multiple_of(n, mult):
    remainder = n % mult
    if remainder == 0:
        return 0
    return mult - remainder

# 装饰器

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 张量辅助函数

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# sigmoid 函数
def sigmoid(t):
    return torch.where(t >= 0, 1 / (1 + torch.exp(-t)), t.exp() / (1 + t.exp()))

# 生成 Gumbel 噪声的函数
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 采样的函数
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 安全除法函数
def safe_div(numer, denom, eps = 1e-6):
    return numer / (denom + eps)

# 生成概率掩码的函数
def prob_mask_like(shape, prob, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 批处理函数
def batch_process(t, fn, chunks = 10, dim = 0):
    chunks = [fn(t_chunk) for t_chunk in t.chunk(chunks, dim = dim)]
    return torch.cat(chunks, dim = dim)

# 多重归约函数
def mult_reduce(arr):
    return functools.reduce(lambda x, y: x * y, arr, 1)

# 梯度控制

# 分数梯度函数
def frac_gradient(t, frac):
    return t * frac + t.detach() * (1 - frac)

# 标准化

# 稳定的 LayerNorm 类
class StableLayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = x / x.amax(dim = -1, keepdim = True).detach()
        return self.norm(x)

# 预标准化类
class PreNorm(nn.Module):
    def __init__(
        self,
        *,
        dim,
        fn
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 三明治标准化类
class SandwichNorm(nn.Module):
    def __init__(
        self,
        *,
        dim,
        fn
    ):
        super().__init__()
        self.prenorm = nn.LayerNorm(dim)
        self.postnorm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.prenorm(x)
        x = self.fn(x, **kwargs)
        x = self.postnorm(x)
        return x

# 相对位置嵌入(旋转)

# 旋转嵌入类
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, seq_len, device):
        inv_freq = self.inv_freq
        t = torch.arange(seq_len, device = device).type_as(inv_freq)
        freqs = torch.einsum('i , j -> i j', t, inv_freq)
        return torch.cat((freqs, freqs), dim = -1)

# 旋转半个维度的函数
def rotate_half(x):
    x = rearrange(x, '... (j d) -> ... j d', j = 2)
    x1, x2 = x.unbind(dim = -2)
    return torch.cat((-x2, x1), dim = -1)

# 应用旋转位置嵌入的函数
def apply_rotary_pos_emb(freqs, t):
    rot_dim = freqs.shape[-1]
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
    t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
    return torch.cat((t, t_pass), dim = -1)

# 辅助类

# 移位音频令牌类
class ShiftAudioTokens(nn.Module):
    # 初始化函数,设置音频每个时间步的音频标记数
    def __init__(
        self,
        fn,
        audio_tokens_per_timestep = 1
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 保存文件名和音频每个时间步的音频标记数
        self.fn = fn
        self.audio_tokens_per_timestep = audio_tokens_per_timestep

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入张量的第二个维度大小
        n = x.shape[1]

        # 填充到最近的时间步

        # 计算需要填充的数量
        padding = self.audio_tokens_per_timestep - (n % self.audio_tokens_per_timestep)
        # 在第二维度上进行填充
        x = F.pad(x, (0, 0, 0, padding), value = 0.)

        # 沿着时间轴进行移动

        # 将输入张量分成两部分
        x_shift, x = x.chunk(2, dim = -1)
        # 在第二维度上进行填充
        x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
        # 拼接两部分张量
        x = torch.cat((x_shift, x), dim = -1)

        # 如果需要,移除填充

        # 返回处理后的结果
        return self.fn(x[:, :n], **kwargs)
class ShiftVideoTokens(nn.Module):
    # 定义 ShiftVideoTokens 类,用于处理视频序列的移位操作
    def __init__(
        self,
        fn,
        image_size,
        shift_space = True,
        shift_time = False
    ):
        # 初始化函数,接收函数 fn、图像大小 image_size、是否移位空间 shift_space、是否移位时间 shift_time 作为参数
        super().__init__()
        self.fn = fn
        self.image_size = image_size

        self.shift_time = shift_time
        self.shift_space = shift_space

    def forward(self, x, **kwargs):
        # 前向传播函数,接收输入 x 和其他关键字参数 kwargs
        if not self.shift_time and not self.shift_space:
            return self.fn(x, **kwargs)

        image_size = self.image_size
        img_seq_len = image_size ** 2

        x_bos, x_video = x[:, :1], x[:, 1:]
        n = x_video.shape[1]

        # pad to nearest frame
        # 填充到最近的帧

        padding = img_seq_len - (n % img_seq_len)
        x_video = F.pad(x_video, (0, 0, 0, padding), value = 0.)

        # reshape to video
        # 重塑为视频

        x_video = rearrange(x_video, 'b (f h w) d -> b f h w d', h = image_size, w = image_size)

        x_image_h = x_image_w = x_frame = None

        # chunk depending on whether shifting time, space, or both
        # 根据是否移位时间、空间或两者来分块

        if self.shift_space and self.shift_time:
            x_frame, x_image_h, x_image_w, *x_rest = x_video.chunk(5, dim = -1)
        elif self.shift_space:
            x_image_h, x_image_w, *x_rest = x_video.chunk(4, dim = -1)
        elif self.shift_time:
            x_frame, *x_rest = x_video.chunk(3, dim = -1)

        # shifts
        # 移位操作

        if self.shift_space:
            x_image_h = F.pad(x_image_h, (0, 0, 0, 0, 1, -1))
            x_image_w = F.pad(x_image_w, (0, 0, 1, -1))

        if self.shift_time:
            x_frame = F.pad(x_frame, (0, 0, 0, 0, 0, 0, 1, -1))

        # concat
        # 连接操作

        x_shifted = [x_frame, x_image_h, x_image_w, *x_rest]
        x_shifted = list(filter(exists, x_shifted))

        x_video = torch.cat(x_shifted, dim = -1)

        # merge text and image sequence back together
        # 将文本和图像序列合并在一起

        x_video = rearrange(x_video, 'b f h w d -> b (f h w) d')
        x_video = x_video[:, :n]

        x = torch.cat((x_bos, x_video), dim = 1)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    # 定义 GEGLU 类,用于实现 Gated Linear Unit 激活函数
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

class FeedForward(nn.Module):
    # 定义 FeedForward 类,用于实现前馈神经网络
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        dropout = 0.,
        chunk_size = None,  # chunk size to process feedforward, along sequence length, from Reformer paper. None means do not chunk
    ):
        super().__init__()
        inner_dim = (dim * mult * 2) // 3
        self.chunk_size = chunk_size

        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim * 2, bias = False),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim, bias = False)
        )

    def forward(self, x):
        if not exists(self.chunk_size):
            return self.net(x)

        x_chunks = x.split(self.chunk_size, dim = -2)
        out_chunks = [self.net(c) for c in x_chunks]
        return torch.cat(out_chunks, dim = -2)

# attention classes

class Attention(nn.Module):
    # 定义 Attention 类,用于实现注意力机制
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = False,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.causal = causal
        self.scale = dim_head ** -0.5

        self.null_k = nn.Parameter(torch.randn(heads, 1, dim_head))
        self.null_v = nn.Parameter(torch.randn(heads, 1, dim_head))

        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None,
        rotary_pos_emb = None
        ):
        # 获取输入张量 x 的 batch 大小、头数、设备信息
        b, h, device = x.shape[0], self.heads, x.device

        # 检查是否存在上下文信息
        has_context = exists(context)
        # 如果存在上下文信息,则将上下文信息作为键值对输入
        kv_input = context if has_context else x

        # 将输入张量 x 转换为查询向量 q,键值对转换为 k 和 v
        qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
        # 将查询 q、键 k、值 v 重排为 batch、头数、序列长度、维度的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # 如果不存在上下文信息且存在旋转位置嵌入,则应用旋转位置嵌入
        if not has_context and exists(rotary_pos_emb):
            apply_rotary = partial(apply_rotary_pos_emb, rotary_pos_emb)
            q, k, v = map(apply_rotary, (q, k, v))

        # 添加空键/值,用于条件丢弃
        null_k = repeat(self.null_k, 'h 1 d -> b h 1 d', b = b)
        null_v = repeat(self.null_v, 'h 1 d -> b h 1 d', b = b)

        # 将空键值与原始键值连接起来
        k = torch.cat((null_k, k), dim = -2)
        v = torch.cat((null_v, v), dim = -2)

        # 缩放
        q = q * self.scale

        # 相似度计算
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 掩码值
        mask_value = -torch.finfo(x.dtype).max

        # 如果存在键掩码,则对相似度矩阵进行掩码处理
        key_mask = mask if not has_context else context_mask
        if exists(key_mask):
            key_mask = F.pad(key_mask, (1, 0), value = True) # 始终注意空键/值
            key_mask = rearrange(key_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~key_mask, mask_value)

        # 如果是因果注意力,则对相似度矩阵进行掩码处理
        if self.causal:
            i, j = sim.shape[-2:]
            mask = torch.ones(i, j, device = device, dtype = torch.bool).triu_(j - i + 1)
            sim = sim.masked_fill(mask, mask_value)

        # 注意力权重计算
        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = self.talking_heads(attn)
        attn = self.dropout(attn)

        # 聚合、合并和组合头
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 Sparse3DNA 的神经网络模块
class Sparse3DNA(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        video_shape,
        kernel_size = 3,
        dilation = 1,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        query_num_frames_chunk = None,
        rel_pos_bias = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        inner_dim = dim_head * heads
        # 设置头数和缩放因子
        self.heads = heads
        self.scale = dim_head ** -0.5

        # 初始化 dropout 层和线性变换层
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 初始化 talking heads 和输出层
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        # 转换为元组并确保卷积核大小为奇数
        self.dilation = cast_tuple(dilation, size = 3)
        self.kernel_size = cast_tuple(kernel_size, size = 3)
        assert all(map(lambda n: n % 2 == 1, self.kernel_size)), 'kernel size must be odd'

        # 计算卷积核元素数量
        self.kernel_numel = mult_reduce(self.kernel_size)

        # 如果需要,为每个头计算相对位置偏置
        self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if rel_pos_bias else None

        # 计算填充
        self.padding_frame = calc_same_padding(self.kernel_size[0], self.dilation[0])
        self.padding_height = calc_same_padding(self.kernel_size[1], self.dilation[1])
        self.padding_width = calc_same_padding(self.kernel_size[2], self.dilation[2])

        # 根据是否是因果卷积使用不同的填充
        if causal:
            self.video_padding = (self.padding_width * 2, 0, self.padding_height * 2, 0, self.padding_frame * 2, 0)
        else:
            self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame)

        # 保存视频形状并计算最大令牌数量
        self.video_shape = video_shape
        max_frames, fmap_size, _ = video_shape
        max_num_tokens = torch.empty(video_shape).numel()
        self.max_num_tokens = max_num_tokens

        # 限制内存使用,一次处理多少查询令牌
        self.query_num_frames_chunk = default(query_num_frames_chunk, max_frames)

        # 预先计算因果掩码
        ones = torch.ones((max_num_tokens,))
        ones = rearrange(ones, '(f h w) -> 1 1 f h w', f = max_frames, h = fmap_size, w = fmap_size)
        ones = F.pad(ones, self.video_padding, value = 0.)
        ones = unfoldNd(ones, kernel_size = self.kernel_size, dilation = self.dilation)
        ones = rearrange(ones, '1 k n -> n k')

        # 掩盖填充
        padding_mask = ones == 0.

        # bos 令牌永远不会被掩盖
        mask = F.pad(padding_mask, (1, 0), value = False)
        self.register_buffer('mask', mask)

# 定义一个名为 SparseCausal2DNA 的神经网络模块
class SparseCausal2DNA(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,
        height = 1,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        kernel_size = 5,
        dilation = 1,
        rel_pos_bias = False
    # 定义 Transformer 层
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 计算每个头的维度
        inner_dim = heads * dim_head
        self.heads = heads
        # 缩放因子
        self.scale = dim_head ** -0.5

        # 定义用于交互的卷积层
        self.talking_heads = nn.Conv3d(heads, heads, 1, bias = False)
        # Dropout 层
        self.dropout = nn.Dropout(dropout)
        # 线性变换,将输入维度转换为内部维度的三倍
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 线性变换,将内部维度转换为输出维度
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # 处理用于展开的变量

        # 高度信息,宽度为序列长度,时间轴 - (batch, seq) -> (batch, time, height)
        self.height = height
        # 卷积核大小
        self.kernel_size = (kernel_size, height)
        # 膨胀率
        self.dilation = (dilation, 1)
        # 因果填充
        self.causal_padding = (0, 0, calc_same_padding(kernel_size, dilation) * 2, 0)
        # 相对位置偏置
        self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if exists(rel_pos_bias) else None

        # 因果掩码

        # 注册缓冲区变量 mask
        self.register_buffer('mask', None, persistent = False)

    # 获取掩码
    def get_mask(self, t):
        # 如果 mask 存在且 mask 的倒数第三维与 t 的倒数第三维相同,则返回 mask
        if exists(self.mask) and self.mask.shape[-3] == t.shape[-3]:
            return self.mask

        device, seq_len = t.device, t.shape[-3] * self.height

        # 创建全为 1 的张量
        ones = torch.ones((seq_len,), device = device)
        # 重排张量维度
        ones = rearrange(ones, '(n m) -> 1 1 n m', m = self.height)

        # 对全为 1 的张量进行填充
        ones = F.pad(ones, self.causal_padding, value = 0.)
        # 展开张量
        ones = unfoldNd(ones, kernel_size = self.kernel_size, dilation = self.dilation)
        ones = rearrange(ones, '1 d n -> n d')

        # 创建填充掩码
        padding_mask = rearrange(ones, 'n j -> n 1 j') == 0.
        mask = F.pad(padding_mask, (1, 0), value = False)

        # 注册缓冲区变量 mask
        self.register_buffer('mask', mask, persistent = False)
        return mask

    # 前向传播方法
    def forward(
        self,
        x,
        **kwargs
        ):
            # 获取输入张量的维度信息
            b, n, h, device = x.shape[0], x.shape[1], self.heads, x.device

            # 计算每个时间步的标记数和卷积核元素数
            tokens_per_timestep = self.height
            kernel_numel = self.kernel_size[0] * self.kernel_size[1]

            # 填充到正确的长度

            bos_only = n == 1
            seq_pad = padding_to_multiple_of(n - 1, tokens_per_timestep)

            # 为视频中的最后一个标记进行填充

            padded_x = F.pad(x, (0, 0, 0, seq_pad), value = 0.) if seq_pad > 0 else x

            # 推导查询、键、值

            q, k, v = self.to_qkv(padded_x).chunk(3, dim = -1)

            # 处理仅有 bos 的情况

            if bos_only:
                return self.to_out(v)

            out_bos = v[:, :1]

            # 分割头部

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

            # 缩放

            q = q * self.scale

            # 处理 bos

            (q_bos, q), (k_bos, k), (v_bos, v) = map(lambda t: (t[:, :, 0], t[:, :, 1:]), (q, k, v))

            # 重塑键/值以进行展开

            k, v = map(lambda t: rearrange(t, 'b h (x y) d -> (b h) d x y ', y = tokens_per_timestep), (k, v))
            k, v = map(lambda t: F.pad(t, self.causal_padding), (k, v))
            k, v = map(lambda t: F.unfold(t, kernel_size = self.kernel_size, dilation = self.dilation), (k, v))
            k, v = map(lambda t: rearrange(t, '(b h f) (d j) i -> b h i (f j) d', b = b, h = h, j = kernel_numel), (k, v))

            # 添加 bos

            k_bos_repeated, v_bos_repeated = map(lambda t: repeat(t, 'b h d -> b h i 1 d', i = k.shape[-3]), (k_bos, v_bos))
            k = torch.cat((k_bos_repeated, k), dim = -2)
            v = torch.cat((v_bos_repeated, v), dim = -2)

            q = rearrange(q, 'b h (x y) d -> b h x y d', y = tokens_per_timestep)

            sim = einsum('b h n i d, b h n j d -> b h n i j', q, k)

            # 相对位置偏置

            if exists(self.rel_pos_bias):
                rel_pos_bias = self.rel_pos_bias()
                rel_pos_bias = rearrange(rel_pos_bias, 'j h -> h 1 1 j')
                rel_pos_bias = F.pad(rel_pos_bias, (1, 0), value = 0.)
                sim = sim + rel_pos_bias

            # 因果 + 填充掩码

            mask_value = -torch.finfo(x.dtype).max
            mask = self.get_mask(sim)
            sim = sim.masked_fill(mask, mask_value)

            # 注意力

            attn = sim.softmax(dim = -1, dtype = torch.float32)
            attn = self.talking_heads(attn)
            attn = self.dropout(attn)

            # 聚合、合并和组合头部

            out = einsum('b h n i j, b h n j d -> b h n i d', attn, v)
            out = rearrange(out, 'b h x y d -> b (x y) (h d)')

            # 将 bos 的输出添加回去

            out = torch.cat((out_bos, out), dim = -2)

            return self.to_out(out[:, :n])
# 定义一个名为 SparseCross2DNA 的神经网络模块
class SparseCross2DNA(nn.Module):
    # 初始化函数,接受一些参数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        image_size,  # 图像大小
        heads = 8,  # 多头注意力机制中的头数,默认为8
        dim_head = 64,  # 每个头的维度,默认为64
        dropout = 0.,  # Dropout 概率,默认为0
        kernel_size = 3,  # 卷积核大小,默认为3
        dilation = 1,  # 膨胀率,默认为1
    ):
        super().__init__()  # 调用父类的初始化函数
        inner_dim = heads * dim_head  # 内部维度为头数乘以每个头的维度
        self.heads = heads  # 多头注意力机制中的头数
        self.scale = dim_head ** -0.5  # 缩放因子

        # 初始化可学习参数 null_k 和 null_v
        self.null_k = nn.Parameter(torch.randn(heads, 1, dim_head))
        self.null_v = nn.Parameter(torch.randn(heads, 1, dim_head))

        # 初始化可学习参数 talking_heads,用于多头注意力机制
        self.talking_heads = nn.Conv3d(heads, heads, 1, bias = False)
        self.dropout = nn.Dropout(dropout)  # 初始化 Dropout 层
        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 输入到查询向量的线性变换
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)  # 输入到键值对的线性变换
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出线性变换

        # 处理 2D 展开的变量

        self.image_size = image_size  # 图像大小
        self.kernel_size = kernel_size  # 卷积核大小
        self.dilation = dilation  # 膨胀率
        self.padding = calc_same_padding(kernel_size, dilation)  # 计算填充大小

    # 前向传播函数,接受输入 x 和一些关键字参数
    def forward(
        self,
        x,
        *,
        context,  # 上下文信息
        context_mask = None,  # 上下文掩码,默认为 None
        **kwargs  # 其他关键字参数
        ):
            # 获取输入张量的维度信息
            b, n, h, device = x.shape[0], x.shape[1], self.heads, x.device

            # 获取模型参数的相关信息
            fmap_size, kernel_size, dilation, padding = self.image_size, self.kernel_size, self.dilation, self.padding

            # 计算上下文长度、每帧的标记数、卷积核元素数
            context_len = context.shape[-2]
            tokens_per_frame = fmap_size * fmap_size
            kernel_numel = kernel_size * kernel_size

            # 如果上下文掩码不存在,则创建一个全为 True 的掩码
            if not exists(context_mask):
                context_mask = torch.ones((b, context_len), dtype=torch.bool, device=device)

            # 初始化掩码值
            mask_value = -torch.finfo(x.dtype).max

            # 生成查询、键、值
            qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

            # 缩放查询
            q = q * self.scale

            # 处理 bos
            q_bos, q = q[:, :, 0], q[:, :, 1:]

            null_k_for_bos = repeat(self.null_k, 'h 1 d -> b h 1 d', b=b)
            null_v_for_bos = repeat(self.null_v, 'h 1 d -> b h 1 d', b=b)

            k_for_bos = torch.cat((null_k_for_bos, k), dim=-2)
            v_for_bos = torch.cat((null_v_for_bos, v), dim=-2)

            sim_bos = einsum('b h d, b h j d -> b h j', q_bos, k_for_bos)

            bos_context_mask = rearrange(context_mask, 'b j -> b 1 j')
            bos_context_mask = F.pad(bos_context_mask, (1, 0), value=True)
            sim_bos = sim_bos.masked_fill(~bos_context_mask, mask_value)

            attn_bos = sim_bos.softmax(dim=-1, dtype=torch.float32)
            out_bos = einsum('b h j, b h j d -> b h d', attn_bos, v_for_bos)
            out_bos = rearrange(out_bos, 'b h d -> b 1 (h d)')

            # 如果只有一个标记,则直接返回结果
            if n == 1:
                return self.to_out(out_bos)

            # 重塑键/值以进行展开
            k, v = map(lambda t: rearrange(t, 'b h (f x y) d -> (b h f) d x y', x=fmap_size, y=fmap_size), (k, v))
            k, v = map(lambda t: F.unfold(t, kernel_size=kernel_size, dilation=dilation, padding=padding), (k, v))
            k, v = map(lambda t: rearrange(t, '(b h f) (d j) i -> b h i (f j) d', b=b, h=h, j=kernel_numel), (k, v))

            # 添加空键/值,用于条件丢弃
            null_k = repeat(self.null_k, 'h 1 d -> b h i 1 d', b=b, i=tokens_per_frame)
            null_v = repeat(self.null_v, 'h 1 d -> b h i 1 d', b=b, i=tokens_per_frame)

            k = torch.cat((null_k, k), dim=-2)
            v = torch.cat((null_v, v), dim=-2)

            # 将查询填充到最近的帧
            q_padding = padding_to_multiple_of(q.shape[-2], tokens_per_frame)
            q = F.pad(q, (0, 0, 0, q_padding), value=0.)

            q = rearrange(q, 'b h (f i) d -> b h f i d', i=tokens_per_frame)

            # 计算相似度
            sim = einsum('b h f i d, b h i j d -> b h f i j', q, k)

            # 掩码
            context_mask = rearrange(context_mask, 'b (f x y) -> (b f) 1 x y', x=fmap_size, y=fmap_size)
            context_mask = F.unfold(context_mask.float(), kernel_size=kernel_size, dilation=dilation, padding=padding)
            context_mask = context_mask == 1.
            context_mask = rearrange(context_mask, '(b f) j i -> b 1 1 i (f j)', b=b, j=kernel_numel)
            context_mask = F.pad(context_mask, (1, 0), value=True)  # 总是关注空键/值

            sim = sim.masked_fill(~context_mask, mask_value)

            # 注意力
            attn = sim.softmax(dim=-1, dtype=torch.float32)
            attn = self.talking_heads(attn)
            attn = self.dropout(attn)

            # 聚合、合并和组合头
            out = einsum('b h f i j, b h i j d -> b h f i d', attn, v)
            out = rearrange(out, 'b h f n d -> b (f n) (h d)')

            # 将 bos 的输出添加回去
            out = torch.cat((out_bos, out), dim=1)

            return self.to_out(out[:, :n])
"""
用于实现高效的音频 <-> 视频注意力机制
主要灵感来源于 https://arxiv.org/abs/2112.04426 中的块交叉注意力机制
"""

class CrossModalityCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入维度
        chunk_size,  # 块大小
        context_chunk_size,  # 上下文块大小
        heads = 8,  # 头数
        dim_head = 64,  # 每个头的维度
        context_dim = None,  # 上下文维度,默认为None
        has_start_token = True,  # 是否有起始标记
        context_has_start_token = True,  # 上下文是否有起始标记
        norm = False,  # 是否进行归一化
        norm_context = False,  # 上下文是否进行归一化
        dropout = 0.  # 丢弃概率
    ):
        super().__init__()
        context_dim = default(context_dim, dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim  = dim_head * heads

        self.norm = nn.LayerNorm(dim) if norm else nn.Identity()  # 归一化层
        self.context_norm = nn.LayerNorm(context_dim) if norm_context else nn.Identity()  # 上下文归一化层

        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 查询线性层
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)  # 键值线性层
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出线性层

        self.null_k = nn.Parameter(torch.randn(heads, dim_head))  # 空键参数
        self.null_v = nn.Parameter(torch.randn(heads, dim_head))  # 空值参数

        self.talking_heads = nn.Conv3d(heads, heads, 1)  # 三维卷积层
        self.dropout = nn.Dropout(dropout)  # 丢弃层

        self.has_start_token = has_start_token  # 是否有起始标记
        self.context_has_start_token = context_has_start_token  # 上下文是否有起始标记

        self.chunk_size = chunk_size  # 块大小
        self.context_chunk_size = context_chunk_size  # 上下文块大小

    def forward(
        self,
        seq,  # 序列输入
        context,  # 上下文输入
        mask = None,  # 掩码
        context_mask = None  # 上下文掩码

# transformer

class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入维度
        depth,  # 深度
        causal = False,  # 是否因果
        heads = 8,  # 头数
        dim_head = 64,  # 每个头的维度
        ff_mult = 4,  # FeedForward 层的倍增因子
        cross_attend = False,  # 是否跨模态注意力
        attn_dropout = 0.,  # 注意力丢弃概率
        ff_dropout = 0.,  # FeedForward 层的丢弃概率
        ff_chunk_size = None,  # FeedForward 层的块大小
        cross_2dna_attn = False,  # 是否跨 2DNA 注意力
        cross_2dna_image_size = None,  # 跨 2DNA 图像大小
        cross_2dna_kernel_size = 3,  # 跨 2DNA 卷积核大小
        cross_2dna_dilations = (1,),  # 跨 2DNA 膨胀率
        sparse_3dna_attn = False,  # 是否稀疏 3DNA 注意力
        sparse_3dna_kernel_size = 3,  # 稀疏 3DNA 卷积核大小
        sparse_3dna_video_shape = None,  # 稀疏 3DNA 视频形状
        sparse_3dna_query_num_frames_chunk = None,  # 稀疏 3DNA 查询帧块数
        sparse_3dna_dilations = (1,),  # 稀疏 3DNA 膨胀率
        sparse_3dna_rel_pos_bias = False,  # 稀疏 3DNA 相对位置偏置
        shift_video_tokens = False,  # 是否移动视频标记
        rotary_pos_emb = False  # 是否使用旋转位置嵌入
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 断言条件,如果不满足则抛出异常
        assert not (sparse_3dna_attn and not exists(sparse_3dna_video_shape)), 'sparse_3dna_video_shape must be defined if turned on'
        assert not (cross_2dna_attn and not exists(cross_2dna_image_size)), 'cross_2dna_image_size must be defined'

        # 初始化层列表
        self.layers = MList([])

        # 循环创建多个层
        for ind in range(depth):
            if sparse_3dna_attn:
                # 如果启用了稀疏3DNA注意力机制
                dilation = sparse_3dna_dilations[ind % len(sparse_3dna_dilations)]

                self_attn = Sparse3DNA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    kernel_size = sparse_3dna_kernel_size,
                    dilation = dilation,
                    video_shape = sparse_3dna_video_shape,
                    query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
                    rel_pos_bias = sparse_3dna_rel_pos_bias,
                )
            else:
                # 否则使用普通的注意力机制
                self_attn = Attention(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    dropout = attn_dropout
                )

            cross_attn = None

            if cross_attend:
                if cross_2dna_attn:
                    # 如果启用了交叉2DNA注意力机制
                    dilation = cross_2dna_dilations[ind % len(cross_2dna_dilations)]

                    cross_attn = SparseCross2DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        image_size = cross_2dna_image_size,
                        kernel_size = cross_2dna_kernel_size,
                        dilation = dilation
                    )

                else:
                    # 否则使用普通的注意力机制
                    cross_attn = Attention(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout
                    )

            # 创建前馈神经网络层
            ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, chunk_size = ff_chunk_size)

            if sparse_3dna_attn and shift_video_tokens:
                # 如果启用了稀疏3DNA注意���机制并且需要移动视频标记
                fmap_size = sparse_3dna_video_shape[-1]
                self_attn = ShiftVideoTokens(self_attn, image_size = fmap_size)
                ff        = ShiftVideoTokens(ff, image_size = fmap_size)

            # 将当前层的各个组件添加到层列表中
            self.layers.append(MList([
                SandwichNorm(dim = dim, fn = self_attn),
                SandwichNorm(dim = dim, fn = cross_attn) if cross_attend else None,
                SandwichNorm(dim = dim, fn = ff)
            ]))

        # 初始化稳定层归一化
        self.norm = StableLayerNorm(dim)

    # 前向传播方法
    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None
    ):
        # 遍历所有层
        for attn, cross_attn, ff in self.layers:
            # 使用自注意力机制更新输入
            x = attn(x, mask = mask) + x

            # 如果存在交叉注意力机制
            if exists(cross_attn):
                # 使用交叉注意力机制更新输入
                x = cross_attn(x, context = context, mask = mask, context_mask = context_mask) + x

            # 使用前馈神经网络更新输入
            x = ff(x) + x

        # 对输出进行稳定层归一化
        return self.norm(x)
# 定义一个可逆的 Transformer 模型类
class ReversibleTransformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,  # 模型维度
        depth,  # 模型深度
        causal = False,  # 是否使用因果注意力
        heads = 8,  # 多头注意力的头数
        dim_head = 64,  # 每个头的维度
        ff_mult = 4,  # FeedForward 层的倍数
        cross_attend = False,  # 是否使用跨层注意力
        attn_dropout = 0.,  # 注意力层的 dropout 概率
        ff_dropout = 0.,  # FeedForward 层的 dropout 概率
        ff_chunk_size = None,  # FeedForward 层的分块大小
        cross_2dna_attn = False,  # 是否使用跨 2D 和 1D 注意力
        cross_2dna_image_size = None,  # 跨 2D 和 1D 注意力的图像大小
        cross_2dna_kernel_size = 3,  # 跨 2D 和 1D 注意力的卷积核大小
        cross_2dna_dilations = (1,),  # 跨 2D 和 1D 注意力的膨胀系数
        sparse_3dna_attn = False,  # 是否使用稀疏 3D 和 1D 注意力
        sparse_3dna_kernel_size = 3,  # 稀疏 3D 和 1D 注意力的卷积核大小
        sparse_3dna_video_shape = None,  # 稀疏 3D 和 1D 注意力的视频形状
        sparse_3dna_query_num_frames_chunk = None,  # 稀疏 3D 和 1D 注意力的查询帧数块大小
        sparse_3dna_dilations = (1,),  # 稀疏 3D 和 1D 注意力的膨胀系数
        sparse_3dna_rel_pos_bias = False,  # 稀疏 3D 和 1D 注意力是否使用相对位置偏置
        shift_video_tokens = False,  # 是否对视频 token 进行位移
        rotary_pos_emb = False  # 是否使用旋转位置编码
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言条件,如果不满足则抛出异常
        assert not (sparse_3dna_attn and not exists(sparse_3dna_video_shape)), 'sparse_3dna_video_shape must be defined if turned on'
        assert not (cross_2dna_attn and not exists(cross_2dna_image_size)), 'cross_2dna_image_size must be defined'

        # 初始化层列表
        self.layers = MList([])

        # 循环创建网络层
        for ind in range(depth):
            if sparse_3dna_attn:
                # 获取稀疏3DNA注意力机制的参数
                dilation = sparse_3dna_dilations[ind % len(sparse_3dna_dilations)]
                image_size = sparse_3dna_video_shape[-1]

                # 创建稀疏3DNA自注意力层
                self_attn = Sparse3DNA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    kernel_size = sparse_3dna_kernel_size,
                    dilation = dilation,
                    video_shape = sparse_3dna_video_shape,
                    query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
                    rel_pos_bias = sparse_3dna_rel_pos_bias,
                )
            else:
                image_size = None

                # 创建普通自注意力层
                self_attn = Attention(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    dropout = attn_dropout
                )

            # 创建包装函数
            wrapper_fn = partial(ShiftVideoTokens, image_size = image_size, shift_space = sparse_3dna_attn and shift_video_tokens)

            # 添加自注意力层和前馈网络层到层列表
            self.layers.append(MList([
                SandwichNorm(dim = dim, fn = wrapper_fn(self_attn)),
                SandwichNorm(dim = dim, fn = wrapper_fn(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, chunk_size = ff_chunk_size)))
            ]))

            # 如果不需要交叉注意力,则继续下一轮循环
            if not cross_attend:
                continue

            if cross_2dna_attn:
                # 获取交叉2DNA注意力机制的参数
                dilation = cross_2dna_dilations[ind % len(cross_2dna_dilations)]

                # 创建交叉2DNA注意力层
                cross_attn = SparseCross2DNA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    dropout = attn_dropout,
                    image_size = cross_2dna_image_size,
                    kernel_size = cross_2dna_kernel_size,
                    dilation = dilation
                )
            else:
                # 创建普通交叉注意力层
                cross_attn = Attention(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    dropout = attn_dropout
                )

            # 添加交叉注意力层和前馈网络层到层列表
            self.layers.append(MList([
                SandwichNorm(dim = dim, fn = cross_attn),
                SandwichNorm(dim = dim, fn = wrapper_fn(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, chunk_size = ff_chunk_size)))
            ]))

        # 设置注意力上下文层和路由
        attn_context_layer = ((True, False),) if cross_attend else tuple()
        route_attn = ((True, False), *attn_context_layer) * depth
        route_context = ((False, False), *attn_context_layer) * depth

        # 设置上下文路由映射和注意力路由映射
        context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {}
        attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}

        # 创建可逆序列网络
        self.net = ReversibleSequence(self.layers, args_route = {**context_route_map, **attn_route_map})
        # 创建稳定层归一化
        self.norm = StableLayerNorm(dim)

    # 前向传播函数
    def forward(
        self,
        x,
        **kwargs
    ):
        # 使用网络进行前向传播
        x = self.net(x, **kwargs)
        # 对结果进行归一化处理
        return self.norm(x)
# 双模态解码器(用于视频和音频合成)

class DualModalityDecoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        num_audio_tokens_per_video_frame,
        num_video_tokens_per_frame,
        sparse_3dna_video_shape,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilations = (1,),
        sparse_3dna_rel_pos_bias = False,
        sparse_2dna_kernel_size = 7,
        sparse_2dna_dilation = (1,),
        sparse_2dna_rel_pos_bias = False,
        shift_video_tokens = False,
        shift_audio_tokens = False,
        audio_tokens_per_timestep = 1,
        cross_modality_attn_every = 3
    # 定义前向传播函数
    def forward(
        self,
        video,
        audio,
        *,
        context,
        audio_mask = None,
        video_mask = None,
        context_mask = None,
        **kwargs
    ):
        # 遍历每个块和层类型
        for blocks, layer_type in zip(self.layers, self.layer_types):
            # 如果层类型为'intra_modality'
            if layer_type == 'intra_modality':
                # 解压块
                (video_self_attn, video_cross_attn, video_ff), (audio_self_attn, audio_cross_attn, audio_ff) = blocks

                # 视频自注意力机制
                video_ = video_self_attn(video, mask = video_mask) + video
                video_ = video_cross_attn(video_, context = context, mask = video_mask, context_mask = context_mask) + video_
                video_ = video_ff(video_) + video_

                # 音频自注意力机制
                audio_ = audio_self_attn(audio, mask = audio_mask) + audio
                audio_ = audio_cross_attn(audio_, context = context, mask = audio_mask, context_mask = context_mask) + audio_
                audio_ = audio_ff(audio_) + audio_

            # 如果层类型为'inter_modality'
            elif layer_type == 'inter_modality':
                # 解压块
                (video_to_audio_attn, video_ff), (audio_to_video_attn, audio_ff) = blocks

                # 视频到音频的注意力机制
                video_ = video_to_audio_attn(
                    video,
                    context = audio,
                    mask = video_mask,
                    context_mask = audio_mask
                ) + video

                # 音频到视频的注意力机制
                audio_ = audio_to_video_attn(
                    audio,
                    context = video,
                    mask = audio_mask,
                    context_mask = video_mask
                ) + audio

                video_ = video_ff(video_) + video_
                audio_ = audio_ff(audio_) + audio_
            else:
                raise ValueError(f'unknown layer type {layer_type}')

            video, audio = video_, audio_

        # 返回视频和音频的归一化结果
        return self.video_norm(video), self.audio_norm(audio)

# 可逆双模态解码器

class ReversibleDualModalityDecoder(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        num_audio_tokens_per_video_frame,
        num_video_tokens_per_frame,
        sparse_3dna_video_shape,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilations = (1,),
        sparse_3dna_rel_pos_bias = False,
        sparse_2dna_kernel_size = 7,
        sparse_2dna_dilation = (1,),
        sparse_2dna_rel_pos_bias = False,
        shift_video_tokens = False,
        shift_audio_tokens = False,
        audio_tokens_per_timestep = 1,
        cross_modality_attn_every = 3
    # 定义前向传播函数
    def forward(
        self,
        video,
        audio,
        *,
        context,
        audio_mask = None,
        video_mask = None,
        context_mask = None,
        **kwargs
    ):
        # 调用网络进行前向传播
        video, audio = self.net(
            video,
            audio,
            context = context,
            audio_mask = audio_mask,
            video_mask = video_mask,
            context_mask = context_mask
        )

        # 返回视频和音频的归一化结果
        return self.video_norm(video), self.audio_norm(audio)

# 嵌入
# 定义一个名为 Embedding 的神经网络模块类
class Embedding(nn.Module):
    # 初始化函数,接受形状参数和梯度分数参数
    def __init__(self, *shape, frac_gradient = 1.):
        super().__init__()
        # 设置梯度分数参数
        self.frac_gradient = frac_gradient
        # 创建 Embedding 层
        self.embed = nn.Embedding(*shape)

    # 前向传播函数
    def forward(self, x):
        # 将输入 x 传入 Embedding 层
        x = self.embed(x)

        # 如果处于训练状态且梯度分数小于1,则对 x 进行梯度分数处理
        if self.training and self.frac_gradient < 1:
            x = frac_gradient(x, self.frac_gradient)

        return x

# positional embedding

# 定义一个名为 AxialPositionalEmbedding 的神经网络模块类
class AxialPositionalEmbedding(nn.Module):
    # 初始化函数,接受维度参数和形状参数
    def __init__(
        self,
        dim,
        *,
        shape
    ):
        super().__init__()
        # 过滤形状参数中大于1的值,形成新的形状参数
        shape = tuple(filter(lambda t: t > 1, shape))

        # 设置维度、形状和轴数
        self.dim = dim
        self.shape = shape
        self.num_axials = len(shape)

        # 为每个轴创建随机参数
        for axial_ind, axial_len in enumerate(shape):
            axial_pos = nn.Parameter(torch.randn(axial_len, dim))
            setattr(self, f'axial{axial_ind + 1}', axial_pos)

    # 前向传播函数,接受 flatten 参数
    def forward(self, *, flatten = True):
        positions = None

        # 遍历每个轴
        for axial_ind in range(self.num_axials):
            axial_pos = getattr(self, f'axial{axial_ind + 1}')

            # 如果 positions 为空,则将当前轴位置赋给 positions
            if not exists(positions):
                positions = axial_pos
                continue

            # 对 positions 进行重排列,并加上当前轴位置
            positions = rearrange(positions, '... d -> ... 1 d')
            positions = positions + axial_pos

        # 如果 flatten 为 True,则对 positions 进行重排列
        if flatten:
            positions = rearrange(positions, '... d -> (...) d')

        return positions

# sampling helpers

# 定义一个名为 top_k 的函数,接受 logits 和阈值参数
def top_k(logits, thres = 0.5):
    # 获取 logits 的最后一个维度大小
    num_logits = logits.shape[-1]
    # 计算 k 值
    k = max(int((1 - thres) * num_logits), 1)
    # 获取前 k 个最大值的索引和值
    val, ind = torch.topk(logits, k)
    # 创建与 logits 相同大小的全为负无穷的张量
    probs = torch.full_like(logits, float('-inf'))
    # 根据索引将值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# main class

# 定义一个名为 NUWA 的神经网络模块类
class NUWA(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,
        vae = None,
        image_size = None,
        max_video_frames = 5,
        text_num_tokens = 49408,
        text_max_seq_len = 256,
        text_enc_depth = 6,
        text_enc_dim_head = 64,
        text_enc_heads = 8,
        text_rotary_pos_emb = True,
        enc_reversible = False,
        dec_depth = 6,
        dec_dim_head = 64,
        dec_heads = 8,
        dec_reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        embed_gradient_frac = 0.2,
        shift_video_tokens = True,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilation = 1,
        sparse_3dna_rel_pos_bias = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言 VAE 或图像大小必须被指定
        assert exists(vae) ^ exists(image_size), 'either VAE or image size must be specified'

        self.vae = None
        # 如果存在 VAE,则复制一个用于评估的 VAE,并设置图像大小为 VAE 的图像大小
        if exists(vae):
            self.vae = vae.copy_for_eval()
            image_size = vae.image_size

        # 获取 VAE 的层数和图像 token 数量
        vae_num_layers = vae.num_layers
        num_image_tokens = vae.codebook_size

        self.text_max_seq_len = text_max_seq_len
        # 创建文本嵌入层
        self.text_embedding = Embedding(text_num_tokens, dim, frac_gradient = embed_gradient_frac)

        # 为文本创建位置嵌入
        self.text_abs_pos_emb = Embedding(text_max_seq_len, dim)  if not text_rotary_pos_emb else None
        self.text_rotary_pos_emb = RotaryEmbedding(dim = min(32, text_enc_dim_head)) if text_rotary_pos_emb else None

        # 根据是否可逆选择编码器的类型
        enc_transformer_klass = Transformer if not enc_reversible else ReversibleTransformer

        # 创建文本变换器
        self.text_transformer = enc_transformer_klass(
            dim = dim,
            depth = text_enc_depth,
            heads = text_enc_heads,
            dim_head = text_enc_dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            rotary_pos_emb = text_rotary_pos_emb
        )

        # 创建视频的开始 token
        self.video_bos = nn.Parameter(torch.randn(dim))
        # 创建图像嵌入层
        self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient = embed_gradient_frac)

        # 计算特征图大小
        fmap_size = image_size // (2 ** vae_num_layers)

        self.video_fmap_size = fmap_size
        self.max_video_frames = max_video_frames
        video_shape = (max_video_frames, fmap_size, fmap_size)

        # 为视频创建位置嵌入
        self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)

        # 设置稀疏 3D 邻近注意力的循环扩张
        sparse_3dna_dilations = tuple(range(1, sparse_3dna_dilation + 1)) if not isinstance(sparse_3dna_dilation, (list, tuple)) else sparse_3dna_dilation

        # 根据是否可逆选择解码器的类型
        dec_transformer_klass = Transformer if not dec_reversible else ReversibleTransformer

        # 创建视频变换器
        self.video_transformer = dec_transformer_klass(
            dim = dim,
            depth = dec_depth,
            heads = dec_heads,
            dim_head = dec_dim_head,
            causal = True,
            cross_attend = True,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_chunk_size = ff_chunk_size,
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = video_shape,
            sparse_3dna_attn = True,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
            sparse_3dna_rel_pos_bias = sparse_3dna_rel_pos_bias
        )

        # 创建输出层
        self.to_logits = nn.Linear(dim, num_image_tokens, bias = False)

    def embed_text(self, text, mask = None):
        # 获取文本的批量大小、序列长度和设备
        batch, seq_len, device = *text.shape, text.device
        # 断言序列长度不超过文本最大序列长度
        assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization'

        # 对文本进行嵌入
        tokens = self.text_embedding(text)

        if exists(self.text_abs_pos_emb):
            # 添加绝对位置嵌入
            pos_emb = self.text_abs_pos_emb(torch.arange(seq_len, device = device))
            tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d')

        rotary_pos_emb = None
        if exists(self.text_rotary_pos_emb):
            # 如果存在旋转位置嵌入,则获取旋转位置嵌入
            rotary_pos_emb = self.text_rotary_pos_emb(seq_len, device = device)

        # 返回文本变换器的结果
        return self.text_transformer(
            tokens,
            mask = mask,
            rotary_pos_emb = rotary_pos_emb
        )

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        *,
        text,
        filter_thres = 0.9,
        temperature = 1.,
        decode_max_batchsize = 10,
        cond_scale = 2.,
        num_frames = None
    ):
        # 解包文本张量的形状和设备信息
        batch, seq_len, device = *text.shape, text.device

        # 创建文本掩码,将文本张量中非零元素标记为True
        text_mask = text != 0
        # 使用文本嵌入层对文本进行嵌入处理,同时传入文本掩码
        text_embeds = self.embed_text(text, mask = text_mask)

        # 重复视频起始符号,形状为(batch, 1, d),其中d为视频特征维度
        bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)

        # 创建空的视频索引张量,形状为(batch, 0),设备为指定设备
        video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)

        # 计算每帧视频的标记数量
        num_tokens_per_frame = self.video_fmap_size ** 2

        # 设置视频帧数,默认为最大视频帧数
        num_frames = default(num_frames, self.max_video_frames)
        total_video_tokens =  num_tokens_per_frame * num_frames
        max_video_tokens = num_tokens_per_frame * self.max_video_frames

        # 获取视频位置编码
        pos_emb = self.video_pos_emb()

        # 遍历视频标记总数
        for ind in tqdm(range(total_video_tokens)):
            # 备份视频索引输入
            video_indices_input = video_indices

            # 获取当前视频标记数量
            num_video_tokens = video_indices.shape[1]
            # 如果视频标记数量超过最大视频标记数量
            if num_video_tokens > max_video_tokens:
                # 计算当前帧标记数量
                curr_frame_tokens = num_video_tokens % num_tokens_per_frame
                # 计算回溯标记数量
                lookback_tokens = (self.max_video_frames - (0 if curr_frame_tokens == 0 else 1)) * num_tokens_per_frame + curr_frame_tokens
                # 更新视频索引输入为最近的标记
                video_indices_input = video_indices[:, -lookback_tokens:]

            # 获取帧嵌入
            frame_embeddings = self.image_embedding(video_indices_input)
            # 添加位置编码到帧嵌入
            frame_embeddings = pos_emb[:frame_embeddings.shape[1]] + frame_embeddings
            # 拼接起始符号和帧嵌入
            frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)

            # 使用视频Transformer处理帧嵌入和文本嵌入
            frame_embeddings = self.video_transformer(
                frame_embeddings,
                context = text_embeds,
                context_mask = text_mask
            )

            # 获取输出logits
            logits = self.to_logits(frame_embeddings)

            # 如果条件缩放不为1
            if cond_scale != 1:
                # 使用视频Transformer处理帧嵌入和文本嵌入,但文本掩码为全零
                uncond_frame_embeddings = self.video_transformer(
                    frame_embeddings,
                    context = text_embeds,
                    context_mask = torch.zeros_like(text_mask).bool()
                )

                # 获取无条件logits
                uncond_logits = self.to_logits(uncond_frame_embeddings)
                # 更新logits为无条件logits加上条件缩放后的值
                logits = uncond_logits + (logits - uncond_logits) * cond_scale

            # 选择最后一个标记的logits
            logits = logits[:, -1, :]

            # 对logits进行筛选,保留前k个值
            filtered_logits = top_k(logits, thres = filter_thres)
            # 使用Gumbel采样获取样本
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
            # 重新排列样本的形状
            sample = rearrange(sample, 'b -> b 1')
            # 拼接样本到视频索引
            video_indices = torch.cat((video_indices, sample), dim = 1)

        # 根据视频索引获取VAE的代码簿
        codes = self.vae.codebook[video_indices]
        # 重新排列代码的形状
        codes = rearrange(codes, 'b (f h w) d -> (b f) d h w', h = self.video_fmap_size, w = self.video_fmap_size)

        # 批处理代码,通过VAE解码获取图像重构
        image_reconstructions = batch_process(codes, self.vae.decode, chunks = decode_max_batchsize)
        # 重新排列图像重构的形状
        video = rearrange(image_reconstructions, '(b f) d h w -> b f d h w', b = batch)
        # 返回视频
        return video

    # 前向传播函数
    def forward(
        self,
        *,
        text,
        video = None,
        return_loss = False,
        cond_dropout_prob = 0.2
        # 从输入的张量形状中获取批次大小、序列长度、帧数和设备信息
        batch, seq_len, frames, device = *text.shape, video.shape[1], text.device

        # 创建文本掩码,将文本中非零元素标记为True
        text_mask = text != 0
        # 使用文本嵌入模型对文本进行嵌入处理,同时应用文本掩码
        text_embeds = self.embed_text(text, mask = text_mask)

        # 如果视频数据类型为torch.long,则直接使用视频帧索引
        if video.dtype == torch.long:
            frame_indices = video
        else:
            # 否则,确保视频帧数与最大视频帧数相同,并且需要传入VAE模型以自动将视频编码为ids
            assert frames == self.max_video_frames, f'you must give the full video frames ({self.max_video_frames}) during training'
            assert exists(self.vae), 'VAE must be passed in if you wish for video to be encoded to ids automatically'
            frame_indices = self.vae.get_video_indices(video)

        # 重新排列视频帧索引的形状
        frame_indices = rearrange(frame_indices, 'b ... -> b (...)')
        # 如果不需要返回损失,则将视频帧索引的最后一帧排除在外
        frame_indices_input = frame_indices[:, :-1] if return_loss else frame_indices

        # 使用图像嵌入模型对视频帧索引进行嵌入处理
        frame_embeddings = self.image_embedding(frame_indices_input)
        # 添加视频位置编码到帧嵌入中
        frame_embeddings = self.video_pos_emb()[:-1] + frame_embeddings

        # 在帧嵌入的开头添加起始符号
        bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
        frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)

        # 如果处于训练状态且条件丢弃概率大于0,则随机丢弃条件
        if self.training and cond_dropout_prob > 0:
            # 随机生成与文本掩码相同形状的无条件掩码
            uncond_mask = prob_mask_like((batch,), cond_dropout_prob, device = device)
            # 将无条件掩码应用到文本掩码上
            text_mask *= rearrange(~uncond_mask, 'b -> b 1')

        # 使用视频变换器模型处理帧嵌入和文本嵌入
        frame_embeddings = self.video_transformer(
            frame_embeddings,
            context = text_embeds,
            context_mask = text_mask
        )

        # 将帧嵌入转换为logits
        logits = self.to_logits(frame_embeddings)

        # 如果不需要返回损失,则直接返回logits
        if not return_loss:
            return logits

        # 计算交叉熵损失
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), frame_indices)
        return loss
# 定义一个名为NUWAVideoAudio的类,继承自nn.Module
class NUWAVideoAudio(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        vae,  # 视频和音频编码器
        dim,  # 模型维度
        image_size,  # 图像尺寸
        num_audio_tokens,  # 音频标记数量
        num_audio_tokens_per_video_frame,  # 每个视频帧的音频标记数量
        audio_tokens_per_timestep = 1,  # 每个时间步的音频标记数量
        max_video_frames = 5,  # 最大视频帧数
        text_num_tokens = 49408,  # 文本标记数量
        text_max_seq_len = 256,  # 文本最大序列长度
        text_enc_depth = 6,  # 文本编码器深度
        text_enc_dim_head = 64,  # 文本编码器头维度
        text_enc_heads = 8,  # 文本编码器头数
        text_rotary_pos_emb = False,  # 是否使用旋转位置嵌入
        enc_reversible = False,  # 编码器是否可逆
        dec_reversible = True,  # 解码器是否可逆
        dec_depth = 6,  # 解码器深度
        dec_dim_head = 64,  # 解码器头维度
        dec_heads = 8,  # 解码器头数
        attn_dropout = 0.,  # 注意力机制的dropout
        ff_dropout = 0.,  # 前馈网络的dropout
        ff_chunk_size = None,  # 前馈网络的分块大小
        embed_gradient_frac = 0.2,  # 嵌入梯度比例
        shift_video_tokens = True,  # 是否移动视频标记
        shift_audio_tokens = True,  # 是否移动音频标记
        sparse_3dna_kernel_size = 3,  # 稀疏3D卷积核大小
        sparse_3dna_query_num_frames_chunk = None,  # 稀疏3D卷积查询帧块数
        sparse_3dna_dilation = 1,  # 稀疏3D卷积膨胀率
        sparse_3dna_rel_pos_bias = True,  # 稀疏3D卷积相对位置偏置
        sparse_2dna_kernel_size = 7,  # 稀疏2D卷积核大小
        sparse_2dna_dilation = 1,  # 稀疏2D卷积膨胀率
        sparse_2dna_rel_pos_bias = True,  # 稀疏2D卷积相对位置偏置
        audio_loss_weight = 1.,  # 音频损失权重
        cross_modality_attn_every = 3  # 跨模态注意力的频率
        ):
        # 调用父类的构造函数
        super().__init__()
        # 复制 VAE 模型用于评估
        self.vae = vae.copy_for_eval()
        # 获取 VAE 模型的层数和图像编码数量
        vae_num_layers = vae.num_layers
        num_image_tokens = vae.codebook_size

        # 设置文本相关参数
        self.text_max_seq_len = text_max_seq_len
        self.text_embedding = Embedding(text_num_tokens, dim, frac_gradient = embed_gradient_frac)

        # 根据是否使用旋转位置编码来选择文本绝对位置编码或旋转位置编码
        self.text_abs_pos_emb = Embedding(text_max_seq_len, dim) if not text_rotary_pos_emb else None
        self.text_rotary_pos_emb = RotaryEmbedding(dim = min(32, text_enc_dim_head)) if text_rotary_pos_emb else None

        # 根据是否使用可逆编码器来选择编码器类型
        enc_transformer_klass = Transformer if not enc_reversible else ReversibleTransformer

        # 创建文本变换器
        self.text_transformer = enc_transformer_klass(
            dim = dim,
            depth = text_enc_depth,
            heads = text_enc_heads,
            dim_head = text_enc_dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout
        )

        # 视频相关参数

        # 初始化视频的开始符号
        self.video_bos = nn.Parameter(torch.randn(dim))
        self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient = embed_gradient_frac)

        # 计算特征图大小
        fmap_size = image_size // (2 ** vae_num_layers)

        self.video_fmap_size = fmap_size
        self.max_video_frames = max_video_frames
        video_shape = (max_video_frames, fmap_size, fmap_size)

        # 创建视频位置编码
        self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)

        # 音频相关参数

        # 初始化音频的开始符号
        self.audio_bos = nn.Parameter(torch.randn(dim))
        self.audio_embedding = Embedding(num_audio_tokens, dim, frac_gradient = embed_gradient_frac)

        # 计算每帧音频序列的最大长度
        max_audio_seq_len = num_audio_tokens_per_video_frame * max_video_frames
        self.audio_pos_emb = AxialPositionalEmbedding(dim, shape = (num_audio_tokens // audio_tokens_per_timestep, audio_tokens_per_timestep))

        self.audio_loss_weight = audio_loss_weight

        # 每帧视频的标记数量

        self.num_video_tokens_per_frame = fmap_size ** 2
        self.num_audio_tokens_per_video_frame = num_audio_tokens_per_video_frame

        # 稀疏3D邻近注意力的循环扩张

        sparse_3dna_dilations = tuple(range(1, sparse_3dna_dilation + 1)) if not isinstance(sparse_3dna_dilation, (list, tuple)) else sparse_3dna_dilation

        sparse_2dna_dilation = tuple(range(1, sparse_2dna_dilation + 1)) if not isinstance(sparse_2dna_dilation, (list, tuple)) else sparse_2dna_dilation

        # 根据是否使用可逆解码器来选择解码器类型
        decoder_klass = ReversibleDualModalityDecoder if dec_reversible else DualModalityDecoder

        # 创建视频音频变换器
        self.video_audio_transformer = decoder_klass(
            dim = dim,
            depth = dec_depth,
            heads = dec_heads,
            dim_head = dec_dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_chunk_size = ff_chunk_size,
            audio_tokens_per_timestep = audio_tokens_per_timestep,
            shift_audio_tokens = shift_audio_tokens,
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = video_shape,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
            sparse_3dna_rel_pos_bias = sparse_3dna_rel_pos_bias,
            num_audio_tokens_per_video_frame = num_audio_tokens_per_video_frame,
            num_video_tokens_per_frame = fmap_size * fmap_size,
            cross_modality_attn_every = cross_modality_attn_every,
            sparse_2dna_kernel_size = sparse_2dna_kernel_size,
            sparse_2dna_dilation = sparse_2dna_dilation,
            sparse_2dna_rel_pos_bias = sparse_2dna_rel_pos_bias
        )

        # 线性层将维度映射到图像标记数量
        self.to_video_logits = nn.Linear(dim, num_image_tokens, bias = False)
        # 线性层将维度映射到音频标记数量
        self.to_audio_logits = nn.Linear(dim, num_audio_tokens, bias = False)
    # 将文本嵌入到模型中
    def embed_text(self, text, mask = None):
        # 获取文本的批次、序列长度和设备信息
        batch, seq_len, device = *text.shape, text.device
        # 断言文本序列长度不超过预设的最大长度
        assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization'

        # 对文本进行嵌入
        tokens = self.text_embedding(text)

        # 如果存在绝对位置嵌入,则添加到嵌入的文本中
        if exists(self.text_abs_pos_emb):
            pos_emb = self.text_abs_pos_emb(torch.arange(seq_len, device = device))
            tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d')

        rotary_pos_emb = None
        # 如果存在旋转位置嵌入,则获取旋转位置嵌入
        if exists(self.text_rotary_pos_emb):
            rotary_pos_emb = self.text_rotary_pos_emb(seq_len, device = device)

        # 返回经过文本变换器处理后的结果
        return self.text_transformer(
            tokens,
            mask = mask,
            rotary_pos_emb = rotary_pos_emb
        )

    # 生成文本
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        *,
        text,
        filter_thres = 0.9,
        temperature = 1.,
        decode_max_batchsize = 10,
        cond_scale = 2.,
        num_frames = None
    # 前向传播
    def forward(
        self,
        *,
        text,
        video,
        audio,
        return_loss = False,
        cond_dropout_prob = 0.2
    ):
        # 获取文本、视频、音频的批次、序列长度、帧数和设备信息
        batch, seq_len, frames, device = *text.shape, video.shape[1], text.device

        # 创建文本的掩码
        text_mask = text != 0
        # 对文本进行嵌入
        text_embeds = self.embed_text(text, mask = text_mask)

        # 准备视频表示

        # 如果视频的数据类型为整数,则直接使用视频帧索引
        if video.dtype == torch.long:
            frame_indices = video
        else:
            # 断言视频帧数与最大视频帧数相同
            assert frames == self.max_video_frames, f'you must give the full video frames ({self.max_video_frames}) during training'
            # 断言存在 VAE 模型
            assert exists(self.vae), 'VAE must be passed in if you wish for video to be encoded to ids automatically'
            # 获取视频帧索引
            frame_indices = self.vae.get_video_indices(video)

        # 重排视频帧索引的维度
        frame_indices = rearrange(frame_indices, 'b ... -> b (...)')
        frame_indices_input = frame_indices[:, :-1] if return_loss else frame_indices

        # 对视频帧进行嵌入
        frame_embeddings = self.image_embedding(frame_indices_input)
        frame_embeddings = self.video_pos_emb()[:-1] + frame_embeddings

        # 在视频帧前添加特殊标记
        video_bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
        frame_embeddings = torch.cat((video_bos, frame_embeddings), dim = 1)

        # 准备音频表示

        audio_indices_input = audio[:, :-1] if return_loss else audio

        # 对音频进行嵌入
        audio_embeddings = self.audio_embedding(audio_indices_input)
        audio_pos_emb = self.audio_pos_emb()[:audio_embeddings.shape[1]]
        audio_embeddings = audio_embeddings + rearrange(audio_pos_emb, 'n d -> 1 n d')

        # 在音频前添加特殊标记
        audio_bos = repeat(self.audio_bos, 'd -> b 1 d', b = batch)
        audio_embeddings = torch.cat((audio_bos, audio_embeddings), dim = 1)

        # 空条件,用于超级条件

        if self.training and cond_dropout_prob > 0:
            # 随机丢弃条件
            # 参考:https://openreview.net/forum?id=qw8AKxfYbI
            uncond_mask = prob_mask_like((batch,), cond_dropout_prob, device = device)
            text_mask *= rearrange(~uncond_mask, 'b -> b 1')

        # 视频和音频的双重注意力塔,具有高效的分块跨模态注意力

        frame_embeddings, audio_embeddings = self.video_audio_transformer(
            frame_embeddings,
            audio_embeddings,
            context = text_embeds,
            context_mask = text_mask
        )

        # 获取视频和音频的逻辑回归结果
        video_logits = self.to_video_logits(frame_embeddings)
        audio_logits = self.to_audio_logits(audio_embeddings)

        # 如果不需要计算损失,则直接返回逻辑回归结果
        if not return_loss:
            return video_logits, audio_logits

        # 计算视频和音频的损失
        video_loss = F.cross_entropy(rearrange(video_logits, 'b n c -> b c n'), frame_indices)
        audio_loss = F.cross_entropy(rearrange(audio_logits, 'b n c -> b c n'), audio)

        # 返回视频和音频的损失之和
        return video_loss + audio_loss * self.audio_loss_weight
# 主要用于学习素描的主类

class NUWASketch(nn.Module):
    def __init__(
        self,
        *,
        vae,  # VAE 模型
        sketch_vae,  # 素描 VAE 模型
        dim,  # 维度
        image_size,  # 图像大小
        max_video_frames = 5,  # 最大视频帧数
        sketch_max_video_frames = 2,  # 素描最大视频帧数
        sketch_enc_depth = 6,  # 素描编码器深度
        sketch_enc_dim_head = 64,  # 素描编码器头维度
        sketch_enc_heads = 8,  # 素描编码器头数
        sketch_enc_use_sparse_3dna = False,  # 是否使用稀疏 3DNA
        enc_reversible = False,  # 编码器是否可逆
        dec_depth = 6,  # 解码器深度
        dec_dim_head = 64,  # 解码器头维度
        dec_heads = 8,  # 解码器头数
        dec_reversible = False,  # 解码器是否可逆
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # FeedForward 层的 dropout
        ff_chunk_size = None,  # FeedForward 层的块大小
        embed_gradient_frac = 0.2,  # 嵌入梯度比例
        shift_video_tokens = True,  # 是否移动视频 token
        cross_2dna_kernel_size = 3,  # 交叉 2DNA 的卷积核大小
        cross_2dna_dilation = 1,  # 交叉 2DNA 的膨胀率
        sparse_3dna_kernel_size = 3,  # 稀疏 3DNA 的卷积核大小
        sparse_3dna_dilation = 1,  # 稀疏 3DNA 的膨胀率
        sparse_3dna_query_num_frames_chunk = None,  # 稀疏 3DNA 查询的帧块数
        ):
        # 调用父类的构造函数
        super().__init__()
        # 设置图像大小
        self.image_size = image_size

        # 设置sketch_vae属性
        self.sketch_vae = sketch_vae
        # 获取sketch_vae的层数
        sketch_vae_num_layers = sketch_vae.num_layers
        # 获取sketch_vae的编码本大小
        sketch_num_image_tokens = sketch_vae.codebook_size
        # 计算sketch的特征图大小
        sketch_fmap_size = image_size // (2 ** sketch_vae_num_layers)

        # 定义sketch的形状
        sketch_shape = (sketch_max_video_frames, sketch_fmap_size, sketch_fmap_size)

        # 设置sketch_max_video_frames属性
        self.sketch_max_video_frames = sketch_max_video_frames
        # 创建sketch的嵌入层
        self.sketch_embedding = Embedding(sketch_num_image_tokens, dim, frac_gradient = embed_gradient_frac)
        # 创建sketch的位置嵌入
        self.sketch_pos_emb = AxialPositionalEmbedding(dim, shape = sketch_shape)

        # sparse 3dna kwargs

        # 设置稀疏3dna的膨胀
        sparse_3dna_dilations = tuple(range(1, sparse_3dna_dilation + 1)) if not isinstance(sparse_3dna_dilation, (list, tuple)) else sparse_3dna_dilation

        # encoder

        # 根据enc_reversible选择不同的Transformer类
        enc_transformer_klass = Transformer if not enc_reversible else ReversibleTransformer

        # 创建sketch_transformer
        self.sketch_transformer = enc_transformer_klass(
            dim = dim,
            depth = sketch_enc_depth,
            heads = sketch_enc_heads,
            dim_head = sketch_enc_dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = sketch_shape,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
            sparse_3dna_attn = sketch_enc_use_sparse_3dna
        )

        # decoder parameters

        # 复制vae用于评估
        self.vae = vae.copy_for_eval()

        # 获取vae的层数和编码本大小
        vae_num_layers = vae.num_layers
        num_image_tokens = vae.codebook_size

        # 创建video_bos参数
        self.video_bos = nn.Parameter(torch.randn(dim))
        # 创建图像嵌入层
        self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient = embed_gradient_frac)

        # 计算特征图大小
        fmap_size = image_size // (2 ** vae_num_layers)

        # 断言特征图大小相等
        assert fmap_size == sketch_fmap_size, 'feature map size of video must be equal to the feature map size of sketches (VAEs must have same number of layers)'

        # 设置video_fmap_size属性
        self.video_fmap_size = fmap_size
        # 设置最大视频帧数
        self.max_video_frames = max_video_frames
        # 定义video的形状
        video_shape = (max_video_frames, fmap_size, fmap_size)

        # 创建video的位置嵌入
        self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)

        # cycle dilation for sparse 3d-nearby attention

        # 设置cross_2dna_dilations
        cross_2dna_dilations = tuple(range(1, cross_2dna_dilation + 1)) if not isinstance(cross_2dna_dilation, (list, tuple)) else cross_2dna_dilation
        # 根据dec_reversible选择不同的Transformer类
        dec_transformer_klass = Transformer if not dec_reversible else ReversibleTransformer

        # 创建video_transformer
        self.video_transformer = dec_transformer_klass(
            dim = dim,
            depth = dec_depth,
            heads = dec_heads,
            dim_head = dec_dim_head,
            causal = True,
            cross_attend = True,
            cross_2dna_attn = True,
            cross_2dna_image_size = fmap_size,
            cross_2dna_kernel_size = cross_2dna_kernel_size,
            cross_2dna_dilations = cross_2dna_dilations,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_chunk_size = ff_chunk_size,
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = video_shape,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
            sparse_3dna_attn = True
        )

        # 创建输出层
        self.to_logits = nn.Linear(dim, num_image_tokens, bias = False)
    # 将 sketch 的形状解构为 batch, frames, channels, image_size, _,并获取设备信息
    batch, frames, channels, image_size, _, device = *sketch.shape, sketch.device

    # 如果存在 mask,则确保 mask 的形状为 (batch, frames)
    if exists(mask):
        assert mask.shape[:2] == (batch, frames), 'sketch mask must be in shape of (batch x frame)'

    # 获取 sketch 的索引
    sketch_indices = self.sketch_vae.get_video_indices(sketch)
    # 重新排列 sketch_indices 的形状
    sketch_indices = rearrange(sketch_indices, 'b ... -> b (...)')

    # 使用 sketch_indices 获取 sketch_tokens
    sketch_tokens = self.sketch_embedding(sketch_indices)

    # 获取 sketch_tokens 的数量
    num_tokens = sketch_tokens.shape[1]

    # 获取 sketch 的位置编码
    sketch_pos_emb = self.sketch_pos_emb()
    sketch_pos_emb = sketch_pos_emb[:num_tokens]

    # 将 sketch_tokens 与 sketch_pos_emb 相加
    sketch_tokens = sketch_tokens + sketch_pos_emb

    # 如果存在 mask,则重复 mask,使其形状为 (batch, num_tokens)
    if exists(mask):
        mask = repeat(mask, 'b f -> b (f n)', n = (num_tokens // frames)
    else:
        # 如果不存在 mask,则创建全为 True 的 mask
        mask = torch.ones((batch, num_tokens), dtype = torch.bool, device = device)

    # 使用 sketch_transformer 对 sketch_tokens 进行嵌入
    embed = self.sketch_transformer(sketch_tokens, mask = mask)
    return embed, mask

@torch.no_grad()
@eval_decorator
def generate(
    self,
    *,
    sketch,
    sketch_mask = None,
    filter_thres = 0.9,
    temperature = 1.,
    decode_max_batchsize = 10,
    cond_scale = 2.,
    num_frames = None
    # 获取批次大小和设备信息
    batch, device = sketch.shape[0], sketch.device

    # 对草图进行嵌入处理,并生成解码器上下文掩码
    sketch_embeds, decoder_context_mask = self.embed_sketch(sketch, mask = sketch_mask)

    # 创建起始符号
    bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)

    # 创建空的视频索引张量
    video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)

    # 计算每帧的标记数量
    num_tokens_per_frame = self.video_fmap_size ** 2

    # 设置视频帧数和总标记数量
    num_frames = default(num_frames, self.max_video_frames)
    total_video_tokens =  num_tokens_per_frame * num_frames
    max_video_tokens = num_tokens_per_frame * self.max_video_frames

    # 获取位置编码
    pos_emb = self.video_pos_emb()

    # 遍历视频标记
    for ind in tqdm(range(total_video_tokens)):
        # 复制视频索引输入
        video_indices_input = video_indices

        # 获取当前视频标记数量
        num_video_tokens = video_indices.shape[1]
        if num_video_tokens > max_video_tokens:
            # 计算回溯标记数量
            curr_frame_tokens = num_video_tokens % num_tokens_per_frame
            lookback_tokens = (self.max_video_frames - (0 if curr_frame_tokens == 0 else 1)) * num_tokens_per_frame + curr_frame_tokens
            video_indices_input = video_indices[:, -lookback_tokens:]

        # 获取帧嵌入
        frame_embeddings = self.image_embedding(video_indices_input)
        frame_embeddings = pos_emb[:frame_embeddings.shape[1]] + frame_embeddings
        frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)

        # 使用视频变换器处理帧嵌入
        frame_embeddings = self.video_transformer(
            frame_embeddings,
            context = sketch_embeds,
            context_mask = decoder_context_mask
        )

        # 获取逻辑回归结果
        logits = self.to_logits(frame_embeddings)

        if cond_scale != 1:
            # 根据条件比例对逻辑回归结果进行调整
            uncond_frame_embeddings = self.video_transformer(
                frame_embeddings,
                context = sketch_embeds,
                context_mask = torch.zeros_like(decoder_context_mask).bool()
            )

            uncond_logits = self.to_logits(uncond_frame_embeddings)
            logits = uncond_logits + (logits - uncond_logits) * cond_scale

        logits = logits[:, -1, :]

        # 过滤逻辑回归结果并进行采样
        filtered_logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
        sample = rearrange(sample, 'b -> b 1')
        video_indices = torch.cat((video_indices, sample), dim = 1)

    # 获取代码本和重构图像
    codes = self.vae.codebook[video_indices]
    codes = rearrange(codes, 'b (f h w) d -> (b f) d h w', h = self.video_fmap_size, w = self.video_fmap_size)

    image_reconstructions = batch_process(codes, self.vae.decode, chunks = decode_max_batchsize)
    video = rearrange(image_reconstructions, '(b f) d h w -> b f d h w', b = batch)
    return video

# 定义前向传播函数
def forward(
    self,
    *,
    sketch,
    sketch_mask = None,
    video = None,
    return_loss = False,
    cond_dropout_prob = 0.2
        # 处理一个草图的过程

        # 如果草图的维度为4,则重新排列成 'b c h w -> b 1 c h w'
        if sketch.ndim == 4:
            sketch = rearrange(sketch, 'b c h w -> b 1 c h w')

        # 获取一系列变量

        # 解包sketch的形状,得到batch, sketch_frames, sketch_channels, sketch_image_size, _, frames, device
        batch, sketch_frames, sketch_channels, sketch_image_size, _, frames, device = *sketch.shape, video.shape[1], sketch.device

        # 断言

        # 断言sketch_image_size必须等于self.image_size
        assert sketch_image_size == self.image_size, 'sketch image size must be equal'
        # 断言sketch_frames必须小于等于self.sketch_max_video_frames
        assert sketch_frames <= self.sketch_max_video_frames, 'sketch frames must be less than max sketch video frames'

        # 获取草图嵌入和计算掩码(暂时假设没有填充)

        # 获取草图嵌入和解码器上下文掩码
        sketch_embeds, decoder_context_mask = self.embed_sketch(sketch, mask=sketch_mask)

        # 断言

        # 断言frames必须等于self.max_video_frames
        assert frames == self.max_video_frames, f'you must give the full video frames ({self.max_video_frames}) during training'

        # 获取视频帧索引
        frame_indices = self.vae.get_video_indices(video)
        # 重新排列帧索引的形状,变为 'b ... -> b (...)'
        frame_indices = rearrange(frame_indices, 'b ... -> b (...)')
        # 如果不需要返回损失,则将帧索引输入设为frame_indices的前n-1个元素
        frame_indices_input = frame_indices[:, :-1] if not return_loss else frame_indices

        # 获取帧嵌入
        frame_embeddings = self.image_embedding(frame_indices_input)
        # 添加视频位置编码
        frame_embeddings = self.video_pos_emb()[:-1] + frame_embeddings

        # 在帧嵌入前添加开始标记
        bos = repeat(self.video_bos, 'd -> b 1 d', b=batch)
        frame_embeddings = torch.cat((bos, frame_embeddings), dim=1)

        # 如果处于训练状态且cond_dropout_prob大于0
        if self.training and cond_dropout_prob > 0:
            # 随机丢弃条件
            # 参考:https://openreview.net/forum?id=qw8AKxfYbI
            uncond_mask = prob_mask_like((batch,), cond_dropout_prob, device=device)
            sketch_mask *= rearrange(~uncond_mask, 'b -> b 1')

        # 使用视频变换器处理帧嵌入
        frame_embeddings = self.video_transformer(
            frame_embeddings,
            context=sketch_embeds,
            context_mask=decoder_context_mask
        )

        # 获取logits
        logits = self.to_logits(frame_embeddings)

        # 如果不需要返回损失,则返回logits
        if not return_loss:
            return logits

        # 计算损失
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), frame_indices)
        return loss
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(108)  评论(0)    收藏  举报