diffusers 源码解析(十)
.\diffusers\models\resnet.py
# 版权声明,指定版权持有者及保留所有权利
# `TemporalConvLayer` 的版权,指定相关团队及保留所有权利
#
# 根据 Apache License 2.0 版本授权使用本文件;
# 除非遵守许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件按“原样”提供,不提供任何形式的担保或条件。
# 查看许可证以获取特定权限和限制的信息。
# 从 functools 模块导入 partial 函数,用于部分应用
from functools import partial
# 导入类型提示的 Optional、Tuple 和 Union
from typing import Optional, Tuple, Union
# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 PyTorch 的功能性操作模块
import torch.nn.functional as F
# 从工具模块导入 deprecate 装饰器
from ..utils import deprecate
# 从激活函数模块导入获取激活函数的工具
from .activations import get_activation
# 从注意力处理模块导入空间归一化类
from .attention_processor import SpatialNorm
# 从下采样模块导入相关的下采样类和函数
from .downsampling import (  # noqa
    Downsample1D,  # 一维下采样类
    Downsample2D,  # 二维下采样类
    FirDownsample2D,  # FIR 二维下采样类
    KDownsample2D,  # K 下采样类
    downsample_2d,  # 二维下采样函数
)
# 从归一化模块导入自适应组归一化类
from .normalization import AdaGroupNorm
# 从上采样模块导入相关的上采样类和函数
from .upsampling import (  # noqa
    FirUpsample2D,  # FIR 二维上采样类
    KUpsample2D,  # K 上采样类
    Upsample1D,  # 一维上采样类
    Upsample2D,  # 二维上采样类
    upfirdn2d_native,  # 原生的二维上采样函数
    upsample_2d,  # 二维上采样函数
)
# 定义一个使用条件归一化的 ResNet 块类,继承自 nn.Module
class ResnetBlockCondNorm2D(nn.Module):
    r"""
    使用包含条件信息的归一化层的 Resnet 块。
    # 参数说明
        Parameters:
            # 输入通道的数量
            in_channels (`int`): The number of channels in the input.
            # 第一层 conv2d 的输出通道数量,默认为 None 表示与输入通道相同
            out_channels (`int`, *optional*, default to be `None`):
                The number of output channels for the first conv2d layer. If None, same as `in_channels`.
            # 使用的 dropout 概率,默认为 0.0
            dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
            # 时间步嵌入的通道数量,默认为 512
            temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
            # 第一层归一化使用的组数量,默认为 32
            groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
            # 第二层归一化使用的组数量,默认为 None 表示与 groups 相同
            groups_out (`int`, *optional*, default to None):
                The number of groups to use for the second normalization layer. if set to None, same as `groups`.
            # 归一化使用的 epsilon 值,默认为 1e-6
            eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
            # 使用的激活函数类型,默认为 "swish"
            non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
            # 时间嵌入的归一化层,当前只支持 "ada_group" 或 "spatial"
            time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
                The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
            # FIR 滤波器,见相关文档
            kernel (`torch.Tensor`, optional, default to None): FIR filter, see
                [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
            # 输出的缩放因子,默认为 1.0
            output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
            # 如果为 True,则添加 1x1 的 nn.conv2d 层作为跳跃连接
            use_in_shortcut (`bool`, *optional*, default to `True`):
                If `True`, add a 1x1 nn.conv2d layer for skip-connection.
            # 如果为 True,则添加上采样层
            up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
            # 如果为 True,则添加下采样层
            down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
            # 如果为 True,则为 `conv_shortcut` 输出添加可学习的偏置
            conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
                `conv_shortcut` output.
            # 输出的通道数量,默认为 None 表示与输出通道相同
            conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
                If None, same as `out_channels`.
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels
        # 如果没有指定输出通道数,则设置为输入通道数
        out_channels = in_channels if out_channels is None else out_channels
        # 保存输出通道数
        self.out_channels = out_channels
        # 保存卷积快捷方式的使用状态
        self.use_conv_shortcut = conv_shortcut
        # 保存上采样的标志
        self.up = up
        # 保存下采样的标志
        self.down = down
        # 保存输出缩放因子
        self.output_scale_factor = output_scale_factor
        # 保存时间嵌入的归一化方式
        self.time_embedding_norm = time_embedding_norm
        # 如果没有指定输出组数,则设置为输入组数
        if groups_out is None:
            groups_out = groups
        # 根据时间嵌入归一化方式选择不同的归一化层
        if self.time_embedding_norm == "ada_group":  # ada_group
            # 使用 AdaGroupNorm 进行归一化
            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
        elif self.time_embedding_norm == "spatial":
            # 使用 SpatialNorm 进行归一化
            self.norm1 = SpatialNorm(in_channels, temb_channels)
        else:
            # 如果归一化方式不支持,抛出错误
            raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
        # 创建第一层卷积,输入通道数为 in_channels,输出通道数为 out_channels
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 根据时间嵌入归一化方式选择第二个归一化层
        if self.time_embedding_norm == "ada_group":  # ada_group
            # 使用 AdaGroupNorm 进行归一化
            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
        elif self.time_embedding_norm == "spatial":  # spatial
            # 使用 SpatialNorm 进行归一化
            self.norm2 = SpatialNorm(out_channels, temb_channels)
        else:
            # 如果归一化方式不支持,抛出错误
            raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
        # 创建 dropout 层以防止过拟合
        self.dropout = torch.nn.Dropout(dropout)
        # 如果没有指定 2D 卷积的输出通道数,则设置为输出通道数
        conv_2d_out_channels = conv_2d_out_channels or out_channels
        # 创建第二层卷积,输入通道数为 out_channels,输出通道数为 conv_2d_out_channels
        self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
        # 获取激活函数
        self.nonlinearity = get_activation(non_linearity)
        # 初始化上采样和下采样的变量
        self.upsample = self.downsample = None
        # 如果需要上采样,则创建上采样层
        if self.up:
            self.upsample = Upsample2D(in_channels, use_conv=False)
        # 如果需要下采样,则创建下采样层
        elif self.down:
            self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
        # 判断是否使用输入快捷方式,默认根据通道数决定
        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
        # 初始化卷积快捷方式
        self.conv_shortcut = None
        # 如果使用输入快捷方式,则创建对应的卷积层
        if self.use_in_shortcut:
            self.conv_shortcut = nn.Conv2d(
                in_channels,
                conv_2d_out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=conv_shortcut_bias,
            )
    # 定义前向传播方法,接收输入张量和时间嵌入,返回输出张量
        def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            # 检查是否有额外的位置参数或关键字参数中的 scale
            if len(args) > 0 or kwargs.get("scale", None) is not None:
                # 设置弃用消息,提醒用户 scale 参数已弃用并将来会引发错误
                deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
                # 调用弃用函数,记录 scale 的弃用情况
                deprecate("scale", "1.0.0", deprecation_message)
    
            # 将输入张量赋值给隐藏状态
            hidden_states = input_tensor
    
            # 对隐藏状态进行归一化处理,使用时间嵌入
            hidden_states = self.norm1(hidden_states, temb)
    
            # 应用非线性激活函数
            hidden_states = self.nonlinearity(hidden_states)
    
            # 检查是否存在上采样操作
            if self.upsample is not None:
                # 如果批次大小大于等于 64,确保输入张量和隐藏状态是连续的
                if hidden_states.shape[0] >= 64:
                    input_tensor = input_tensor.contiguous()
                    hidden_states = hidden_states.contiguous()
                # 对输入张量和隐藏状态进行上采样
                input_tensor = self.upsample(input_tensor)
                hidden_states = self.upsample(hidden_states)
    
            # 检查是否存在下采样操作
            elif self.downsample is not None:
                # 对输入张量和隐藏状态进行下采样
                input_tensor = self.downsample(input_tensor)
                hidden_states = self.downsample(hidden_states)
    
            # 对隐藏状态进行卷积操作
            hidden_states = self.conv1(hidden_states)
    
            # 再次对隐藏状态进行归一化处理,使用时间嵌入
            hidden_states = self.norm2(hidden_states, temb)
    
            # 应用非线性激活函数
            hidden_states = self.nonlinearity(hidden_states)
    
            # 应用 dropout 操作,防止过拟合
            hidden_states = self.dropout(hidden_states)
            # 再次对隐藏状态进行卷积操作
            hidden_states = self.conv2(hidden_states)
    
            # 如果存在 shortcut 卷积,则对输入张量进行 shortcut 卷积
            if self.conv_shortcut is not None:
                input_tensor = self.conv_shortcut(input_tensor)
    
            # 将输入张量和隐藏状态相加,并按输出缩放因子进行缩放
            output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
    
            # 返回输出张量
            return output_tensor
# 定义一个名为 ResnetBlock2D 的类,继承自 nn.Module
class ResnetBlock2D(nn.Module):
    r"""
    一个 Resnet 块的文档字符串。
    参数:
        in_channels (`int`): 输入的通道数。
        out_channels (`int`, *可选*, 默认为 `None`):
            第一个 conv2d 层的输出通道数。如果为 None,则与 `in_channels` 相同。
        dropout (`float`, *可选*, 默认为 `0.0`): 使用的 dropout 概率。
        temb_channels (`int`, *可选*, 默认为 `512`): 时间步嵌入的通道数。
        groups (`int`, *可选*, 默认为 `32`): 第一个归一化层使用的组数。
        groups_out (`int`, *可选*, 默认为 None):
            第二个归一化层使用的组数。如果设为 None,则与 `groups` 相同。
        eps (`float`, *可选*, 默认为 `1e-6`): 用于归一化的 epsilon。
        non_linearity (`str`, *可选*, 默认为 `"swish"`): 使用的激活函数。
        time_embedding_norm (`str`, *可选*, 默认为 `"default"` ): 时间缩放平移配置。
            默认情况下,通过简单的平移机制应用时间步嵌入条件。选择 "scale_shift" 以获得
            更强的条件作用,包含缩放和平移。
        kernel (`torch.Tensor`, *可选*, 默认为 None): FIR 滤波器,见
            [`~models.resnet.FirUpsample2D`] 和 [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *可选*, 默认为 `1.0`): 输出使用的缩放因子。
        use_in_shortcut (`bool`, *可选*, 默认为 `True`):
            如果为 `True`,为跳跃连接添加 1x1 的 nn.conv2d 层。
        up (`bool`, *可选*, 默认为 `False`): 如果为 `True`,添加一个上采样层。
        down (`bool`, *可选*, 默认为 `False`): 如果为 `True`,添加一个下采样层。
        conv_shortcut_bias (`bool`, *可选*, 默认为 `True`): 如果为 `True`,为
            `conv_shortcut` 输出添加可学习的偏置。
        conv_2d_out_channels (`int`, *可选*, 默认为 `None`): 输出的通道数。
            如果为 None,则与 `out_channels` 相同。
    """
    # 定义初始化方法,接收各参数以设置 Resnet 块的属性
    def __init__(
        self,
        *,
        in_channels: int,  # 输入通道数
        out_channels: Optional[int] = None,  # 输出通道数,默认为 None
        conv_shortcut: bool = False,  # 是否使用卷积快捷连接
        dropout: float = 0.0,  # dropout 概率
        temb_channels: int = 512,  # 时间步嵌入的通道数
        groups: int = 32,  # 归一化层的组数
        groups_out: Optional[int] = None,  # 第二个归一化层的组数
        pre_norm: bool = True,  # 是否在激活之前进行归一化
        eps: float = 1e-6,  # 归一化使用的 epsilon
        non_linearity: str = "swish",  # 使用的激活函数类型
        skip_time_act: bool = False,  # 是否跳过时间激活
        time_embedding_norm: str = "default",  # 时间嵌入的归一化方式
        kernel: Optional[torch.Tensor] = None,  # FIR 滤波器
        output_scale_factor: float = 1.0,  # 输出缩放因子
        use_in_shortcut: Optional[bool] = None,  # 是否在快捷连接中使用
        up: bool = False,  # 是否添加上采样层
        down: bool = False,  # 是否添加下采样层
        conv_shortcut_bias: bool = True,  # 是否添加可学习的偏置
        conv_2d_out_channels: Optional[int] = None,  # 输出通道数
    # 前向传播方法,接受输入张量和时间嵌入,返回输出张量
        def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            # 检查是否有额外参数或已弃用的 scale 参数
            if len(args) > 0 or kwargs.get("scale", None) is not None:
                # 生成弃用信息提示
                deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
                # 调用弃用函数记录弃用信息
                deprecate("scale", "1.0.0", deprecation_message)
    
            # 将输入张量赋值给隐藏状态
            hidden_states = input_tensor
    
            # 对隐藏状态进行规范化
            hidden_states = self.norm1(hidden_states)
            # 应用非线性激活函数
            hidden_states = self.nonlinearity(hidden_states)
    
            # 如果存在上采样层
            if self.upsample is not None:
                # 当批量大小较大时,确保张量连续存储
                if hidden_states.shape[0] >= 64:
                    input_tensor = input_tensor.contiguous()
                    hidden_states = hidden_states.contiguous()
                # 对输入和隐藏状态进行上采样
                input_tensor = self.upsample(input_tensor)
                hidden_states = self.upsample(hidden_states)
            # 如果存在下采样层
            elif self.downsample is not None:
                # 对输入和隐藏状态进行下采样
                input_tensor = self.downsample(input_tensor)
                hidden_states = self.downsample(hidden_states)
    
            # 对隐藏状态进行卷积操作
            hidden_states = self.conv1(hidden_states)
    
            # 如果存在时间嵌入投影层
            if self.time_emb_proj is not None:
                # 如果不跳过时间激活
                if not self.skip_time_act:
                    # 对时间嵌入应用非线性激活
                    temb = self.nonlinearity(temb)
                # 进行时间嵌入投影,并增加维度
                temb = self.time_emb_proj(temb)[:, :, None, None]
    
            # 根据时间嵌入的规范化方式处理隐藏状态
            if self.time_embedding_norm == "default":
                if temb is not None:
                    # 将时间嵌入加到隐藏状态上
                    hidden_states = hidden_states + temb
                # 对隐藏状态进行第二次规范化
                hidden_states = self.norm2(hidden_states)
            elif self.time_embedding_norm == "scale_shift":
                # 如果时间嵌入为 None,抛出错误
                if temb is None:
                    raise ValueError(
                        f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
                    )
                # 将时间嵌入分割为缩放和偏移
                time_scale, time_shift = torch.chunk(temb, 2, dim=1)
                # 对隐藏状态进行第二次规范化
                hidden_states = self.norm2(hidden_states)
                # 应用缩放和偏移
                hidden_states = hidden_states * (1 + time_scale) + time_shift
            else:
                # 直接对隐藏状态进行第二次规范化
                hidden_states = self.norm2(hidden_states)
    
            # 应用非线性激活函数
            hidden_states = self.nonlinearity(hidden_states)
    
            # 应用 dropout 以增加正则化
            hidden_states = self.dropout(hidden_states)
            # 对隐藏状态进行第二次卷积操作
            hidden_states = self.conv2(hidden_states)
    
            # 如果存在卷积快捷连接
            if self.conv_shortcut is not None:
                # 对输入进行快捷卷积
                input_tensor = self.conv_shortcut(input_tensor)
    
            # 计算输出张量,结合输入和隐藏状态,并按输出缩放因子归一化
            output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
    
            # 返回输出张量
            return output_tensor
# unet_rl.py
# 定义一个函数,用于重新排列张量的维度
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
    # 如果张量的维度是 2,则在最后添加一个新维度
    if len(tensor.shape) == 2:
        return tensor[:, :, None]
    # 如果张量的维度是 3,则在第二维后添加一个新维度
    if len(tensor.shape) == 3:
        return tensor[:, :, None, :]
    # 如果张量的维度是 4,则取出第三维的第一个元素
    elif len(tensor.shape) == 4:
        return tensor[:, :, 0, :]
    # 如果维度不在 2, 3 或 4 之间,则抛出错误
    else:
        raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
# unet_rl.py
# 定义一个卷积块类,包含 1D 卷积、分组归一化和激活函数
class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    Parameters:
        inp_channels (`int`): 输入通道数。
        out_channels (`int`): 输出通道数。
        kernel_size (`int` or `tuple`): 卷积核的大小。
        n_groups (`int`, default `8`): 将通道分成的组数。
        activation (`str`, defaults to `mish`): 激活函数的名称。
    """
    # 初始化函数,定义卷积块的各个层
    def __init__(
        self,
        inp_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        n_groups: int = 8,
        activation: str = "mish",
    ):
        super().__init__()
        # 创建 1D 卷积层,设置填充以保持输出尺寸
        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
        # 创建分组归一化层
        self.group_norm = nn.GroupNorm(n_groups, out_channels)
        # 获取指定的激活函数
        self.mish = get_activation(activation)
    # 前向传播函数,定义数据流经网络的方式
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 通过卷积层处理输入
        intermediate_repr = self.conv1d(inputs)
        # 重新排列维度
        intermediate_repr = rearrange_dims(intermediate_repr)
        # 通过分组归一化处理
        intermediate_repr = self.group_norm(intermediate_repr)
        # 再次重新排列维度
        intermediate_repr = rearrange_dims(intermediate_repr)
        # 应用激活函数
        output = self.mish(intermediate_repr)
        # 返回最终输出
        return output
# unet_rl.py
# 定义一个残差时序块类,包含时序卷积
class ResidualTemporalBlock1D(nn.Module):
    """
    Residual 1D block with temporal convolutions.
    Parameters:
        inp_channels (`int`): 输入通道数。
        out_channels (`int`): 输出通道数。
        embed_dim (`int`): 嵌入维度。
        kernel_size (`int` or `tuple`): 卷积核的大小。
        activation (`str`, defaults `mish`): 可以选择合适的激活函数。
    """
    # 初始化函数,定义残差块的各个层
    def __init__(
        self,
        inp_channels: int,
        out_channels: int,
        embed_dim: int,
        kernel_size: Union[int, Tuple[int, int]] = 5,
        activation: str = "mish",
    ):
        super().__init__()
        # 创建输入卷积块
        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
        # 创建输出卷积块
        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
        # 获取指定的激活函数
        self.time_emb_act = get_activation(activation)
        # 创建线性层,将嵌入维度映射到输出通道数
        self.time_emb = nn.Linear(embed_dim, out_channels)
        # 创建残差卷积,如果输入通道数不等于输出通道数,则使用 1x1 卷积
        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )
    # 定义前向传播函数,接收输入张量和时间嵌入张量,返回输出张量
    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # 参数说明:inputs是输入数据,t是时间嵌入
        """
        Args:
            inputs : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]
    
        returns:
            out : [ batch_size x out_channels x horizon ]
        """
        # 对时间嵌入应用激活函数
        t = self.time_emb_act(t)
        # 将时间嵌入进行进一步处理
        t = self.time_emb(t)
        # 将输入经过初始卷积处理并与重排后的时间嵌入相加
        out = self.conv_in(inputs) + rearrange_dims(t)
        # 对合并后的结果进行输出卷积处理
        out = self.conv_out(out)
        # 返回卷积结果与残差卷积的和
        return out + self.residual_conv(inputs)
# 定义一个时间卷积层,适用于视频(图像序列)输入,主要代码来源于指定 GitHub 地址
class TemporalConvLayer(nn.Module):
    """
    时间卷积层,用于视频(图像序列)输入。代码主要复制自:
    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
    参数:
        in_dim (`int`): 输入通道数。
        out_dim (`int`): 输出通道数。
        dropout (`float`, *可选*, 默认值为 `0.0`): 使用的丢弃概率。
    """
    # 初始化方法,设置输入输出维度和其他参数
    def __init__(
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果没有提供 out_dim,则将其设为 in_dim
        out_dim = out_dim or in_dim
        # 保存输入和输出通道数
        self.in_dim = in_dim
        self.out_dim = out_dim
        # 卷积层构建
        self.conv1 = nn.Sequential(
            # 对输入通道进行分组归一化
            nn.GroupNorm(norm_num_groups, in_dim),
            # 应用 SiLU 激活函数
            nn.SiLU(),
            # 创建 3D 卷积层
            nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv2 = nn.Sequential(
            # 对输出通道进行分组归一化
            nn.GroupNorm(norm_num_groups, out_dim),
            # 应用 SiLU 激活函数
            nn.SiLU(),
            # 应用丢弃层
            nn.Dropout(dropout),
            # 创建 3D 卷积层
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv3 = nn.Sequential(
            # 对输出通道进行分组归一化
            nn.GroupNorm(norm_num_groups, out_dim),
            # 应用 SiLU 激活函数
            nn.SiLU(),
            # 应用丢弃层
            nn.Dropout(dropout),
            # 创建 3D 卷积层
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        self.conv4 = nn.Sequential(
            # 对输出通道进行分组归一化
            nn.GroupNorm(norm_num_groups, out_dim),
            # 应用 SiLU 激活函数
            nn.SiLU(),
            # 应用丢弃层
            nn.Dropout(dropout),
            # 创建 3D 卷积层
            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
        )
        # 将最后一层的参数归零,使卷积块成为恒等映射
        nn.init.zeros_(self.conv4[-1].weight)
        nn.init.zeros_(self.conv4[-1].bias)
    # 前向传播方法,定义数据如何通过网络流动
    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
        # 重塑输入的隐藏状态以适应卷积层的要求
        hidden_states = (
            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
        )
        # 保存输入的恒等映射
        identity = hidden_states
        # 通过第一个卷积层处理
        hidden_states = self.conv1(hidden_states)
        # 通过第二个卷积层处理
        hidden_states = self.conv2(hidden_states)
        # 通过第三个卷积层处理
        hidden_states = self.conv3(hidden_states)
        # 通过第四个卷积层处理
        hidden_states = self.conv4(hidden_states)
        # 将处理后的隐藏状态与恒等映射相加
        hidden_states = identity + hidden_states
        # 重塑输出以便返回
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
        )
        # 返回最终的隐藏状态
        return hidden_states
# 定义一个 Resnet 块
class TemporalResnetBlock(nn.Module):
    r"""
    一个 Resnet 块。
    # 参数文档
    Parameters:
        # 输入的通道数
        in_channels (`int`): The number of channels in the input.
        # 第一层 conv2d 的输出通道数,可选,默认为 None,表示与输入通道相同
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        # 时间步嵌入的通道数,可选,默认为 512
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        # 归一化使用的 epsilon,可选,默认为 1e-6
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
    """
    # 初始化方法
    def __init__(
        self,
        # 输入的通道数
        in_channels: int,
        # 输出的通道数,可选,默认为 None
        out_channels: Optional[int] = None,
        # 时间步嵌入的通道数,默认为 512
        temb_channels: int = 512,
        # 归一化使用的 epsilon,默认为 1e-6
        eps: float = 1e-6,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置输入通道数
        self.in_channels = in_channels
        # 如果输出通道数为 None,则使用输入通道数
        out_channels = in_channels if out_channels is None else out_channels
        # 设置输出通道数
        self.out_channels = out_channels
        # 定义卷积核大小
        kernel_size = (3, 1, 1)
        # 计算填充大小
        padding = [k // 2 for k in kernel_size]
        # 创建第一层的归一化层
        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
        # 创建第一层的卷积层
        self.conv1 = nn.Conv3d(
            # 输入通道数
            in_channels,
            # 输出通道数
            out_channels,
            # 卷积核大小
            kernel_size=kernel_size,
            # 步幅
            stride=1,
            # 填充大小
            padding=padding,
        )
        # 如果时间步嵌入通道数不为 None,则创建对应的线性层
        if temb_channels is not None:
            self.time_emb_proj = nn.Linear(temb_channels, out_channels)
        else:
            # 否则设置为 None
            self.time_emb_proj = None
        # 创建第二层的归一化层
        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
        # 创建 Dropout 层,比例为 0.0
        self.dropout = torch.nn.Dropout(0.0)
        # 创建第二层的卷积层
        self.conv2 = nn.Conv3d(
            # 输入通道数
            out_channels,
            # 输出通道数
            out_channels,
            # 卷积核大小
            kernel_size=kernel_size,
            # 步幅
            stride=1,
            # 填充大小
            padding=padding,
        )
        # 获取激活函数,这里使用的是 "silu"
        self.nonlinearity = get_activation("silu")
        # 判断是否需要使用输入的 shortcut
        self.use_in_shortcut = self.in_channels != out_channels
        # 初始化 shortcut 卷积层为 None
        self.conv_shortcut = None
        # 如果需要使用输入的 shortcut,则创建对应的卷积层
        if self.use_in_shortcut:
            self.conv_shortcut = nn.Conv3d(
                # 输入通道数
                in_channels,
                # 输出通道数
                out_channels,
                # 卷积核大小为 1
                kernel_size=1,
                # 步幅
                stride=1,
                # 填充为 0
                padding=0,
            )
    # 前向传播方法
    def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        # 将输入张量赋值给隐藏状态
        hidden_states = input_tensor
        # 进行第一次归一化
        hidden_states = self.norm1(hidden_states)
        # 应用非线性激活函数
        hidden_states = self.nonlinearity(hidden_states)
        # 进行第一次卷积操作
        hidden_states = self.conv1(hidden_states)
        # 如果时间步嵌入投影层存在
        if self.time_emb_proj is not None:
            # 应用非线性激活函数
            temb = self.nonlinearity(temb)
            # 通过线性层处理时间嵌入
            temb = self.time_emb_proj(temb)[:, :, :, None, None]
            # 调整维度顺序
            temb = temb.permute(0, 2, 1, 3, 4)
            # 将时间嵌入添加到隐藏状态
            hidden_states = hidden_states + temb
        # 进行第二次归一化
        hidden_states = self.norm2(hidden_states)
        # 应用非线性激活函数
        hidden_states = self.nonlinearity(hidden_states)
        # 应用 Dropout
        hidden_states = self.dropout(hidden_states)
        # 进行第二次卷积操作
        hidden_states = self.conv2(hidden_states)
        # 如果 shortcut 卷积层存在
        if self.conv_shortcut is not None:
            # 对输入张量应用 shortcut 卷积
            input_tensor = self.conv_shortcut(input_tensor)
        # 将输入张量与隐藏状态相加,得到输出张量
        output_tensor = input_tensor + hidden_states
        # 返回输出张量
        return output_tensor
# VideoResBlock
# 定义一个时空残差块的类,继承自 nn.Module
class SpatioTemporalResBlock(nn.Module):
    r"""
    一个时空残差网络块。
    参数:
        in_channels (`int`): 输入通道的数量。
        out_channels (`int`, *可选*, 默认为 `None`):
            第一个 conv2d 层的输出通道数量。如果为 None,和 `in_channels` 相同。
        temb_channels (`int`, *可选*, 默认为 `512`): 时间步嵌入的通道数量。
        eps (`float`, *可选*, 默认为 `1e-6`): 用于空间残差网络的 epsilon。
        temporal_eps (`float`, *可选*, 默认为 `eps`): 用于时间残差网络的 epsilon。
        merge_factor (`float`, *可选*, 默认为 `0.5`): 用于时间混合的合并因子。
        merge_strategy (`str`, *可选*, 默认为 `learned_with_images`):
            用于时间混合的合并策略。
        switch_spatial_to_temporal_mix (`bool`, *可选*, 默认为 `False`):
            如果为 `True`,则切换空间和时间混合。
    """
    # 初始化方法,定义类的属性
    def __init__(
        self,
        in_channels: int,  # 输入通道数量
        out_channels: Optional[int] = None,  # 输出通道数量,可选
        temb_channels: int = 512,  # 时间步嵌入通道数量,默认值为512
        eps: float = 1e-6,  # epsilon的默认值
        temporal_eps: Optional[float] = None,  # 时间残差网络的epsilon,默认为None
        merge_factor: float = 0.5,  # 合并因子的默认值
        merge_strategy="learned_with_images",  # 合并策略的默认值
        switch_spatial_to_temporal_mix: bool = False,  # 切换标志,默认为False
    ):
        # 调用父类初始化方法
        super().__init__()
        # 创建一个空间残差块实例
        self.spatial_res_block = ResnetBlock2D(
            in_channels=in_channels,  # 输入通道数量
            out_channels=out_channels,  # 输出通道数量
            temb_channels=temb_channels,  # 时间步嵌入通道数量
            eps=eps,  # epsilon的值
        )
        # 创建一个时间残差块实例
        self.temporal_res_block = TemporalResnetBlock(
            in_channels=out_channels if out_channels is not None else in_channels,  # 输入通道数量,依据输出通道数量决定
            out_channels=out_channels if out_channels is not None else in_channels,  # 输出通道数量
            temb_channels=temb_channels,  # 时间步嵌入通道数量
            eps=temporal_eps if temporal_eps is not None else eps,  # epsilon的值
        )
        # 创建一个时间混合器实例
        self.time_mixer = AlphaBlender(
            alpha=merge_factor,  # 合并因子
            merge_strategy=merge_strategy,  # 合并策略
            switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,  # 切换标志
        )
    # 前向传播方法,定义如何处理输入数据
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隐藏状态的张量输入
        temb: Optional[torch.Tensor] = None,  # 可选的时间步嵌入张量
        image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示张量
    ):
        # 获取图像帧的数量,即最后一个维度的大小
        num_frames = image_only_indicator.shape[-1]
        # 通过空间残差块处理隐藏状态
        hidden_states = self.spatial_res_block(hidden_states, temb)
        # 获取当前隐藏状态的批次大小、通道数、高度和宽度
        batch_frames, channels, height, width = hidden_states.shape
        # 计算每个批次的大小,即总帧数除以每个批次的帧数
        batch_size = batch_frames // num_frames
        # 重新调整隐藏状态的形状并进行维度转换,以便于后续处理
        hidden_states_mix = (
            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        )
        # 同样的调整隐藏状态的形状并进行维度转换
        hidden_states = (
            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
        )
        # 如果时间嵌入不为空,则调整其形状以匹配批次大小和帧数
        if temb is not None:
            temb = temb.reshape(batch_size, num_frames, -1)
        # 通过时间残差块处理隐藏状态
        hidden_states = self.temporal_res_block(hidden_states, temb)
        # 将空间和时间的隐藏状态混合
        hidden_states = self.time_mixer(
            x_spatial=hidden_states_mix,
            x_temporal=hidden_states,
            image_only_indicator=image_only_indicator,
        )
        # 重新排列维度并调整形状,以恢复到原始的隐藏状态格式
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个名为 AlphaBlender 的类,继承自 nn.Module
class AlphaBlender(nn.Module):
    r"""
    一个模块,用于混合空间和时间特征。
    参数:
        alpha (`float`): 混合因子的初始值。
        merge_strategy (`str`, *可选*, 默认值为 `learned_with_images`):
            用于时间混合的合并策略。
        switch_spatial_to_temporal_mix (`bool`, *可选*, 默认值为 `False`):
            如果为 `True`,则交换空间和时间混合。
    """
    # 定义可用的合并策略列表
    strategies = ["learned", "fixed", "learned_with_images"]
    # 初始化方法,设置参数和合并策略
    def __init__(
        self,
        alpha: float,  # 混合因子的初始值
        merge_strategy: str = "learned_with_images",  # 合并策略的默认值
        switch_spatial_to_temporal_mix: bool = False,  # 是否交换混合方式的标志
    ):
        # 调用父类构造函数
        super().__init__()
        # 保存合并策略
        self.merge_strategy = merge_strategy
        # 保存空间和时间混合的交换标志
        self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix  # 用于 TemporalVAE
        # 检查合并策略是否在可用策略中
        if merge_strategy not in self.strategies:
            raise ValueError(f"merge_strategy needs to be in {self.strategies}")
        # 如果合并策略为 "fixed",则注册固定混合因子
        if self.merge_strategy == "fixed":
            self.register_buffer("mix_factor", torch.Tensor([alpha]))  # 使用缓冲区注册固定值
        # 如果合并策略为 "learned" 或 "learned_with_images",则注册可学习的混合因子
        elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
            self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))  # 使用可学习参数注册
        else:
            # 如果合并策略未知,抛出错误
            raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
    # 获取当前的 alpha 值,基于合并策略和输入
    def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
        # 如果合并策略为 "fixed",直接使用 mix_factor
        if self.merge_strategy == "fixed":
            alpha = self.mix_factor
        # 如果合并策略为 "learned",使用 sigmoid 函数处理 mix_factor
        elif self.merge_strategy == "learned":
            alpha = torch.sigmoid(self.mix_factor)
        # 如果合并策略为 "learned_with_images",根据图像指示器计算 alpha
        elif self.merge_strategy == "learned_with_images":
            # 如果没有提供图像指示器,则抛出错误
            if image_only_indicator is None:
                raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
            # 根据 image_only_indicator 的布尔值选择 alpha 的值
            alpha = torch.where(
                image_only_indicator.bool(),  # 使用布尔索引
                torch.ones(1, 1, device=image_only_indicator.device),  # 图像对应的 alpha 为 1
                torch.sigmoid(self.mix_factor)[..., None],  # 其他情况下使用 sigmoid 处理后的值
            )
            # (batch, channel, frames, height, width)
            if ndims == 5:
                alpha = alpha[:, None, :, None, None]  # 调整维度以适应 5D 输入
            # (batch*frames, height*width, channels)
            elif ndims == 3:
                alpha = alpha.reshape(-1)[:, None, None]  # 重塑为 3D 输入
            else:
                # 如果维度不符合预期,抛出错误
                raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
        else:
            # 如果合并策略未实现,抛出错误
            raise NotImplementedError
        # 返回计算得到的 alpha 值
        return alpha
    # 前向传播方法,用于处理输入数据
    def forward(
        self,
        x_spatial: torch.Tensor,  # 空间特征输入
        x_temporal: torch.Tensor,  # 时间特征输入
        image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示器
    # 定义一个函数的返回类型为 torch.Tensor
        ) -> torch.Tensor:
        # 获取 alpha 值,依据图像指示器和空间维度
            alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
        # 将 alpha 转换为与 x_spatial 相同的数据类型
            alpha = alpha.to(x_spatial.dtype)
    
        # 如果开启空间到时间混合的切换
            if self.switch_spatial_to_temporal_mix:
        # 将 alpha 值取反
                alpha = 1.0 - alpha
    
        # 根据 alpha 值进行空间和时间数据的加权组合
            x = alpha * x_spatial + (1.0 - alpha) * x_temporal
        # 返回合成后的数据
            return x
.\diffusers\models\resnet_flax.py
# 版权所有 2024 The HuggingFace Team。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下位置获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件
# 是按“原样”基础分发的,不提供任何明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和
# 许可证限制的具体语言。
import flax.linen as nn  # 导入 flax.linen 模块以构建神经网络
import jax  # 导入 jax,用于高效数值计算
import jax.numpy as jnp  # 导入 jax 的 numpy 作为 jnp 以进行数组操作
class FlaxUpsample2D(nn.Module):  # 定义一个用于上采样的 2D 模块
    out_channels: int  # 输出通道数的类型注解
    dtype: jnp.dtype = jnp.float32  # 数据类型默认为 float32
    def setup(self):  # 设置模块参数
        self.conv = nn.Conv(  # 创建卷积层
            self.out_channels,  # 设置输出通道数
            kernel_size=(3, 3),  # 卷积核大小为 3x3
            strides=(1, 1),  # 步幅为 1
            padding=((1, 1), (1, 1)),  # 在每个边缘填充 1 像素
            dtype=self.dtype,  # 设置数据类型
        )
    def __call__(self, hidden_states):  # 定义模块的前向传播
        batch, height, width, channels = hidden_states.shape  # 获取输入张量的维度
        hidden_states = jax.image.resize(  # 使用 nearest 方法调整图像大小
            hidden_states,
            shape=(batch, height * 2, width * 2, channels),  # 输出形状为高和宽各翻倍
            method="nearest",  # 使用最近邻插值法
        )
        hidden_states = self.conv(hidden_states)  # 对调整后的张量应用卷积层
        return hidden_states  # 返回卷积后的结果
class FlaxDownsample2D(nn.Module):  # 定义一个用于下采样的 2D 模块
    out_channels: int  # 输出通道数的类型注解
    dtype: jnp.dtype = jnp.float32  # 数据类型默认为 float32
    def setup(self):  # 设置模块参数
        self.conv = nn.Conv(  # 创建卷积层
            self.out_channels,  # 设置输出通道数
            kernel_size=(3, 3),  # 卷积核大小为 3x3
            strides=(2, 2),  # 步幅为 2
            padding=((1, 1), (1, 1)),  # 在每个边缘填充 1 像素
            dtype=self.dtype,  # 设置数据类型
        )
    def __call__(self, hidden_states):  # 定义模块的前向传播
        # pad = ((0, 0), (0, 1), (0, 1), (0, 0))  # 为高和宽维度填充
        # hidden_states = jnp.pad(hidden_states, pad_width=pad)  # 使用填充来调整张量大小
        hidden_states = self.conv(hidden_states)  # 对输入张量应用卷积层
        return hidden_states  # 返回卷积后的结果
class FlaxResnetBlock2D(nn.Module):  # 定义一个用于 2D ResNet 块的模块
    in_channels: int  # 输入通道数的类型注解
    out_channels: int = None  # 输出通道数,默认为 None
    dropout_prob: float = 0.0  # dropout 概率,默认为 0
    use_nin_shortcut: bool = None  # 是否使用 NIN 短路,默认为 None
    dtype: jnp.dtype = jnp.float32  # 数据类型默认为 float32
    # 设置模型的各个层和参数
        def setup(self):
            # 确定输出通道数,若未指定,则使用输入通道数
            out_channels = self.in_channels if self.out_channels is None else self.out_channels
    
            # 初始化第一个归一化层,使用32个组和小于1e-5的epsilon
            self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
            # 初始化第一个卷积层,设置卷积参数
            self.conv1 = nn.Conv(
                out_channels,
                kernel_size=(3, 3),
                strides=(1, 1),
                padding=((1, 1), (1, 1)),
                dtype=self.dtype,
            )
    
            # 初始化时间嵌入投影层,输出通道数与dtype
            self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
    
            # 初始化第二个归一化层
            self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
            # 初始化丢弃层,设置丢弃概率
            self.dropout = nn.Dropout(self.dropout_prob)
            # 初始化第二个卷积层,设置卷积参数
            self.conv2 = nn.Conv(
                out_channels,
                kernel_size=(3, 3),
                strides=(1, 1),
                padding=((1, 1), (1, 1)),
                dtype=self.dtype,
            )
    
            # 确定是否使用1x1卷积快捷连接,依据输入和输出通道数
            use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
    
            # 初始化快捷连接卷积层为None
            self.conv_shortcut = None
            # 如果需要,初始化1x1卷积快捷连接层
            if use_nin_shortcut:
                self.conv_shortcut = nn.Conv(
                    out_channels,
                    kernel_size=(1, 1),
                    strides=(1, 1),
                    padding="VALID",
                    dtype=self.dtype,
                )
    
        # 定义前向传播方法
        def __call__(self, hidden_states, temb, deterministic=True):
            # 保存输入作为残差
            residual = hidden_states
            # 对输入进行归一化处理
            hidden_states = self.norm1(hidden_states)
            # 应用Swish激活函数
            hidden_states = nn.swish(hidden_states)
            # 通过第一个卷积层处理
            hidden_states = self.conv1(hidden_states)
    
            # 对时间嵌入进行Swish激活处理
            temb = self.time_emb_proj(nn.swish(temb))
            # 扩展时间嵌入的维度
            temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
            # 将时间嵌入加到隐藏状态中
            hidden_states = hidden_states + temb
    
            # 对隐藏状态进行第二次归一化
            hidden_states = self.norm2(hidden_states)
            # 应用Swish激活函数
            hidden_states = nn.swish(hidden_states)
            # 应用丢弃层
            hidden_states = self.dropout(hidden_states, deterministic)
            # 通过第二个卷积层处理
            hidden_states = self.conv2(hidden_states)
    
            # 如果存在快捷连接卷积层,则对残差进行处理
            if self.conv_shortcut is not None:
                residual = self.conv_shortcut(residual)
    
            # 返回隐藏状态和残差的和
            return hidden_states + residual
# 版权声明,指明该文件的作者和许可证信息
# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证,版本 2.0(“许可证”)进行授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意另有规定,按照许可证分发的软件
# 是以“原样”基础分发,不提供任何形式的担保或条件,
# 明示或暗示。
# 请参阅许可证以获取有关权限和
# 限制的具体语言。
# 从 typing 模块导入 Any、Dict 和 Union 类型
from typing import Any, Dict, Union
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
import torch.nn.functional as F
# 从配置和工具模块导入所需类和函数
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
    Attention,
    AttentionProcessor,
    AuraFlowAttnProcessor2_0,
    FusedAuraFlowAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormZero, FP32LayerNorm
# 创建一个日志记录器,便于记录信息和错误
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
# 定义一个函数,用于找到 n 的下一个可被 k 整除的数
def find_multiple(n: int, k: int) -> int:
    # 如果 n 可以被 k 整除,直接返回 n
    if n % k == 0:
        return n
    # 否则返回下一个可被 k 整除的数
    return n + k - (n % k)
# 定义 AuraFlowPatchEmbed 类,表示一个嵌入模块
# 不使用卷积来进行投影,同时使用学习到的位置嵌入
class AuraFlowPatchEmbed(nn.Module):
    # 初始化函数,设置嵌入模块的参数
    def __init__(
        self,
        height=224,  # 输入图像高度
        width=224,   # 输入图像宽度
        patch_size=16,  # 每个补丁的大小
        in_channels=3,   # 输入通道数(例如,RGB图像)
        embed_dim=768,   # 嵌入维度
        pos_embed_max_size=None,  # 最大位置嵌入大小
    ):
        super().__init__()
        # 计算补丁数量
        self.num_patches = (height // patch_size) * (width // patch_size)
        self.pos_embed_max_size = pos_embed_max_size
        # 定义线性层,将补丁投影到嵌入空间
        self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
        # 定义位置嵌入参数,随机初始化
        self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
        # 保存补丁大小和图像的补丁高度和宽度
        self.patch_size = patch_size
        self.height, self.width = height // patch_size, width // patch_size
        # 保存基础大小
        self.base_size = height // patch_size
    # 根据输入的高度和宽度选择基于维度的嵌入索引
    def pe_selection_index_based_on_dim(self, h, w):
        # 计算基于补丁大小的高度和宽度
        h_p, w_p = h // self.patch_size, w // self.patch_size
        # 生成原始位置嵌入的索引
        original_pe_indexes = torch.arange(self.pos_embed.shape[1])
        # 计算最大高度和宽度
        h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
        # 将索引视图调整为二维网格
        original_pe_indexes = original_pe_indexes.view(h_max, w_max)
        # 计算起始行和结束行
        starth = h_max // 2 - h_p // 2
        endh = starth + h_p
        # 计算起始列和结束列
        startw = w_max // 2 - w_p // 2
        endw = startw + w_p
        # 选择指定范围的原始位置嵌入索引
        original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
        # 返回展平的索引
        return original_pe_indexes.flatten()
    
    # 前向传播函数
    def forward(self, latent):
        # 获取输入的批量大小、通道数、高度和宽度
        batch_size, num_channels, height, width = latent.size()
        # 调整潜在张量的形状以适应补丁结构
        latent = latent.view(
            batch_size,
            num_channels,
            height // self.patch_size,
            self.patch_size,
            width // self.patch_size,
            self.patch_size,
        )
        # 重新排列维度并展平
        latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
        # 应用投影层
        latent = self.proj(latent)
        # 获取嵌入索引
        pe_index = self.pe_selection_index_based_on_dim(height, width)
        # 返回潜在张量与位置嵌入的和
        return latent + self.pos_embed[:, pe_index]
# 取自原始的 Aura 流推理代码。
# 我们的前馈网络只使用 GELU,而 Aura 使用 SiLU。
class AuraFlowFeedForward(nn.Module):
    # 初始化方法,接收输入维度和隐藏层维度(如果未提供则设为 4 倍输入维度)
    def __init__(self, dim, hidden_dim=None) -> None:
        # 调用父类构造函数
        super().__init__()
        # 如果没有提供隐藏层维度,则计算为输入维度的 4 倍
        if hidden_dim is None:
            hidden_dim = 4 * dim
        # 计算最终隐藏层维度,取隐藏层维度的 2/3
        final_hidden_dim = int(2 * hidden_dim / 3)
        # 将最终隐藏层维度调整为 256 的倍数
        final_hidden_dim = find_multiple(final_hidden_dim, 256)
        # 创建第一个线性层,不使用偏置
        self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
        # 创建第二个线性层,不使用偏置
        self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
        # 创建输出投影层,不使用偏置
        self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
    # 前向传播方法
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 经过第一个线性层并使用 SiLU 激活函数,然后与第二个线性层的输出相乘
        x = F.silu(self.linear_1(x)) * self.linear_2(x)
        # 经过输出投影层
        x = self.out_projection(x)
        # 返回处理后的张量
        return x
class AuraFlowPreFinalBlock(nn.Module):
    # 初始化方法,接收嵌入维度和条件嵌入维度
    def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
        # 调用父类构造函数
        super().__init__()
        # 定义 SiLU 激活函数
        self.silu = nn.SiLU()
        # 创建线性层,输出维度为嵌入维度的两倍,不使用偏置
        self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
    # 前向传播方法
    def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
        # 对条件嵌入应用 SiLU 激活并转换为与 x 相同的数据类型,然后通过线性层
        emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
        # 将嵌入分成两个部分:缩放和偏移
        scale, shift = torch.chunk(emb, 2, dim=1)
        # 更新 x,使用缩放和偏移进行调整
        x = x * (1 + scale)[:, None, :] + shift[:, None, :]
        # 返回调整后的张量
        return x
@maybe_allow_in_graph
class AuraFlowSingleTransformerBlock(nn.Module):
    """类似于 `AuraFlowJointTransformerBlock`,但只使用一个 DiT 而不是 MMDiT。"""
    # 初始化方法,接收输入维度、注意力头数量和每个头的维度
    def __init__(self, dim, num_attention_heads, attention_head_dim):
        # 调用父类构造函数
        super().__init__()
        # 创建层归一化对象,设置维度和不使用偏置,归一化类型为 "fp32_layer_norm"
        self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
        # 创建注意力处理器
        processor = AuraFlowAttnProcessor2_0()
        # 创建注意力机制对象,设置参数
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            qk_norm="fp32_layer_norm",
            out_dim=dim,
            bias=False,
            out_bias=False,
            processor=processor,
        )
        # 创建第二层归一化对象,设置维度和不使用偏置
        self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
        # 创建前馈网络对象,隐藏层维度为输入维度的 4 倍
        self.ff = AuraFlowFeedForward(dim, dim * 4)
    # 前向传播方法,接收隐藏状态和条件嵌入
    def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
        # 保存输入的残差
        residual = hidden_states
        # 进行归一化和投影
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
        # 经过注意力机制处理
        attn_output = self.attn(hidden_states=norm_hidden_states)
        # 将注意力输出与残差相结合,并进行第二次归一化
        hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
        # 更新 hidden_states,使用缩放和偏移
        hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
        # 经过前馈网络处理
        ff_output = self.ff(hidden_states)
        # 更新 hidden_states,使用门控机制
        hidden_states = gate_mlp.unsqueeze(1) * ff_output
        # 将残差与更新后的 hidden_states 相加
        hidden_states = residual + hidden_states
        # 返回最终的隐藏状态
        return hidden_states
@maybe_allow_in_graph
# 定义 AuraFlow 的 Transformer 块类,继承自 nn.Module
class AuraFlowJointTransformerBlock(nn.Module):
    r"""
    Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
        * QK Norm in the attention blocks
        * No bias in the attention blocks
        * Most LayerNorms are in FP32
    Parameters:
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        is_last (`bool`): Boolean to determine if this is the last block in the model.
    """
    # 初始化方法,接受输入维度、注意力头数和每个头的维度
    def __init__(self, dim, num_attention_heads, attention_head_dim):
        # 调用父类构造函数
        super().__init__()
        # 创建第一个层归一化对象,不使用偏置,采用 FP32 类型
        self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
        # 创建上下文的层归一化对象,同样不使用偏置
        self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
        # 实例化注意力处理器
        processor = AuraFlowAttnProcessor2_0()
        # 创建注意力机制对象,配置查询维度、头数等参数
        self.attn = Attention(
            query_dim=dim,                       # 查询向量的维度
            cross_attention_dim=None,            # 交叉注意力的维度,未使用
            added_kv_proj_dim=dim,               # 添加的键值投影维度
            added_proj_bias=False,                # 不使用添加的偏置
            dim_head=attention_head_dim,         # 每个头的维度
            heads=num_attention_heads,            # 注意力头的数量
            qk_norm="fp32_layer_norm",           # QK 的归一化类型
            out_dim=dim,                         # 输出维度
            bias=False,                           # 不使用偏置
            out_bias=False,                       # 不使用输出偏置
            processor=processor,                  # 传入的处理器
            context_pre_only=False,               # 不仅仅使用上下文
        )
        # 创建第二个层归一化对象,不使用元素级的仿射变换和偏置
        self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
        # 创建前馈神经网络对象,输出维度是输入维度的四倍
        self.ff = AuraFlowFeedForward(dim, dim * 4)
        # 创建上下文的第二个层归一化对象
        self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
        # 创建上下文的前馈神经网络对象
        self.ff_context = AuraFlowFeedForward(dim, dim * 4)
    # 定义前向传播方法,接受隐藏状态、编码器的隐藏状态和时间嵌入
    def forward(
        self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
    ):
        # 初始化残差为当前的隐藏状态
        residual = hidden_states
        # 初始化残差上下文为编码器的隐藏状态
        residual_context = encoder_hidden_states
        # 归一化和投影操作
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
        # 对编码器隐藏状态进行归一化和投影
        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
            encoder_hidden_states, emb=temb
        )
        # 注意力机制计算
        attn_output, context_attn_output = self.attn(
            hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
        )
        # 处理注意力输出以更新 `hidden_states`
        hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
        # 对隐藏状态进行缩放和偏移
        hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
        # 使用前馈网络处理隐藏状态
        hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
        # 将更新后的隐藏状态与残差相加
        hidden_states = residual + hidden_states
        # 处理注意力输出以更新 `encoder_hidden_states`
        encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
        # 对编码器隐藏状态进行缩放和偏移
        encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
        # 使用前馈网络处理编码器隐藏状态
        encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
        # 将更新后的编码器隐藏状态与残差上下文相加
        encoder_hidden_states = residual_context + encoder_hidden_states
        # 返回编码器隐藏状态和更新后的隐藏状态
        return encoder_hidden_states, hidden_states
# 定义一个2D Transformer模型类,继承自ModelMixin和ConfigMixin
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
    r"""
    介绍AuraFlow中提出的2D Transformer模型(https://blog.fal.ai/auraflow/)。
    参数:
        sample_size (`int`): 潜在图像的宽度。由于用于学习位置嵌入,因此在训练过程中是固定的。
        patch_size (`int`): 将输入数据转换为小块的大小。
        in_channels (`int`, *optional*, defaults to 16): 输入通道的数量。
        num_mmdit_layers (`int`, *optional*, defaults to 4): 要使用的MMDiT Transformer块的层数。
        num_single_dit_layers (`int`, *optional*, defaults to 4):
            要使用的Transformer块的层数。这些块使用连接的图像和文本表示。
        attention_head_dim (`int`, *optional*, defaults to 64): 每个头的通道数。
        num_attention_heads (`int`, *optional*, defaults to 18): 用于多头注意力的头数。
        joint_attention_dim (`int`, *optional*): 要使用的`encoder_hidden_states`维度数量。
        caption_projection_dim (`int`): 投影`encoder_hidden_states`时使用的维度数量。
        out_channels (`int`, defaults to 16): 输出通道的数量。
        pos_embed_max_size (`int`, defaults to 4096): 从图像潜在值中嵌入的最大位置数量。
    """
    # 支持梯度检查点
    _supports_gradient_checkpointing = True
    # 将该方法注册到配置中
    @register_to_config
    def __init__(
        # 潜在图像的宽度,默认为64
        sample_size: int = 64,
        # 输入数据的小块大小,默认为2
        patch_size: int = 2,
        # 输入通道的数量,默认为4
        in_channels: int = 4,
        # MMDiT Transformer块的层数,默认为4
        num_mmdit_layers: int = 4,
        # 单一Transformer块的层数,默认为32
        num_single_dit_layers: int = 32,
        # 每个头的通道数,默认为256
        attention_head_dim: int = 256,
        # 多头注意力的头数,默认为12
        num_attention_heads: int = 12,
        # `encoder_hidden_states`的维度数量,默认为2048
        joint_attention_dim: int = 2048,
        # 投影时使用的维度数量,默认为3072
        caption_projection_dim: int = 3072,
        # 输出通道的数量,默认为4
        out_channels: int = 4,
        # 从图像潜在值中嵌入的最大位置数量,默认为1024
        pos_embed_max_size: int = 1024,
    ):
        # 初始化父类
        super().__init__()
        # 设置默认输出通道为输入通道数
        default_out_channels = in_channels
        # 如果提供了输出通道数,则使用该值,否则使用默认值
        self.out_channels = out_channels if out_channels is not None else default_out_channels
        # 计算内部维度为注意力头数与每个注意力头维度的乘积
        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
        # 创建位置嵌入对象,使用配置中的样本大小和补丁大小
        self.pos_embed = AuraFlowPatchEmbed(
            height=self.config.sample_size,
            width=self.config.sample_size,
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            embed_dim=self.inner_dim,
            pos_embed_max_size=pos_embed_max_size,
        )
        # 创建线性层用于上下文嵌入,不使用偏置
        self.context_embedder = nn.Linear(
            self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
        )
        # 创建时间步嵌入对象,配置频道数和频率下采样
        self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
        # 创建时间步投影层,输入频道数为256,嵌入维度为内部维度
        self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
        # 创建联合变换器模块列表,根据配置中的层数
        self.joint_transformer_blocks = nn.ModuleList(
            [
                AuraFlowJointTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                )
                for i in range(self.config.num_mmdit_layers)
            ]
        )
        # 创建单一变换器模块列表,根据配置中的层数
        self.single_transformer_blocks = nn.ModuleList(
            [
                AuraFlowSingleTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                )
                for _ in range(self.config.num_single_dit_layers)
            ]
        )
        # 创建最终块的归一化层,维度为内部维度
        self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
        # 创建线性投影层,将内部维度映射到补丁大小平方与输出通道数的乘积,不使用偏置
        self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
        # https://arxiv.org/abs/2309.16588
        # 防止注意力图中的伪影
        self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
        # 设置梯度检查点为 False
        self.gradient_checkpointing = False
    @property
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的属性
    # 定义一个返回注意力处理器的函数,返回类型为字典
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 初始化一个空字典用于存储处理器
        processors = {}
    
        # 定义递归函数用于添加处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否具有获取处理器的方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为处理器名称
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用以添加子模块的处理器
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            # 返回处理器字典
            return processors
    
        # 遍历当前对象的子模块
        for name, module in self.named_children():
            # 调用递归函数以添加所有子模块的处理器
            fn_recursive_add_processors(name, module, processors)
    
        # 返回包含所有处理器的字典
        return processors
    
    # 定义设置注意力处理器的函数
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.
    
        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.
    
                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 检查传入的处理器字典长度是否与注意力层数量一致
        if isinstance(processor, dict) and len(processor) != count:
            # 如果不一致,抛出错误
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )
    
        # 定义递归函数用于设置处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否具有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置处理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中获取并设置对应的处理器
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历子模块
            for sub_name, child in module.named_children():
                # 递归调用以设置子模块的处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的子模块
        for name, module in self.named_children():
            # 调用递归函数以设置所有子模块的处理器
            fn_recursive_attn_processor(name, module, processor)
    
    # 该函数用于融合注意力层中的 QKV 投影
    # 定义一个方法以启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
        <Tip warning={true}>
        This API is 🧪 experimental.
        </Tip>
        """
        # 初始化原始的注意力处理器为 None
        self.original_attn_processors = None
        # 遍历所有的注意力处理器
        for _, attn_processor in self.attn_processors.items():
            # 如果注意力处理器类名中包含 "Added",则抛出错误
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
        # 保存当前的注意力处理器以便后续恢复
        self.original_attn_processors = self.attn_processors
        # 遍历所有模块以查找注意力模块
        for module in self.modules():
            # 检查模块是否为 Attention 类型
            if isinstance(module, Attention):
                # 对注意力模块进行融合投影处理
                module.fuse_projections(fuse=True)
        # 设置新的注意力处理器为融合的处理器
        self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
    # 从 UNet2DConditionModel 类复制的方法,用于取消融合的 QKV 投影
    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.
        <Tip warning={true}>
        This API is 🧪 experimental.
        </Tip>
        """
        # 如果原始的注意力处理器不为 None,则恢复为原始处理器
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)
    # 定义一个方法以设置模块的梯度检查点
    def _set_gradient_checkpointing(self, module, value=False):
        # 检查模块是否具有梯度检查点属性
        if hasattr(module, "gradient_checkpointing"):
            # 将梯度检查点属性设置为给定值
            module.gradient_checkpointing = value
    # 定义前向传播方法
    def forward(
        # 接收隐藏状态的浮点张量
        hidden_states: torch.FloatTensor,
        # 可选的编码器隐藏状态的浮点张量
        encoder_hidden_states: torch.FloatTensor = None,
        # 可选的时间步长的长整型张量
        timestep: torch.LongTensor = None,
        # 是否返回字典格式的标志,默认为 True
        return_dict: bool = True,
# 版权声明,说明代码的版权所有者和使用许可
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# 所有权利保留。
#
# 根据 Apache 许可证,第 2.0 版("许可证")进行授权;
# 除非遵循许可证,否则您不能使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,
# 根据许可证分发的软件是按“原样”提供的,不附带任何明示或暗示的担保或条件。
# 有关许可证下特定权限和限制的信息,请参阅许可证。
# 从 typing 模块导入所需的类型
from typing import Any, Dict, Optional, Tuple, Union
# 导入 PyTorch 库
import torch
# 从 PyTorch 导入神经网络模块
from torch import nn
# 导入其他模块中的工具和类
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
# 创建日志记录器,以便在模块中记录信息
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
# 使用装饰器,允许在图计算中可能的功能
@maybe_allow_in_graph
# 定义一个名为 CogVideoXBlock 的类,继承自 nn.Module
class CogVideoXBlock(nn.Module):
    r"""
    在 [CogVideoX](https://github.com/THUDM/CogVideo) 模型中使用的 Transformer 块。
    # 定义函数参数的文档字符串,描述各个参数的用途
    Parameters:
        dim (`int`):  # 输入和输出的通道数
            The number of channels in the input and output.
        num_attention_heads (`int`):  # 多头注意力中使用的头数
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`):  # 每个头的通道数
            The number of channels in each head.
        time_embed_dim (`int`):  # 时间步嵌入的通道数
            The number of channels in timestep embedding.
        dropout (`float`, defaults to `0.0`):  # 使用的丢弃概率
            The dropout probability to use.
        activation_fn (`str`, defaults to `"gelu-approximate"`):  # 前馈网络中使用的激活函数
            Activation function to be used in feed-forward.
        attention_bias (`bool`, defaults to `False`):  # 是否在注意力投影层使用偏置
            Whether or not to use bias in attention projection layers.
        qk_norm (`bool`, defaults to `True`):  # 是否在注意力中查询和键的投影后使用归一化
            Whether or not to use normalization after query and key projections in Attention.
        norm_elementwise_affine (`bool`, defaults to `True`):  # 是否使用可学习的逐元素仿射参数进行归一化
            Whether to use learnable elementwise affine parameters for normalization.
        norm_eps (`float`, defaults to `1e-5`):  # 归一化层的 epsilon 值
            Epsilon value for normalization layers.
        final_dropout (`bool`, defaults to `False`):  # 是否在最后的前馈层后应用最终的丢弃
            Whether to apply a final dropout after the last feed-forward layer.
        ff_inner_dim (`int`, *optional*, defaults to `None`):  # 前馈层的自定义隐藏维度
            Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
        ff_bias (`bool`, defaults to `True`):  # 是否在前馈层中使用偏置
            Whether or not to use bias in Feed-forward layer.
        attention_out_bias (`bool`, defaults to `True`):  # 是否在注意力输出投影层中使用偏置
            Whether or not to use bias in Attention output projection layer.
    """  # 结束文档字符串
    def __init__(  # 定义构造函数
        self,  # 实例自身
        dim: int,  # 输入和输出通道数
        num_attention_heads: int,  # 多头注意力中头数
        attention_head_dim: int,  # 每个头的通道数
        time_embed_dim: int,  # 时间步嵌入通道数
        dropout: float = 0.0,  # 默认丢弃概率
        activation_fn: str = "gelu-approximate",  # 默认激活函数
        attention_bias: bool = False,  # 默认不使用注意力偏置
        qk_norm: bool = True,  # 默认使用查询和键的归一化
        norm_elementwise_affine: bool = True,  # 默认使用逐元素仿射参数
        norm_eps: float = 1e-5,  # 默认归一化的 epsilon 值
        final_dropout: bool = True,  # 默认使用最终丢弃
        ff_inner_dim: Optional[int] = None,  # 前馈层的可选隐藏维度
        ff_bias: bool = True,  # 默认使用前馈层的偏置
        attention_out_bias: bool = True,  # 默认使用注意力输出层的偏置
    ):
        # 调用父类初始化方法
        super().__init__()
        # 1. Self Attention
        # 创建归一化层,处理时间嵌入维度和特征维度
        self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
        # 创建自注意力机制,配置查询维度和头数等参数
        self.attn1 = Attention(
            query_dim=dim,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            qk_norm="layer_norm" if qk_norm else None,
            eps=1e-6,
            bias=attention_bias,
            out_bias=attention_out_bias,
            processor=CogVideoXAttnProcessor2_0(),
        )
        # 2. Feed Forward
        # 创建另一个归一化层,用于后续的前馈网络
        self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
        # 创建前馈网络,配置隐藏层维度及其他超参数
        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            inner_dim=ff_inner_dim,
            bias=ff_bias,
        )
    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temb: torch.Tensor,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        # 获取编码器隐藏状态的序列长度
        text_seq_length = encoder_hidden_states.size(1)
        # norm & modulate
        # 对输入的隐藏状态和编码器状态进行归一化和调制
        norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
            hidden_states, encoder_hidden_states, temb
        )
        # attention
        # 执行自注意力机制,计算新的隐藏状态
        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )
        # 更新隐藏状态和编码器隐藏状态,结合注意力输出
        hidden_states = hidden_states + gate_msa * attn_hidden_states
        encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
        # norm & modulate
        # 再次进行归一化和调制
        norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
            hidden_states, encoder_hidden_states, temb
        )
        # feed-forward
        # 将归一化后的隐藏状态和编码器状态连接,输入前馈网络
        norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
        ff_output = self.ff(norm_hidden_states)
        # 更新隐藏状态和编码器状态,结合前馈网络输出
        hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
        encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
        # 返回更新后的隐藏状态和编码器状态
        return hidden_states, encoder_hidden_states
# 定义一个用于视频数据的 Transformer 模型,继承自 ModelMixin 和 ConfigMixin
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
    """
    A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
    """
    # 设置支持梯度检查点
    _supports_gradient_checkpointing = True
    # 注册到配置中的初始化方法,定义多个超参数
    @register_to_config
    def __init__(
        # 注意力头的数量,默认值为 30
        num_attention_heads: int = 30,
        # 每个注意力头的维度,默认值为 64
        attention_head_dim: int = 64,
        # 输入通道的数量,默认值为 16
        in_channels: int = 16,
        # 输出通道的数量,可选,默认值为 16
        out_channels: Optional[int] = 16,
        # 是否翻转正弦到余弦,默认值为 True
        flip_sin_to_cos: bool = True,
        # 频率偏移量,默认值为 0
        freq_shift: int = 0,
        # 时间嵌入维度,默认值为 512
        time_embed_dim: int = 512,
        # 文本嵌入维度,默认值为 4096
        text_embed_dim: int = 4096,
        # 层的数量,默认值为 30
        num_layers: int = 30,
        # dropout 概率,默认值为 0.0
        dropout: float = 0.0,
        # 是否使用注意力偏置,默认值为 True
        attention_bias: bool = True,
        # 采样宽度,默认值为 90
        sample_width: int = 90,
        # 采样高度,默认值为 60
        sample_height: int = 60,
        # 采样帧数,默认值为 49
        sample_frames: int = 49,
        # 补丁大小,默认值为 2
        patch_size: int = 2,
        # 时间压缩比例,默认值为 4
        temporal_compression_ratio: int = 4,
        # 最大文本序列长度,默认值为 226
        max_text_seq_length: int = 226,
        # 激活函数类型,默认值为 "gelu-approximate"
        activation_fn: str = "gelu-approximate",
        # 时间步激活函数类型,默认值为 "silu"
        timestep_activation_fn: str = "silu",
        # 是否使用元素逐个仿射的归一化,默认值为 True
        norm_elementwise_affine: bool = True,
        # 归一化的 epsilon 值,默认值为 1e-5
        norm_eps: float = 1e-5,
        # 空间插值缩放因子,默认值为 1.875
        spatial_interpolation_scale: float = 1.875,
        # 时间插值缩放因子,默认值为 1.0
        temporal_interpolation_scale: float = 1.0,
        # 是否使用旋转位置嵌入,默认值为 False
        use_rotary_positional_embeddings: bool = False,
        # 是否使用学习的位置嵌入,默认值为 False
        use_learned_positional_embeddings: bool = False,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 计算内部维度,等于注意力头数与每个头的维度乘积
        inner_dim = num_attention_heads * attention_head_dim
        # 检查位置嵌入的使用情况,如果不支持则抛出错误
        if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
            raise ValueError(
                "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
                "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
                "issue at https://github.com/huggingface/diffusers/issues."
            )
        # 1. 创建补丁嵌入层
        self.patch_embed = CogVideoXPatchEmbed(
            # 设置补丁大小
            patch_size=patch_size,
            # 输入通道数
            in_channels=in_channels,
            # 嵌入维度
            embed_dim=inner_dim,
            # 文本嵌入维度
            text_embed_dim=text_embed_dim,
            # 是否使用偏置
            bias=True,
            # 样本宽度
            sample_width=sample_width,
            # 样本高度
            sample_height=sample_height,
            # 样本帧数
            sample_frames=sample_frames,
            # 时间压缩比
            temporal_compression_ratio=temporal_compression_ratio,
            # 最大文本序列长度
            max_text_seq_length=max_text_seq_length,
            # 空间插值缩放
            spatial_interpolation_scale=spatial_interpolation_scale,
            # 时间插值缩放
            temporal_interpolation_scale=temporal_interpolation_scale,
            # 使用位置嵌入
            use_positional_embeddings=not use_rotary_positional_embeddings,
            # 使用学习的位置嵌入
            use_learned_positional_embeddings=use_learned_positional_embeddings,
        )
        # 创建嵌入丢弃层
        self.embedding_dropout = nn.Dropout(dropout)
        # 2. 创建时间嵌入
        self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
        # 创建时间步嵌入层
        self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
        # 3. 定义时空变换器块
        self.transformer_blocks = nn.ModuleList(
            [
                # 创建多个变换器块
                CogVideoXBlock(
                    dim=inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    time_embed_dim=time_embed_dim,
                    dropout=dropout,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                )
                # 根据层数重复创建变换器块
                for _ in range(num_layers)
            ]
        )
        # 创建最终的层归一化
        self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
        # 4. 输出块的定义
        self.norm_out = AdaLayerNorm(
            # 嵌入维度
            embedding_dim=time_embed_dim,
            # 输出维度
            output_dim=2 * inner_dim,
            # 是否使用元素级别的归一化
            norm_elementwise_affine=norm_elementwise_affine,
            # 归一化的epsilon值
            norm_eps=norm_eps,
            # 块的维度
            chunk_dim=1,
        )
        # 创建输出的线性层
        self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False
    # 设置梯度检查点的方法
    def _set_gradient_checkpointing(self, module, value=False):
        # 更新梯度检查点标志
        self.gradient_checkpointing = value
    @property
    # 从 diffusers.models.unets.unet_2d_condition 中复制的属性
    # 定义一个方法,返回注意力处理器的字典
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 的注意力处理器: 一个字典,包含模型中所有使用的注意力处理器,以权重名称索引。
        """
        # 初始化一个空字典用于存储处理器
        processors = {}
        # 定义一个递归函数,用于添加注意力处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否具有获取处理器的方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为处理器的名称
                processors[f"{name}.processor"] = module.get_processor()
            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,处理子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
            # 返回处理器字典
            return processors
        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,添加处理器
            fn_recursive_add_processors(name, module, processors)
        # 返回收集到的处理器字典
        return processors
    # 从 UNet2DConditionModel 中复制的方法,用于设置注意力处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
        参数:
            processor (`dict` 的 `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或处理器类的字典,将作为所有 `Attention` 层的处理器设置。
                如果 `processor` 是一个字典,键需要定义相应交叉注意力处理器的路径。
                在设置可训练的注意力处理器时,强烈建议这样做。
        """
        # 计算当前注意力处理器的数量
        count = len(self.attn_processors.keys())
        # 如果传入的处理器是字典,且数量与当前处理器不匹配,则抛出异常
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )
        # 定义一个递归函数,用于设置注意力处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否具有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,则直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))
            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,处理子模块
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,设置处理器
            fn_recursive_attn_processor(name, module, processor)
    # 从 UNet2DConditionModel 中复制的方法,涉及融合 QKV 投影
    # 定义融合 QKV 投影的方法
        def fuse_qkv_projections(self):
            """
            启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。
            对于交叉注意力模块,键和值投影矩阵被融合。
    
            <Tip warning={true}>
    
            此 API 是 🧪 实验性的。
    
            </Tip>
            """
            # 初始化原始注意力处理器为 None
            self.original_attn_processors = None
    
            # 遍历所有注意力处理器
            for _, attn_processor in self.attn_processors.items():
                # 如果注意力处理器的类名包含 "Added"
                if "Added" in str(attn_processor.__class__.__name__):
                    # 抛出异常,表示不支持有额外 KV 投影的模型
                    raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
    
            # 保存原始注意力处理器
            self.original_attn_processors = self.attn_processors
    
            # 遍历所有模块
            for module in self.modules():
                # 如果模块是 Attention 类型
                if isinstance(module, Attention):
                    # 融合投影
                    module.fuse_projections(fuse=True)
    
            # 设置注意力处理器为融合的处理器
            self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
    
        # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 拷贝而来
        def unfuse_qkv_projections(self):
            """禁用融合的 QKV 投影(如果已启用)。
    
            <Tip warning={true}>
    
            此 API 是 🧪 实验性的。
    
            </Tip>
    
            """
            # 如果原始注意力处理器不为 None
            if self.original_attn_processors is not None:
                # 设置注意力处理器为原始处理器
                self.set_attn_processor(self.original_attn_processors)
    
        # 定义前向传播方法
        def forward(
            # 隐藏状态输入的张量
            hidden_states: torch.Tensor,
            # 编码器隐藏状态的张量
            encoder_hidden_states: torch.Tensor,
            # 时间步的整数或浮点数
            timestep: Union[int, float, torch.LongTensor],
            # 可选的时间步条件张量
            timestep_cond: Optional[torch.Tensor] = None,
            # 可选的图像旋转嵌入
            image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
            # 返回字典的布尔值,默认为 True
            return_dict: bool = True,
# 版权声明,表示本代码的版权归 HuggingFace 团队所有,保留所有权利
# 
# 根据 Apache 许可证 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非根据适用法律或书面协议另有约定,软件
# 按照“现状”分发,不提供任何形式的保证或条件,
# 明示或暗示。
# 查看许可证以获取有关许可和
# 限制的具体信息。
from typing import Any, Dict, Optional  # 导入类型提示相关的模块
import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 的函数式 API
from torch import nn  # 导入 PyTorch 的神经网络模块
from ...configuration_utils import ConfigMixin, register_to_config  # 从配置工具中导入混入类和注册功能
from ...utils import is_torch_version, logging  # 从工具中导入版本检查和日志记录功能
from ..attention import BasicTransformerBlock  # 从注意力模块导入基本变换块
from ..embeddings import PatchEmbed  # 从嵌入模块导入补丁嵌入类
from ..modeling_outputs import Transformer2DModelOutput  # 从建模输出模块导入 2D 变换器模型输出类
from ..modeling_utils import ModelMixin  # 从建模工具中导入模型混入类
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,使用 pylint 禁用无效名称警告
class DiTTransformer2DModel(ModelMixin, ConfigMixin):  # 定义一个 2D 变换器模型类,继承自模型混入和配置混入类
    r"""  # 开始文档字符串,描述模型的功能和来源
    A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).  # 描述模型为 DiT 中引入的 2D 变换器模型
    # 定义参数的文档字符串,说明每个参数的作用及默认值
    Parameters:
        # 使用的多头注意力的头数,默认为 16
        num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
        # 每个头的通道数,默认为 72
        attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
        # 输入的通道数,默认为 4
        in_channels (int, defaults to 4): The number of channels in the input.
        # 输出的通道数,如果与输入的通道数不同,需要指定该参数
        out_channels (int, optional):
            The number of channels in the output. Specify this parameter if the output channel number differs from the
            input.
        # Transformer 块的层数,默认为 28
        num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
        # Transformer 块内使用的 dropout 概率,默认为 0.0
        dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
        # Transformer 块内组归一化的组数,默认为 32
        norm_num_groups (int, optional, defaults to 32):
            Number of groups for group normalization within Transformer blocks.
        # 配置 Transformer 块的注意力是否包含偏置参数,默认为 True
        attention_bias (bool, optional, defaults to True):
            Configure if the Transformer blocks' attention should contain a bias parameter.
        # 潜在图像的宽度,训练期间该参数是固定的,默认为 32
        sample_size (int, defaults to 32):
            The width of the latent images. This parameter is fixed during training.
        # 模型处理的补丁大小,与处理非序列数据的架构相关,默认为 2
        patch_size (int, defaults to 2):
            Size of the patches the model processes, relevant for architectures working on non-sequential data.
        # 在 Transformer 块内前馈网络中使用的激活函数,默认为 "gelu-approximate"
        activation_fn (str, optional, defaults to "gelu-approximate"):
            Activation function to use in feed-forward networks within Transformer blocks.
        # AdaLayerNorm 的嵌入数量,训练期间固定,影响推理时的最大去噪步骤,默认为 1000
        num_embeds_ada_norm (int, optional, defaults to 1000):
            Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
            inference.
        # 如果为真,提升注意力机制维度以潜在改善性能,默认为 False
        upcast_attention (bool, optional, defaults to False):
            If true, upcasts the attention mechanism dimensions for potentially improved performance.
        # 指定使用的归一化类型,可以是 'ada_norm_zero',默认为 "ada_norm_zero"
        norm_type (str, optional, defaults to "ada_norm_zero"):
            Specifies the type of normalization used, can be 'ada_norm_zero'.
        # 如果为真,启用归一化层中的逐元素仿射参数,默认为 False
        norm_elementwise_affine (bool, optional, defaults to False):
            If true, enables element-wise affine parameters in the normalization layers.
        # 在归一化层中添加的一个小常数,以防止除以零,默认为 1e-5
        norm_eps (float, optional, defaults to 1e-5):
            A small constant added to the denominator in normalization layers to prevent division by zero.
    """
    # 支持梯度检查点,以减少内存使用
    _supports_gradient_checkpointing = True
    # 用于注册配置的装饰器
    @register_to_config
    def __init__(
        # 初始化时使用的多头注意力的头数,默认为 16
        num_attention_heads: int = 16,
        # 初始化时每个头的通道数,默认为 72
        attention_head_dim: int = 72,
        # 初始化时输入的通道数,默认为 4
        in_channels: int = 4,
        # 初始化时输出的通道数,默认为 None(可选)
        out_channels: Optional[int] = None,
        # 初始化时 Transformer 块的层数,默认为 28
        num_layers: int = 28,
        # 初始化时使用的 dropout 概率,默认为 0.0
        dropout: float = 0.0,
        # 初始化时组归一化的组数,默认为 32
        norm_num_groups: int = 32,
        # 初始化时注意力的偏置参数,默认为 True
        attention_bias: bool = True,
        # 初始化时潜在图像的宽度,默认为 32
        sample_size: int = 32,
        # 初始化时模型处理的补丁大小,默认为 2
        patch_size: int = 2,
        # 初始化时使用的激活函数,默认为 "gelu-approximate"
        activation_fn: str = "gelu-approximate",
        # 初始化时 AdaLayerNorm 的嵌入数量,默认为 1000(可选)
        num_embeds_ada_norm: Optional[int] = 1000,
        # 初始化时提升注意力机制维度,默认为 False
        upcast_attention: bool = False,
        # 初始化时使用的归一化类型,默认为 "ada_norm_zero"
        norm_type: str = "ada_norm_zero",
        # 初始化时启用归一化层的逐元素仿射参数,默认为 False
        norm_elementwise_affine: bool = False,
        # 初始化时用于归一化层的小常数,默认为 1e-5
        norm_eps: float = 1e-5,
    # 初始化父类
        ):
            super().__init__()
    
            # 验证输入参数是否有效
            if norm_type != "ada_norm_zero":
                # 如果规范类型不正确,抛出未实现错误
                raise NotImplementedError(
                    f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
                )
            elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
                # 当规范类型为 'ada_norm_zero' 且嵌入数为 None 时,抛出值错误
                raise ValueError(
                    f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
                )
    
            # 设置通用变量
            self.attention_head_dim = attention_head_dim
            # 计算内部维度为注意力头数量乘以注意力头维度
            self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
            # 设置输出通道数,如果未指定,则使用输入通道数
            self.out_channels = in_channels if out_channels is None else out_channels
            # 初始化梯度检查点为 False
            self.gradient_checkpointing = False
    
            # 2. 初始化位置嵌入和变换器块
            self.height = self.config.sample_size
            self.width = self.config.sample_size
    
            # 获取补丁大小
            self.patch_size = self.config.patch_size
            # 初始化补丁嵌入对象
            self.pos_embed = PatchEmbed(
                height=self.config.sample_size,
                width=self.config.sample_size,
                patch_size=self.config.patch_size,
                in_channels=self.config.in_channels,
                embed_dim=self.inner_dim,
            )
    
            # 创建变换器块的模块列表
            self.transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        self.inner_dim,
                        self.config.num_attention_heads,
                        self.config.attention_head_dim,
                        dropout=self.config.dropout,
                        activation_fn=self.config.activation_fn,
                        num_embeds_ada_norm=self.config.num_embeds_ada_norm,
                        attention_bias=self.config.attention_bias,
                        upcast_attention=self.config.upcast_attention,
                        norm_type=norm_type,
                        norm_elementwise_affine=self.config.norm_elementwise_affine,
                        norm_eps=self.config.norm_eps,
                    )
                    # 根据层数创建多个变换器块
                    for _ in range(self.config.num_layers)
                ]
            )
    
            # 3. 输出层
            # 初始化层归一化
            self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
            # 第一层线性变换,将维度从 inner_dim 扩展到 2 * inner_dim
            self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
            # 第二层线性变换,输出维度为补丁大小的平方乘以输出通道数
            self.proj_out_2 = nn.Linear(
                self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
            )
    
        # 设置梯度检查点的功能
        def _set_gradient_checkpointing(self, module, value=False):
            # 如果模块有梯度检查点属性,则设置其值
            if hasattr(module, "gradient_checkpointing"):
                module.gradient_checkpointing = value
    
        # 前向传播函数定义
        def forward(
            self,
            hidden_states: torch.Tensor,
            timestep: Optional[torch.LongTensor] = None,
            class_labels: Optional[torch.LongTensor] = None,
            cross_attention_kwargs: Dict[str, Any] = None,
            return_dict: bool = True,
# 版权所有 2024 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵循该许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件按“原样”分发,
# 不提供任何形式的保证或条件,无论是明示还是暗示。
# 有关许可的特定权限和限制,请参见许可证。
from typing import Optional  # 从 typing 模块导入 Optional 类型,用于指示可选参数类型
from torch import nn  # 从 torch 模块导入 nn 子模块,提供神经网络的构建块
from ..modeling_outputs import Transformer2DModelOutput  # 从上级模块导入 Transformer2DModelOutput,用于模型输出格式
from .transformer_2d import Transformer2DModel  # 从当前模块导入 Transformer2DModel,用于构建 Transformer 模型
class DualTransformer2DModel(nn.Module):  # 定义 DualTransformer2DModel 类,继承自 nn.Module
    """
    Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
    
    这个类是一个双重变换器的封装器,结合了两个 `Transformer2DModel` 用于混合推理。
    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            Pass if the input is continuous. The number of channels in the input and output.
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
            Note that this is fixed at training time as it is used for learning a number of position embeddings. See
            `ImagePositionalEmbeddings`.
        num_vector_embeds (`int`, *optional*):
            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
            Includes the class for the masked latent pixel.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
            The number of diffusion steps used during training. Note that this is fixed at training time as it is used
            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
            up to but not more than steps than `num_embeds_ada_norm`.
        attention_bias (`bool`, *optional*):
            Configure if the TransformerBlocks' attention should contain a bias parameter.
    """
    # 初始化方法,用于设置模型参数
        def __init__(
            # 注意力头的数量,默认值为16
            num_attention_heads: int = 16,
            # 每个注意力头的维度,默认值为88
            attention_head_dim: int = 88,
            # 输入通道数,可选参数
            in_channels: Optional[int] = None,
            # 模型层数,默认值为1
            num_layers: int = 1,
            # dropout比率,默认值为0.0
            dropout: float = 0.0,
            # 归一化的组数,默认值为32
            norm_num_groups: int = 32,
            # 交叉注意力维度,可选参数
            cross_attention_dim: Optional[int] = None,
            # 是否使用注意力偏差,默认值为False
            attention_bias: bool = False,
            # 样本大小,可选参数
            sample_size: Optional[int] = None,
            # 向量嵌入的数量,可选参数
            num_vector_embeds: Optional[int] = None,
            # 激活函数,默认值为"geglu"
            activation_fn: str = "geglu",
            # 自适应归一化的嵌入数量,可选参数
            num_embeds_ada_norm: Optional[int] = None,
        ):
            # 调用父类初始化方法
            super().__init__()
            # 创建一个包含两个Transformer2DModel的模块列表
            self.transformers = nn.ModuleList(
                [
                    # 实例化Transformer2DModel,使用传入的参数
                    Transformer2DModel(
                        num_attention_heads=num_attention_heads,
                        attention_head_dim=attention_head_dim,
                        in_channels=in_channels,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm_num_groups=norm_num_groups,
                        cross_attention_dim=cross_attention_dim,
                        attention_bias=attention_bias,
                        sample_size=sample_size,
                        num_vector_embeds=num_vector_embeds,
                        activation_fn=activation_fn,
                        num_embeds_ada_norm=num_embeds_ada_norm,
                    )
                    # 创建两个Transformer实例
                    for _ in range(2)
                ]
            )
    
            # 可通过管道设置的变量:
    
            # 推理时组合transformer1和transformer2输出状态的比率
            self.mix_ratio = 0.5
    
            # `encoder_hidden_states`的形状预期为
            # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
            self.condition_lengths = [77, 257]
    
            # 指定编码条件时使用哪个transformer。
            # 例如,`(1, 0)`表示使用`transformers[1](conditions[0])`和`transformers[0](conditions[1])`
            self.transformer_index_for_condition = [1, 0]
    
        # 前向传播方法,用于模型的推理过程
        def forward(
            # 隐藏状态输入
            hidden_states,
            # 编码器的隐藏状态
            encoder_hidden_states,
            # 时间步,可选参数
            timestep=None,
            # 注意力掩码,可选参数
            attention_mask=None,
            # 交叉注意力的额外参数,可选参数
            cross_attention_kwargs=None,
            # 是否返回字典格式的输出,默认值为True
            return_dict: bool = True,
    ):
        """
        参数:
            hidden_states ( 当为离散时,`torch.LongTensor` 形状为 `(batch size, num latent pixels)`。
                当为连续时,`torch.Tensor` 形状为 `(batch size, channel, height, width)`): 输入的 hidden_states。
            encoder_hidden_states ( `torch.LongTensor` 形状为 `(batch size, encoder_hidden_states dim)`,*可选*):
                用于交叉注意力层的条件嵌入。如果未提供,交叉注意力将默认为
                自注意力。
            timestep ( `torch.long`,*可选*):
                可选的时间步长,将作为 AdaLayerNorm 中的嵌入使用。用于指示去噪步骤。
            attention_mask (`torch.Tensor`,*可选*):
                可选的注意力掩码,应用于注意力。
            cross_attention_kwargs (`dict`,*可选*):
                如果指定,将传递给 `AttentionProcessor` 的关键字参数字典,如
                在 `self.processor` 中定义的
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)。
            return_dict (`bool`,*可选*,默认为 `True`):
                是否返回 [`models.unets.unet_2d_condition.UNet2DConditionOutput`] 而不是简单的
                元组。
        返回:
            [`~models.transformers.transformer_2d.Transformer2DModelOutput`] 或 `tuple`:
            如果 `return_dict` 为 True,则返回 [`~models.transformers.transformer_2d.Transformer2DModelOutput`],否则返回
            `tuple`。当返回元组时,第一个元素是样本张量。
        """
        # 将输入的 hidden_states 赋值给 input_states
        input_states = hidden_states
        # 初始化一个空列表用于存储编码后的状态
        encoded_states = []
        # 初始化 token 的起始位置为 0
        tokens_start = 0
        # attention_mask 目前尚未使用
        for i in range(2):
            # 对于两个变换器中的每一个,传递相应的条件标记
            condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
            # 根据条件标记的索引获取对应的变换器
            transformer_index = self.transformer_index_for_condition[i]
            # 调用变换器处理输入状态和条件状态,并获取输出
            encoded_state = self.transformers[transformer_index](
                input_states,
                encoder_hidden_states=condition_state,
                timestep=timestep,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]  # 只获取输出的第一个元素
            # 将编码后的状态与输入状态相减,存入列表
            encoded_states.append(encoded_state - input_states)
            # 更新 token 的起始位置
            tokens_start += self.condition_lengths[i]
        # 结合两个编码后的状态,计算输出状态
        output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
        # 将计算后的输出状态与输入状态相加
        output_states = output_states + input_states
        # 如果不返回字典格式
        if not return_dict:
            # 返回输出状态的元组
            return (output_states,)
        # 返回包含样本输出的 Transformer2DModelOutput 对象
        return Transformer2DModelOutput(sample=output_states)