diffusers 源码解析(五十五)
.\diffusers\pipelines\wuerstchen\modeling_wuerstchen_common.py
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 从指定路径导入 Attention 处理模块
from ...models.attention_processor import Attention
# 定义自定义的层归一化类,继承自 nn.LayerNorm
class WuerstchenLayerNorm(nn.LayerNorm):
    # 初始化方法,接收可变参数
    def __init__(self, *args, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)
    # 前向传播方法
    def forward(self, x):
        # 调整输入张量的维度顺序
        x = x.permute(0, 2, 3, 1)
        # 调用父类的前向传播方法进行归一化
        x = super().forward(x)
        # 恢复输入张量的维度顺序并返回
        return x.permute(0, 3, 1, 2)
# 定义时间步块类,继承自 nn.Module
class TimestepBlock(nn.Module):
    # 初始化方法,接收通道数和时间步数
    def __init__(self, c, c_timestep):
        # 调用父类的初始化方法
        super().__init__()
        # 定义线性映射层,将时间步数映射到两倍的通道数
        self.mapper = nn.Linear(c_timestep, c * 2)
    # 前向传播方法
    def forward(self, x, t):
        # 使用映射层处理时间步,并将结果分割为两个部分
        a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
        # 根据公式更新输入张量并返回
        return x * (1 + a) + b
# 定义残差块类,继承自 nn.Module
class ResBlock(nn.Module):
    # 初始化方法,接收通道数、跳过连接的通道数、卷积核大小和丢弃率
    def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
        # 调用父类的初始化方法
        super().__init__()
        # 定义深度可分离卷积层
        self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
        # 定义自定义层归一化
        self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
        # 定义通道处理的顺序模块
        self.channelwise = nn.Sequential(
            nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
        )
    # 前向传播方法
    def forward(self, x, x_skip=None):
        # 保存输入张量以便后续残差连接
        x_res = x
        # 如果有跳过连接的张量,则进行拼接
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        # 对输入张量进行深度卷积和归一化
        x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
        # 通过通道处理模块
        x = self.channelwise(x).permute(0, 3, 1, 2)
        # 返回残差连接后的结果
        return x + x_res
# 从外部库导入的全局响应归一化类
class GlobalResponseNorm(nn.Module):
    # 初始化方法,接收特征维度
    def __init__(self, dim):
        # 调用父类的初始化方法
        super().__init__()
        # 定义可学习参数 gamma 和 beta
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
    # 前向传播方法
    def forward(self, x):
        # 计算输入张量的聚合范数
        agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        # 计算标准化范数
        stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
        # 返回经过归一化后的结果
        return self.gamma * (x * stand_div_norm) + self.beta + x
# 定义注意力块类,继承自 nn.Module
class AttnBlock(nn.Module):
    # 初始化方法,接收通道数、条件通道数、头数、是否自注意力及丢弃率
    def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
        # 调用父类的初始化方法
        super().__init__()
        # 设置是否使用自注意力
        self.self_attn = self_attn
        # 定义归一化层
        self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
        # 定义注意力机制
        self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
        # 定义键值映射层
        self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
    # 前向传播方法
    def forward(self, x, kv):
        # 使用键值映射层处理 kv
        kv = self.kv_mapper(kv)
        # 对输入张量进行归一化
        norm_x = self.norm(x)
        # 如果使用自注意力,则拼接归一化后的 x 和 kv
        if self.self_attn:
            batch_size, channel, _, _ = x.shape
            kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
        # 将注意力机制的输出与原输入相加
        x = x + self.attention(norm_x, encoder_hidden_states=kv)
        # 返回处理后的张量
        return x
.\diffusers\pipelines\wuerstchen\modeling_wuerstchen_diffnext.py
# 版权信息,标明该代码的版权所有者及许可证
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 在 Apache 许可证 2.0("许可证")下获得许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件按"原样"提供,
# 不提供任何形式的保证或条件,无论是明示还是暗示。
# 请参见许可证以获取特定于许可证的权限和限制。
# 导入数学库
import math
# 导入 NumPy 库以进行数组处理
import numpy as np
# 导入 PyTorch 库及其神经网络模块
import torch
import torch.nn as nn
# 从配置工具模块导入 ConfigMixin 和注册配置的方法
from ...configuration_utils import ConfigMixin, register_to_config
# 从模型工具模块导入 ModelMixin
from ...models.modeling_utils import ModelMixin
# 从本地模块导入模型组件
from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm
# 定义 WuerstchenDiffNeXt 类,继承自 ModelMixin 和 ConfigMixin
class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
    # 注册初始化方法到配置
    @register_to_config
    def __init__(
        self,
        c_in=4,  # 输入通道数,默认为 4
        c_out=4,  # 输出通道数,默认为 4
        c_r=64,  # 嵌入维度,默认为 64
        patch_size=2,  # 补丁大小,默认为 2
        c_cond=1024,  # 条件通道数,默认为 1024
        c_hidden=[320, 640, 1280, 1280],  # 隐藏层通道数配置
        nhead=[-1, 10, 20, 20],  # 注意力头数配置
        blocks=[4, 4, 14, 4],  # 各级块数配置
        level_config=["CT", "CTA", "CTA", "CTA"],  # 各级配置
        inject_effnet=[False, True, True, True],  # 是否注入 EfficientNet
        effnet_embd=16,  # EfficientNet 嵌入维度
        clip_embd=1024,  # CLIP 嵌入维度
        kernel_size=3,  # 卷积核大小
        dropout=0.1,  # dropout 比率
    ):
        # 初始化权重的方法
        def _init_weights(self, m):
            # 对卷积层和线性层进行通用初始化
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)  # 使用 Xavier 均匀分布初始化权重
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)  # 偏置初始化为 0
            # 对 EfficientNet 映射器进行初始化
            for mapper in self.effnet_mappers:
                if mapper is not None:
                    nn.init.normal_(mapper.weight, std=0.02)  # 条件初始化为正态分布
            nn.init.normal_(self.clip_mapper.weight, std=0.02)  # CLIP 映射器初始化
            nn.init.xavier_uniform_(self.embedding[1].weight, 0.02)  # 输入嵌入初始化
            nn.init.constant_(self.clf[1].weight, 0)  # 输出分类器初始化为 0
            # 初始化块中的权重
            for level_block in self.down_blocks + self.up_blocks:
                for block in level_block:
                    if isinstance(block, ResBlockStageB):
                        block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks))  # 权重缩放
                    elif isinstance(block, TimestepBlock):
                        nn.init.constant_(block.mapper.weight, 0)  # 将时间步映射器的权重初始化为 0
    # 生成位置嵌入的方法
    def gen_r_embedding(self, r, max_positions=10000):
        r = r * max_positions  # 将位置 r 乘以最大位置
        half_dim = self.c_r // 2  # 计算半维度
        emb = math.log(max_positions) / (half_dim - 1)  # 计算嵌入尺度
        emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()  # 生成嵌入
        emb = r[:, None] * emb[None, :]  # 扩展 r 的维度并进行乘法
        emb = torch.cat([emb.sin(), emb.cos()], dim=1)  # 计算正弦和余弦嵌入并拼接
        if self.c_r % 2 == 1:  # 如果 c_r 为奇数,则进行零填充
            emb = nn.functional.pad(emb, (0, 1), mode="constant")  # 用常数进行填充
        return emb.to(dtype=r.dtype)  # 返回与 r 数据类型相同的嵌入
    # 生成 CLIP 嵌入
        def gen_c_embeddings(self, clip):
            # 将输入 clip 通过映射转换
            clip = self.clip_mapper(clip)
            # 对 clip 进行序列归一化处理
            clip = self.seq_norm(clip)
            # 返回处理后的 clip
            return clip
    
        # 下采样编码过程
        def _down_encode(self, x, r_embed, effnet, clip=None):
            # 初始化层级输出列表
            level_outputs = []
            # 遍历每一个下采样块
            for i, down_block in enumerate(self.down_blocks):
                effnet_c = None  # 初始化有效网络通道为 None
                # 遍历每个下采样块中的组件
                for block in down_block:
                    # 如果是残差块阶段 B
                    if isinstance(block, ResBlockStageB):
                        # 检查有效网络通道是否为 None
                        if effnet_c is None and self.effnet_mappers[i] is not None:
                            dtype = effnet.dtype  # 获取 effnet 的数据类型
                            # 进行双线性插值并创建有效网络通道
                            effnet_c = self.effnet_mappers[i](
                                nn.functional.interpolate(
                                    effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
                                ).to(dtype)
                            )
                        # 设置跳跃连接为有效网络通道
                        skip = effnet_c if self.effnet_mappers[i] is not None else None
                        # 通过当前块处理输入 x 和跳跃连接
                        x = block(x, skip)
                    # 如果是注意力块
                    elif isinstance(block, AttnBlock):
                        # 通过当前块处理输入 x 和 clip
                        x = block(x, clip)
                    # 如果是时间步块
                    elif isinstance(block, TimestepBlock):
                        # 通过当前块处理输入 x 和 r_embed
                        x = block(x, r_embed)
                    else:
                        # 通过当前块处理输入 x
                        x = block(x)
                # 将当前层输出插入到层级输出列表的开头
                level_outputs.insert(0, x)
            # 返回所有层级输出
            return level_outputs
    
        # 上采样解码过程
        def _up_decode(self, level_outputs, r_embed, effnet, clip=None):
            # 使用层级输出的第一个元素初始化 x
            x = level_outputs[0]
            # 遍历每一个上采样块
            for i, up_block in enumerate(self.up_blocks):
                effnet_c = None  # 初始化有效网络通道为 None
                # 遍历每个上采样块中的组件
                for j, block in enumerate(up_block):
                    # 如果是残差块阶段 B
                    if isinstance(block, ResBlockStageB):
                        # 检查有效网络通道是否为 None
                        if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None:
                            dtype = effnet.dtype  # 获取 effnet 的数据类型
                            # 进行双线性插值并创建有效网络通道
                            effnet_c = self.effnet_mappers[len(self.down_blocks) + i](
                                nn.functional.interpolate(
                                    effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
                                ).to(dtype)
                            )
                        # 设置跳跃连接为当前层级输出的第 i 个元素
                        skip = level_outputs[i] if j == 0 and i > 0 else None
                        # 如果有效网络通道不为 None
                        if effnet_c is not None:
                            # 如果跳跃连接不为 None,将其与有效网络通道拼接
                            if skip is not None:
                                skip = torch.cat([skip, effnet_c], dim=1)
                            else:
                                # 否则直接设置为有效网络通道
                                skip = effnet_c
                        # 通过当前块处理输入 x 和跳跃连接
                        x = block(x, skip)
                    # 如果是注意力块
                    elif isinstance(block, AttnBlock):
                        # 通过当前块处理输入 x 和 clip
                        x = block(x, clip)
                    # 如果是时间步块
                    elif isinstance(block, TimestepBlock):
                        # 通过当前块处理输入 x 和 r_embed
                        x = block(x, r_embed)
                    else:
                        # 通过当前块处理输入 x
                        x = block(x)
            # 返回最终处理后的 x
            return x
    # 定义前向传播函数,接受多个输入参数
        def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True):
            # 如果 x_cat 不为 None,将 x 和 x_cat 沿着维度 1 拼接
            if x_cat is not None:
                x = torch.cat([x, x_cat], dim=1)
            # 处理条件嵌入
            r_embed = self.gen_r_embedding(r)
            # 如果 clip 不为 None,生成条件嵌入
            if clip is not None:
                clip = self.gen_c_embeddings(clip)
    
            # 模型块
            x_in = x  # 保存输入 x 以备后用
            x = self.embedding(x)  # 将输入 x 转换为嵌入表示
            # 下采样编码
            level_outputs = self._down_encode(x, r_embed, effnet, clip)
            # 上采样解码
            x = self._up_decode(level_outputs, r_embed, effnet, clip)
            # 将输出分成两个部分 a 和 b
            a, b = self.clf(x).chunk(2, dim=1)
            # 对 b 进行 sigmoid 激活,并进行缩放
            b = b.sigmoid() * (1 - eps * 2) + eps
            # 如果返回噪声,计算并返回
            if return_noise:
                return (x_in - a) / b
            else:
                return a, b  # 否则返回 a 和 b
# 定义一个残差块阶段 B,继承自 nn.Module
class ResBlockStageB(nn.Module):
    # 初始化函数,设置输入通道、跳跃连接通道、卷积核大小和丢弃率
    def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
        # 调用父类的初始化方法
        super().__init__()
        # 创建深度卷积层,使用指定的卷积核大小和填充
        self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
        # 创建层归一化层,设置元素可学习性为 False 和小的 epsilon 值
        self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
        # 创建一个顺序容器,包含线性层、GELU 激活、全局响应归一化、丢弃层和另一线性层
        self.channelwise = nn.Sequential(
            nn.Linear(c + c_skip, c * 4),
            nn.GELU(),
            GlobalResponseNorm(c * 4),
            nn.Dropout(dropout),
            nn.Linear(c * 4, c),
        )
    # 定义前向传播函数
    def forward(self, x, x_skip=None):
        # 保存输入以进行残差连接
        x_res = x
        # 先进行深度卷积和层归一化
        x = self.norm(self.depthwise(x))
        # 如果有跳跃连接,则将其与当前输出连接
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        # 变换输入维度并通过通道层,最后恢复维度
        x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        # 返回残差输出
        return x + x_res
.\diffusers\pipelines\wuerstchen\modeling_wuerstchen_prior.py
# 版权声明,说明文件的版权所有者和许可证信息
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License, Version 2.0 许可证许可使用本文件
# 仅在遵循许可证的情况下使用此文件
# 可以在此获取许可证的副本
# 
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件在许可证下按 "现状" 基础提供
# 不提供任何形式的担保或条件
# 查看许可证以获取特定语言的权限和限制
# 导入数学库
import math
# 从 typing 模块导入字典和联合类型
from typing import Dict, Union
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 导入配置工具和适配器相关的类
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
# 导入注意力处理器相关的类
from ...models.attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
)
# 导入模型相关的基类
from ...models.modeling_utils import ModelMixin
# 导入工具函数以检查 PyTorch 版本
from ...utils import is_torch_version
# 导入模型组件
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
# 定义 WuerstchenPrior 类,继承自多个基类
class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
    # 设置 UNet 名称为 "prior"
    unet_name = "prior"
    # 启用梯度检查点功能
    _supports_gradient_checkpointing = True
    # 注册初始化方法,定义类的构造函数
    @register_to_config
    def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
        # 调用父类的构造函数
        super().__init__()
        # 设置压缩通道数
        self.c_r = c_r
        # 定义一个卷积层用于输入到中间通道的映射
        self.projection = nn.Conv2d(c_in, c, kernel_size=1)
        # 定义条件映射层,由两个线性层和一个激活函数组成
        self.cond_mapper = nn.Sequential(
            nn.Linear(c_cond, c),  # 将条件输入映射到中间通道
            nn.LeakyReLU(0.2),      # 应用 Leaky ReLU 激活函数
            nn.Linear(c, c),        # 再次映射到中间通道
        )
        # 创建一个模块列表用于存储多个块
        self.blocks = nn.ModuleList()
        # 根据深度参数添加多个残差块、时间步块和注意力块
        for _ in range(depth):
            self.blocks.append(ResBlock(c, dropout=dropout))  # 添加残差块
            self.blocks.append(TimestepBlock(c, c_r))         # 添加时间步块
            self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))  # 添加注意力块
        # 定义输出层,由归一化层和卷积层组成
        self.out = nn.Sequential(
            WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),  # 归一化
            nn.Conv2d(c, c_in * 2, kernel_size=1),  # 输出卷积层
        )
        # 默认禁用梯度检查点
        self.gradient_checkpointing = False
        # 设置默认的注意力处理器
        self.set_default_attn_processor()
    # 定义一个只读属性
    @property
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的属性
    # 定义一个返回注意力处理器字典的方法
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 的注意力处理器: 一个字典,包含模型中使用的所有注意力处理器,并以其权重名称索引。
        """
        # 初始化一个空字典以存储处理器
        processors = {}
    
        # 定义递归添加处理器的内部函数
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模块具有获取处理器的方法,则将其添加到字典中
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用以添加子模块的处理器
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            # 返回更新后的处理器字典
            return processors
    
        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用内部函数以添加所有子模块的处理器
            fn_recursive_add_processors(name, module, processors)
    
        # 返回所有处理器的字典
        return processors
    
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制的方法
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的注意力处理器。
    
        参数:
            processor (`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或将作为所有 `Attention` 层的处理器设置的处理器类字典。
    
                如果 `processor` 是一个字典,则键需要定义对应的交叉注意力处理器的路径。
                在设置可训练的注意力处理器时,强烈建议这样做。
        """
        # 计算当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 检查传入的处理器字典的大小是否与注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器的字典,但处理器的数量 {len(processor)} 与注意力层的数量: {count} 不匹配。"
                f" 请确保传入 {count} 个处理器类。"
            )
    
        # 定义递归设置处理器的内部函数
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块具有设置处理器的方法,则设置处理器
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,则直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历子模块
            for sub_name, child in module.named_children():
                # 递归调用以设置子模块的处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用内部函数以设置所有子模块的处理器
            fn_recursive_attn_processor(name, module, processor)
    
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 复制的方法
    # 定义一个方法,用于设置默认的注意力处理器
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器,并设置默认的注意力实现。
        """
        # 检查所有注意力处理器是否属于新增的键值注意力处理器
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,使用新增的键值注意力处理器
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否属于交叉注意力处理器
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,使用标准的注意力处理器
            processor = AttnProcessor()
        else:
            # 否则,抛出一个值错误,说明无法设置默认注意力处理器
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )
        # 调用设置方法,将选择的处理器应用于当前对象
        self.set_attn_processor(processor)
    # 定义一个私有方法,用于设置梯度检查点
    def _set_gradient_checkpointing(self, module, value=False):
        # 将梯度检查点的值设置为传入的布尔值
        self.gradient_checkpointing = value
    # 定义生成位置嵌入的方法
    def gen_r_embedding(self, r, max_positions=10000):
        # 将输入的 r 乘以最大位置数
        r = r * max_positions
        # 计算嵌入的半维度
        half_dim = self.c_r // 2
        # 计算嵌入的缩放因子
        emb = math.log(max_positions) / (half_dim - 1)
        # 创建一个张量,并根据半维度生成指数嵌入
        emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
        # 根据 r 生成最终的嵌入
        emb = r[:, None] * emb[None, :]
        # 将正弦和余弦嵌入拼接在一起
        emb = torch.cat([emb.sin(), emb.cos()], dim=1)
        # 如果 c_r 是奇数,则进行零填充
        if self.c_r % 2 == 1:  # zero pad
            emb = nn.functional.pad(emb, (0, 1), mode="constant")
        # 返回嵌入,确保数据类型与 r 一致
        return emb.to(dtype=r.dtype)
    # 定义前向传播函数,接收输入张量 x、条件 r 和 c
        def forward(self, x, r, c):
            # 保存输入张量的原始值
            x_in = x
            # 对输入张量进行投影处理
            x = self.projection(x)
            # 将条件 c 转换为嵌入表示
            c_embed = self.cond_mapper(c)
            # 生成条件 r 的嵌入表示
            r_embed = self.gen_r_embedding(r)
    
            # 如果处于训练模式并且开启梯度检查点
            if self.training and self.gradient_checkpointing:
    
                # 创建自定义前向传播函数的辅助函数
                def create_custom_forward(module):
                    # 定义接受任意输入的自定义前向函数
                    def custom_forward(*inputs):
                        return module(*inputs)
    
                    return custom_forward
    
                # 检查 PyTorch 版本是否大于等于 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 遍历所有块进行处理
                    for block in self.blocks:
                        # 如果块是注意力块
                        if isinstance(block, AttnBlock):
                            # 使用检查点来保存内存
                            x = torch.utils.checkpoint.checkpoint(
                                create_custom_forward(block), x, c_embed, use_reentrant=False
                            )
                        # 如果块是时间步块
                        elif isinstance(block, TimestepBlock):
                            # 使用检查点来保存内存
                            x = torch.utils.checkpoint.checkpoint(
                                create_custom_forward(block), x, r_embed, use_reentrant=False
                            )
                        else:
                            # 处理其他类型的块
                            x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
                else:
                    # 对于旧版本的 PyTorch
                    for block in self.blocks:
                        # 如果块是注意力块
                        if isinstance(block, AttnBlock):
                            # 使用检查点来保存内存
                            x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
                        # 如果块是时间步块
                        elif isinstance(block, TimestepBlock):
                            # 使用检查点来保存内存
                            x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
                        else:
                            # 处理其他类型的块
                            x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
            else:
                # 如果不在训练模式下
                for block in self.blocks:
                    # 如果块是注意力块
                    if isinstance(block, AttnBlock):
                        # 直接进行前向传播
                        x = block(x, c_embed)
                    # 如果块是时间步块
                    elif isinstance(block, TimestepBlock):
                        # 直接进行前向传播
                        x = block(x, r_embed)
                    else:
                        # 处理其他类型的块
                        x = block(x)
            # 将输出分割为两个部分 a 和 b
            a, b = self.out(x).chunk(2, dim=1)
            # 返回经过归一化处理的结果
            return (x_in - a) / ((1 - b).abs() + 1e-5)
.\diffusers\pipelines\wuerstchen\pipeline_wuerstchen.py
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,按许可证分发的软件
# 是在“按原样”基础上提供的,没有任何形式的保证或条件,
# 明示或暗示。有关许可的特定权限和
# 限制,请参阅许可证。
from typing import Callable, Dict, List, Optional, Union  # 从 typing 模块导入类型注解工具
import numpy as np  # 导入 NumPy 库,常用于数值计算
import torch  # 导入 PyTorch 库,支持深度学习
from transformers import CLIPTextModel, CLIPTokenizer  # 从 transformers 库导入 CLIP 模型和分词器
from ...schedulers import DDPMWuerstchenScheduler  # 从调度器模块导入 DDPMWuerstchenScheduler
from ...utils import deprecate, logging, replace_example_docstring  # 从 utils 模块导入实用工具
from ...utils.torch_utils import randn_tensor  # 从 PyTorch 工具模块导入 randn_tensor 函数
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput  # 从管道工具模块导入 DiffusionPipeline 和 ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel  # 从 Paella VQ 模型模块导入 PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt  # 从 Wuerstchen DiffNeXt 模型模块导入 WuerstchenDiffNeXt
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,禁用 pylint 对无效名称的警告
EXAMPLE_DOC_STRING = """  # 示例文档字符串,提供用法示例
    Examples:
        ```py
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline  # 导入 Wuerstchen 管道
        >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(  # 从预训练模型创建 WuerstchenPriorPipeline 实例
        ...     "warp-ai/wuerstchen-prior", torch_dtype=torch.float16  # 指定模型名称和数据类型
        ... ).to("cuda")  # 将管道移动到 CUDA 设备
        >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to(  # 从预训练模型创建 WuerstchenDecoderPipeline 实例
        ...     "cuda"  # 将生成管道移动到 CUDA 设备
        ... )
        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"  # 定义生成图像的提示
        >>> prior_output = pipe(prompt)  # 使用提示生成先前输出
        >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)  # 使用生成管道从图像嵌入生成图像
        ```py
"""
class WuerstchenDecoderPipeline(DiffusionPipeline):  # 定义 WuerstchenDecoderPipeline 类,继承自 DiffusionPipeline
    """
    Pipeline for generating images from the Wuerstchen model.  # 类文档字符串,说明该管道的功能
    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)  # 说明该模型继承自 DiffusionPipeline,并提醒用户查看父类文档以获取通用方法
    # 参数说明
    Args:
        tokenizer (`CLIPTokenizer`):  # CLIP 模型使用的分词器
            The CLIP tokenizer.
        text_encoder (`CLIPTextModel`):  # CLIP 模型使用的文本编码器
            The CLIP text encoder.
        decoder ([`WuerstchenDiffNeXt`]):  # WuerstchenDiffNeXt 解码器
            The WuerstchenDiffNeXt unet decoder.
        vqgan ([`PaellaVQModel`]):  # VQGAN 模型,用于图像生成
            The VQGAN model.
        scheduler ([`DDPMWuerstchenScheduler`]):  # 调度器,用于图像嵌入生成
            A scheduler to be used in combination with `prior` to generate image embedding.
        latent_dim_scale (float, `optional`, defaults to 10.67):  # 用于确定 VQ 潜在空间大小的乘数
            Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
            height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
            width=int(24*10.67)=256 in order to match the training conditions.
    """
    # 定义模型的 CPU 卸载顺序
    model_cpu_offload_seq = "text_encoder->decoder->vqgan"
    # 定义需要回调的张量输入列表
    _callback_tensor_inputs = [
        "latents",  # 潜在变量
        "text_encoder_hidden_states",  # 文本编码器的隐藏状态
        "negative_prompt_embeds",  # 负面提示的嵌入
        "image_embeddings",  # 图像嵌入
    ]
    # 构造函数
    def __init__(
        self,
        tokenizer: CLIPTokenizer,  # 初始化时传入的分词器
        text_encoder: CLIPTextModel,  # 初始化时传入的文本编码器
        decoder: WuerstchenDiffNeXt,  # 初始化时传入的解码器
        scheduler: DDPMWuerstchenScheduler,  # 初始化时传入的调度器
        vqgan: PaellaVQModel,  # 初始化时传入的 VQGAN 模型
        latent_dim_scale: float = 10.67,  # 可选参数,默认值为 10.67
    ) -> None:
        super().__init__()  # 调用父类构造函数
        # 注册模型的各个模块
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            decoder=decoder,
            scheduler=scheduler,
            vqgan=vqgan,
        )
        # 将潜在维度缩放因子注册到配置中
        self.register_to_config(latent_dim_scale=latent_dim_scale)
    # 从 diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline 复制的方法,准备潜在变量
    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
        # 如果潜在变量为 None,则生成随机张量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 检查传入的潜在变量形状是否与预期形状匹配
            if latents.shape != shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
            # 将潜在变量移动到指定设备
            latents = latents.to(device)
        # 用调度器的初始噪声标准差调整潜在变量
        latents = latents * scheduler.init_noise_sigma
        # 返回调整后的潜在变量
        return latents
    # 编码提示的方法
    def encode_prompt(
        self,
        prompt,  # 输入的提示
        device,  # 目标设备
        num_images_per_prompt,  # 每个提示生成的图像数量
        do_classifier_free_guidance,  # 是否进行无分类器引导
        negative_prompt=None,  # 负面提示(可选)
    @property
    # 获取引导缩放比例的属性
    def guidance_scale(self):
        return self._guidance_scale
    @property
    # 判断是否使用无分类器引导
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1
    @property
    # 获取时间步数的属性
    def num_timesteps(self):
        return self._num_timesteps
    @torch.no_grad()  # 不计算梯度
    @replace_example_docstring(EXAMPLE_DOC_STRING)  # 替换示例文档字符串
    # 定义可调用对象的 __call__ 方法,允许实例像函数一样被调用
    def __call__(
            self,
            # 输入图像的嵌入,支持单个张量或张量列表
            image_embeddings: Union[torch.Tensor, List[torch.Tensor]],
            # 提示文本,可以是单个字符串或字符串列表
            prompt: Union[str, List[str]] = None,
            # 推理步骤的数量,默认值为 12
            num_inference_steps: int = 12,
            # 指定时间步的列表,默认为 None
            timesteps: Optional[List[float]] = None,
            # 指导比例,控制生成的多样性,默认值为 0.0
            guidance_scale: float = 0.0,
            # 负提示文本,可以是单个字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认值为 1
            num_images_per_prompt: int = 1,
            # 随机数生成器,可选,支持单个或多个生成器
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量,可选,默认为 None
            latents: Optional[torch.Tensor] = None,
            # 输出类型,默认值为 "pil"
            output_type: Optional[str] = "pil",
            # 返回字典标志,默认为 True
            return_dict: bool = True,
            # 结束步骤回调函数,可选,接收步骤信息
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            # 结束步骤回调函数使用的张量输入列表,默认为 ["latents"]
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            # 其他可选参数,以关键字参数形式传递
            **kwargs,
.\diffusers\pipelines\wuerstchen\pipeline_wuerstchen_combined.py
# 版权信息,表明版权所有者和许可信息
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证版本 2.0 进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 本文件只能在遵循许可证的情况下使用
# you may not use this file except in compliance with the License.
# 可以在以下地址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“原样”基础分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件,无论是明示或暗示
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 请参见许可证以了解管理权限和限制的具体语言
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入所需的类型提示
from typing import Callable, Dict, List, Optional, Union
# 导入 PyTorch 库
import torch
# 从 transformers 库导入 CLIP 文本模型和分词器
from transformers import CLIPTextModel, CLIPTokenizer
# 从自定义调度器导入 DDPMWuerstchenScheduler
from ...schedulers import DDPMWuerstchenScheduler
# 从自定义工具导入去除过时函数和替换示例文档字符串的函数
from ...utils import deprecate, replace_example_docstring
# 从管道工具导入 DiffusionPipeline 基类
from ..pipeline_utils import DiffusionPipeline
# 从模型模块导入 PaellaVQModel
from .modeling_paella_vq_model import PaellaVQModel
# 从模型模块导入 WuerstchenDiffNeXt
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
# 从模型模块导入 WuerstchenPrior
from .modeling_wuerstchen_prior import WuerstchenPrior
# 从管道模块导入 WuerstchenDecoderPipeline
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
# 从管道模块导入 WuerstchenPriorPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
# 文档字符串示例,用于展示如何使用文本转图像的管道
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> from diffusions import WuerstchenCombinedPipeline
        >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to(
        ...     "cuda"
        ... )
        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
        >>> images = pipe(prompt=prompt)
        ```
"""
# 定义一个结合文本到图像生成的管道类
class WuerstchenCombinedPipeline(DiffusionPipeline):
    """
    使用 Wuerstchen 进行文本到图像生成的组合管道
    该模型继承自 [`DiffusionPipeline`]。查看父类文档以了解库为所有管道实现的通用方法
    (如下载或保存,运行在特定设备等)。
    参数:
        tokenizer (`CLIPTokenizer`):
            用于文本输入的解码器分词器。
        text_encoder (`CLIPTextModel`):
            用于文本输入的解码器文本编码器。
        decoder (`WuerstchenDiffNeXt`):
            用于图像生成管道的解码器模型。
        scheduler (`DDPMWuerstchenScheduler`):
            用于图像生成管道的调度器。
        vqgan (`PaellaVQModel`):
            用于图像生成管道的 VQGAN 模型。
        prior_tokenizer (`CLIPTokenizer`):
            用于文本输入的先前分词器。
        prior_text_encoder (`CLIPTextModel`):
            用于文本输入的先前文本编码器。
        prior_prior (`WuerstchenPrior`):
            用于先前管道的先前模型。
        prior_scheduler (`DDPMWuerstchenScheduler`):
            用于先前管道的调度器。
    """
    # 标志,表示是否加载连接的管道
    _load_connected_pipes = True
    # 初始化类的构造函数,接收多个模型和调度器作为参数
        def __init__(
            self,
            tokenizer: CLIPTokenizer,  # 词汇处理器
            text_encoder: CLIPTextModel,  # 文本编码器
            decoder: WuerstchenDiffNeXt,  # 解码器模型
            scheduler: DDPMWuerstchenScheduler,  # 调度器
            vqgan: PaellaVQModel,  # VQGAN模型
            prior_tokenizer: CLIPTokenizer,  # 先验词汇处理器
            prior_text_encoder: CLIPTextModel,  # 先验文本编码器
            prior_prior: WuerstchenPrior,  # 先验模型
            prior_scheduler: DDPMWuerstchenScheduler,  # 先验调度器
        ):
            super().__init__()  # 调用父类的构造函数
    
            # 注册各个模型和调度器到当前实例
            self.register_modules(
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                decoder=decoder,
                scheduler=scheduler,
                vqgan=vqgan,
                prior_prior=prior_prior,
                prior_text_encoder=prior_text_encoder,
                prior_tokenizer=prior_tokenizer,
                prior_scheduler=prior_scheduler,
            )
            # 初始化先验管道,用于处理先验相关操作
            self.prior_pipe = WuerstchenPriorPipeline(
                prior=prior_prior,  # 先验模型
                text_encoder=prior_text_encoder,  # 先验文本编码器
                tokenizer=prior_tokenizer,  # 先验词汇处理器
                scheduler=prior_scheduler,  # 先验调度器
            )
            # 初始化解码器管道,用于处理解码相关操作
            self.decoder_pipe = WuerstchenDecoderPipeline(
                text_encoder=text_encoder,  # 文本编码器
                tokenizer=tokenizer,  # 词汇处理器
                decoder=decoder,  # 解码器
                scheduler=scheduler,  # 调度器
                vqgan=vqgan,  # VQGAN模型
            )
    
        # 启用节省内存的高效注意力机制
        def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
            # 在解码器管道中启用高效注意力机制
            self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
    
        # 启用模型的CPU卸载,减少内存使用
        def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
            r"""
            使用accelerate将所有模型卸载到CPU,减少内存使用且对性能影响较小。
            此方法在调用模型的`forward`方法时将整个模型移到GPU,模型将在下一个模型运行之前保持在GPU上。
            相比于`enable_sequential_cpu_offload`,内存节省较少,但性能更佳。
            """
            # 在先验管道中启用模型的CPU卸载
            self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
            # 在解码器管道中启用模型的CPU卸载
            self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
    
        # 启用顺序CPU卸载,显著减少内存使用
        def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
            r"""
            使用🤗Accelerate将所有模型卸载到CPU,显著减少内存使用。
            模型被移动到`torch.device('meta')`,并仅在调用特定子模块的`forward`方法时加载到GPU。
            卸载是基于子模块进行的,内存节省比使用`enable_model_cpu_offload`高,但性能较低。
            """
            # 在先验管道中启用顺序CPU卸载
            self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
            # 在解码器管道中启用顺序CPU卸载
            self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
    # 定义进度条方法,接受可迭代对象和总计数作为参数
    def progress_bar(self, iterable=None, total=None):
        # 在 prior_pipe 上更新进度条,传入可迭代对象和总计数
        self.prior_pipe.progress_bar(iterable=iterable, total=total)
        # 在 decoder_pipe 上更新进度条,传入可迭代对象和总计数
        self.decoder_pipe.progress_bar(iterable=iterable, total=total)
    # 定义设置进度条配置的方法,接收任意关键字参数
    def set_progress_bar_config(self, **kwargs):
        # 在 prior_pipe 上设置进度条配置,传入关键字参数
        self.prior_pipe.set_progress_bar_config(**kwargs)
        # 在 decoder_pipe 上设置进度条配置,传入关键字参数
        self.decoder_pipe.set_progress_bar_config(**kwargs)
    # 使用 torch.no_grad() 装饰器,表示在此上下文中不计算梯度
    @torch.no_grad()
    # 替换示例文档字符串的装饰器
    @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
    # 定义调用方法,处理文本到图像的转换
    def __call__(
        # 接受提示文本,支持字符串或字符串列表,默认为 None
        prompt: Optional[Union[str, List[str]]] = None,
        # 图像高度,默认为 512
        height: int = 512,
        # 图像宽度,默认为 512
        width: int = 512,
        # prior 阶段推理步骤数,默认为 60
        prior_num_inference_steps: int = 60,
        # prior 阶段时间步,默认为 None
        prior_timesteps: Optional[List[float]] = None,
        # prior 阶段引导比例,默认为 4.0
        prior_guidance_scale: float = 4.0,
        # decoder 阶段推理步骤数,默认为 12
        num_inference_steps: int = 12,
        # decoder 阶段时间步,默认为 None
        decoder_timesteps: Optional[List[float]] = None,
        # decoder 阶段引导比例,默认为 0.0
        decoder_guidance_scale: float = 0.0,
        # 负提示文本,支持字符串或字符串列表,默认为 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 提示嵌入,默认为 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 负提示嵌入,默认为 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 每个提示生成的图像数量,默认为 1
        num_images_per_prompt: int = 1,
        # 随机数生成器,默认为 None
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潜在表示,默认为 None
        latents: Optional[torch.Tensor] = None,
        # 输出类型,默认为 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典格式,默认为 True
        return_dict: bool = True,
        # prior 阶段的回调函数,默认为 None
        prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        # prior 阶段回调函数输入的张量名称列表,默认为 ["latents"]
        prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # decoder 阶段的回调函数,默认为 None
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        # decoder 阶段回调函数输入的张量名称列表,默认为 ["latents"]
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # 接受其他任意关键字参数
        **kwargs,
.\diffusers\pipelines\wuerstchen\pipeline_wuerstchen_prior.py
# 版权信息,声明此代码归 HuggingFace 团队所有,保留所有权利
# 许可证声明,使用此文件需遵守 Apache 许可证 2.0
# 提供许可证的获取地址
# 许可证说明,未按适用法律或书面协议另行约定的情况下,软件在“按现状”基础上分发
# 提供许可证详细信息的地址
from dataclasses import dataclass  # 导入数据类装饰器,用于简化类的定义
from math import ceil  # 导入向上取整函数
from typing import Callable, Dict, List, Optional, Union  # 导入类型注解
import numpy as np  # 导入 NumPy 库,用于数值计算
import torch  # 导入 PyTorch 库,用于深度学习
from transformers import CLIPTextModel, CLIPTokenizer  # 导入 CLIP 模型和分词器
from ...loaders import StableDiffusionLoraLoaderMixin  # 导入加载 LoRA 权重的混合类
from ...schedulers import DDPMWuerstchenScheduler  # 导入调度器
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring  # 导入工具类和函数
from ...utils.torch_utils import randn_tensor  # 导入生成随机张量的工具函数
from ..pipeline_utils import DiffusionPipeline  # 导入扩散管道基类
from .modeling_wuerstchen_prior import WuerstchenPrior  # 导入 Wuerstchen 先验模型
logger = logging.get_logger(__name__)  # 创建日志记录器
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]  # 设置默认的时间步,分段线性生成
EXAMPLE_DOC_STRING = """  # 示例文档字符串,提供使用示例
    Examples:
        ```py  # Python 代码块开始
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import WuerstchenPriorPipeline  # 导入 WuerstchenPriorPipeline 类
        >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(  # 从预训练模型加载管道
        ...     "warp-ai/wuerstchen-prior", torch_dtype=torch.float16  # 指定模型路径和数据类型
        ... ).to("cuda")  # 将管道移动到 GPU
        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"  # 定义生成图像的提示
        >>> prior_output = pipe(prompt)  # 生成图像并返回结果
        ```py  # Python 代码块结束
"""
@dataclass  # 使用数据类装饰器定义输出类
class WuerstchenPriorPipelineOutput(BaseOutput):  # 定义 WuerstchenPriorPipeline 的输出类
    """
    输出类用于 WuerstchenPriorPipeline。
    Args:
        image_embeddings (`torch.Tensor` or `np.ndarray`)  # 图像嵌入数据的类型说明
            Prior image embeddings for text prompt  # 为文本提示生成的图像嵌入
    """
    image_embeddings: Union[torch.Tensor, np.ndarray]  # 定义图像嵌入属性
class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):  # 定义 WuerstchenPriorPipeline 类,继承自扩散管道和加载器
    """
    用于生成 Wuerstchen 图像先验的管道。
    此模型继承自 [`DiffusionPipeline`]。查看超类文档以获取库实现的所有管道的通用方法(例如下载、保存、在特定设备上运行等)
    该管道还继承以下加载方法:
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
```py  # 文档结束
```  # 文档结束
    # 文档字符串,说明构造函数参数及其作用
    Args:
        prior ([`Prior`]):
            # 指定用于从文本嵌入近似图像嵌入的标准 unCLIP 先验
        text_encoder ([`CLIPTextModelWithProjection`]):
            # 冻结的文本编码器
        tokenizer (`CLIPTokenizer`):
            # 用于文本处理的标记器,详细信息见 CLIPTokenizer 文档
        scheduler ([`DDPMWuerstchenScheduler`]):
            # 与 `prior` 结合使用的调度器,用于生成图像嵌入
        latent_mean ('float', *optional*, defaults to 42.0):
            # 潜在扩散器的均值
        latent_std ('float', *optional*, defaults to 1.0):
            # 潜在扩散器的标准差
        resolution_multiple ('float', *optional*, defaults to 42.67):
            # 生成多个图像时的默认分辨率
    """
    # 定义 unet 的名称为 "prior"
    unet_name = "prior"
    # 定义文本编码器的名称
    text_encoder_name = "text_encoder"
    # 定义模型的 CPU 卸载顺序
    model_cpu_offload_seq = "text_encoder->prior"
    # 定义回调张量输入的列表
    _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
    # 定义可加载的 LoRA 模块
    _lora_loadable_modules = ["prior", "text_encoder"]
    # 初始化函数,设置类的属性
    def __init__(
        self,
        # 初始化所需的标记器
        tokenizer: CLIPTokenizer,
        # 初始化所需的文本编码器
        text_encoder: CLIPTextModel,
        # 初始化所需的 unCLIP 先验
        prior: WuerstchenPrior,
        # 初始化所需的调度器
        scheduler: DDPMWuerstchenScheduler,
        # 设置潜在均值,默认值为 42.0
        latent_mean: float = 42.0,
        # 设置潜在标准差,默认值为 1.0
        latent_std: float = 1.0,
        # 设置生成图像的默认分辨率倍数,默认值为 42.67
        resolution_multiple: float = 42.67,
    ) -> None:
        # 调用父类的初始化方法
        super().__init__()
        # 注册所需的模块
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            prior=prior,
            scheduler=scheduler,
        )
        # 将配置注册到类中
        self.register_to_config(
            latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple
        )
    # 从指定的管道准备潜在张量
    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
        # 如果未提供潜在张量,则生成随机张量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 检查潜在张量的形状是否匹配
            if latents.shape != shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
            # 将潜在张量移动到指定设备
            latents = latents.to(device)
        # 将潜在张量乘以调度器的初始噪声标准差
        latents = latents * scheduler.init_noise_sigma
        # 返回准备好的潜在张量
        return latents
    # 编码提示信息,处理正向和负向提示
    def encode_prompt(
        self,
        # 指定设备
        device,
        # 每个提示生成的图像数量
        num_images_per_prompt,
        # 是否进行无分类器自由引导
        do_classifier_free_guidance,
        # 正向提示文本
        prompt=None,
        # 负向提示文本
        negative_prompt=None,
        # 提示的嵌入张量,若有则提供
        prompt_embeds: Optional[torch.Tensor] = None,
        # 负向提示的嵌入张量,若有则提供
        negative_prompt_embeds: Optional[torch.Tensor] = None,
    # 检查输入的有效性
    def check_inputs(
        self,
        # 正向提示文本
        prompt,
        # 负向提示文本
        negative_prompt,
        # 推理步骤的数量
        num_inference_steps,
        # 是否进行无分类器自由引导
        do_classifier_free_guidance,
        # 提示的嵌入张量,若有则提供
        prompt_embeds=None,
        # 负向提示的嵌入张量,若有则提供
        negative_prompt_embeds=None,
    # 检查 prompt 和 prompt_embeds 是否同时存在
        ):
            if prompt is not None and prompt_embeds is not None:
                # 抛出异常,提示不能同时提供 prompt 和 prompt_embeds
                raise ValueError(
                    f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                    " only forward one of the two."
                )
            # 检查 prompt 和 prompt_embeds 是否都未定义
            elif prompt is None and prompt_embeds is None:
                # 抛出异常,提示必须提供 prompt 或 prompt_embeds
                raise ValueError(
                    "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
                )
            # 检查 prompt 是否为有效类型
            elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                # 抛出异常,提示 prompt 类型不正确
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
            # 检查 negative_prompt 和 negative_prompt_embeds 是否同时存在
            if negative_prompt is not None and negative_prompt_embeds is not None:
                # 抛出异常,提示不能同时提供 negative_prompt 和 negative_prompt_embeds
                raise ValueError(
                    f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                    f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
                )
    
            # 检查 prompt_embeds 和 negative_prompt_embeds 是否同时存在
            if prompt_embeds is not None and negative_prompt_embeds is not None:
                # 验证这两个张量的形状是否一致
                if prompt_embeds.shape != negative_prompt_embeds.shape:
                    # 抛出异常,提示形状不匹配
                    raise ValueError(
                        "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                        f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                        f" {negative_prompt_embeds.shape}."
                    )
    
            # 检查 num_inference_steps 是否为整数
            if not isinstance(num_inference_steps, int):
                # 抛出异常,提示 num_inference_steps 类型不正确
                raise TypeError(
                    f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
                               In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
                )
    
        # 定义属性 guidance_scale,返回该类的 _guidance_scale 值
        @property
        def guidance_scale(self):
            return self._guidance_scale
    
        # 定义属性 do_classifier_free_guidance,判断是否执行无分类器引导
        @property
        def do_classifier_free_guidance(self):
            return self._guidance_scale > 1
    
        # 定义属性 num_timesteps,返回该类的 _num_timesteps 值
        @property
        def num_timesteps(self):
            return self._num_timesteps
    
        # 定义可调用方法,执行主要功能
        @torch.no_grad()
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        def __call__(
            # 定义方法的参数,包括 prompt 和其他配置
            self,
            prompt: Optional[Union[str, List[str]]] = None,
            height: int = 1024,
            width: int = 1024,
            num_inference_steps: int = 60,
            timesteps: List[float] = None,
            guidance_scale: float = 8.0,
            negative_prompt: Optional[Union[str, List[str]]] = None,
            prompt_embeds: Optional[torch.Tensor] = None,
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            num_images_per_prompt: Optional[int] = 1,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            latents: Optional[torch.Tensor] = None,
            output_type: Optional[str] = "pt",
            return_dict: bool = True,
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            **kwargs,
.\diffusers\pipelines\wuerstchen\__init__.py
# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING
# 从 utils 模块导入各种工具和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 指示是否进行慢速导入
    OptionalDependencyNotAvailable,  # 处理可选依赖项不可用的异常
    _LazyModule,  # 用于延迟加载模块
    get_objects_from_module,  # 从模块中获取对象的函数
    is_torch_available,  # 检查 PyTorch 是否可用的函数
    is_transformers_available,  # 检查 Transformers 是否可用的函数
)
# 初始化一个空字典,用于存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典,用于存储导入结构
_import_structure = {}
# 尝试检查依赖项的可用性
try:
    # 如果 Transformers 和 Torch 不可用,则抛出异常
    if not (is_transformers_available() and is_torch_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖项不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 模块导入虚拟对象
    from ...utils import dummy_torch_and_transformers_objects
    # 更新虚拟对象字典
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依赖项可用,更新导入结构
else:
    _import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"]  # 添加 PaellaVQModel
    _import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"]  # 添加 WuerstchenDiffNeXt
    _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"]  # 添加 WuerstchenPrior
    _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"]  # 添加 WuerstchenDecoderPipeline
    _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"]  # 添加 WuerstchenCombinedPipeline
    _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"]  # 添加相关管道
# 根据类型检查或慢速导入的标志进行条件判断
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    # 尝试检查依赖项的可用性
    try:
        if not (is_transformers_available() and is_torch_available()):  # 同样检查可用性
            raise OptionalDependencyNotAvailable()  # 抛出异常
    # 捕获可选依赖项不可用的异常
    except OptionalDependencyNotAvailable:
        # 从虚拟对象模块导入所有内容
        from ...utils.dummy_torch_and_transformers_objects import *  # noqa F403
    else:
        # 从各个模块导入必要的类和函数
        from .modeling_paella_vq_model import PaellaVQModel
        from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
        from .modeling_wuerstchen_prior import WuerstchenPrior
        from .pipeline_wuerstchen import WuerstchenDecoderPipeline
        from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
        from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline
else:
    # 如果不是类型检查或慢速导入,导入 sys 模块
    import sys
    # 将当前模块替换为一个延迟加载模块
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名称
        globals()["__file__"],  # 当前文件
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规格
    )
    # 将虚拟对象添加到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)  # 设置属性
.\diffusers\pipelines\__init__.py
# 导入类型检查的模块
from typing import TYPE_CHECKING
# 从父级目录的 utils 模块中导入多个对象和函数
from ..utils import (
    DIFFUSERS_SLOW_IMPORT,  # 导入一个慢加载的功能
    OptionalDependencyNotAvailable,  # 导入可选依赖不可用的异常
    _LazyModule,  # 导入懒加载模块的工具
    get_objects_from_module,  # 导入从模块中获取对象的函数
    is_flax_available,  # 导入检查 Flax 库是否可用的函数
    is_k_diffusion_available,  # 导入检查 K-Diffusion 库是否可用的函数
    is_librosa_available,  # 导入检查 Librosa 库是否可用的函数
    is_note_seq_available,  # 导入检查 NoteSeq 库是否可用的函数
    is_onnx_available,  # 导入检查 ONNX 库是否可用的函数
    is_sentencepiece_available,  # 导入检查 SentencePiece 库是否可用的函数
    is_torch_available,  # 导入检查 PyTorch 库是否可用的函数
    is_torch_npu_available,  # 导入检查 NPU 版 PyTorch 是否可用的函数
    is_transformers_available,  # 导入检查 Transformers 库是否可用的函数
)
# 初始化一个空字典以存储假对象
_dummy_objects = {}
# 定义一个字典以组织导入的模块结构
_import_structure = {
    "controlnet": [],  # 控制网模块
    "controlnet_hunyuandit": [],  # 控制网相关模块
    "controlnet_sd3": [],  # 控制网 SD3 模块
    "controlnet_xs": [],  # 控制网 XS 模块
    "deprecated": [],  # 存放弃用模块
    "latent_diffusion": [],  # 潜在扩散模块
    "ledits_pp": [],  # LEDITS PP 模块
    "marigold": [],  # 万寿菊模块
    "pag": [],  # PAG 模块
    "stable_diffusion": [],  # 稳定扩散模块
    "stable_diffusion_xl": [],  # 稳定扩散 XL 模块
}
try:
    # 检查 PyTorch 是否可用
    if not is_torch_available():
        # 如果不可用,抛出可选依赖不可用的异常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果捕获到异常,从 utils 模块导入假对象(PyTorch 相关)
    from ..utils import dummy_pt_objects  # noqa F403
    # 将获取的假对象更新到字典中
    _dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
    # 如果 PyTorch 可用,更新导入结构,添加自动管道类
    _import_structure["auto_pipeline"] = [
        "AutoPipelineForImage2Image",  # 图像到图像的自动管道
        "AutoPipelineForInpainting",  # 图像修复的自动管道
        "AutoPipelineForText2Image",  # 文本到图像的自动管道
    ]
    # 添加一致性模型管道
    _import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
    # 添加舞蹈扩散管道
    _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
    # 添加 DDIM 管道
    _import_structure["ddim"] = ["DDIMPipeline"]
    # 添加 DDPM 管道
    _import_structure["ddpm"] = ["DDPMPipeline"]
    # 添加 DiT 管道
    _import_structure["dit"] = ["DiTPipeline"]
    # 扩展潜在扩散模块,添加超分辨率管道
    _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
    # 添加管道工具的输出类型
    _import_structure["pipeline_utils"] = [
        "AudioPipelineOutput",  # 音频管道输出
        "DiffusionPipeline",  # 扩散管道
        "StableDiffusionMixin",  # 稳定扩散混合类
        "ImagePipelineOutput",  # 图像管道输出
    ]
    # 扩展弃用模块,添加弃用的管道
    _import_structure["deprecated"].extend(
        [
            "PNDMPipeline",  # PNDM 管道
            "LDMPipeline",  # LDM 管道
            "RePaintPipeline",  # 重绘管道
            "ScoreSdeVePipeline",  # Score SDE VE 管道
            "KarrasVePipeline",  # Karras VE 管道
        ]
    )
try:
    # 检查 PyTorch 和 Librosa 是否都可用
    if not (is_torch_available() and is_librosa_available()):
        # 如果其中一个不可用,抛出可选依赖不可用的异常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 捕获异常,从 utils 模块导入假对象(PyTorch 和 Librosa 相关)
    from ..utils import dummy_torch_and_librosa_objects  # noqa F403
    # 更新假对象字典
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else:
    # 如果两个库都可用,扩展弃用模块,添加音频扩散管道和 Mel 类
    _import_structure["deprecated"].extend(["AudioDiffusionPipeline", "Mel"])
try:
    # 检查 Transformers、PyTorch 和 NoteSeq 是否都可用
    if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
        # 如果其中一个不可用,抛出可选依赖不可用的异常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 捕获异常,从 utils 模块导入假对象(Transformers、PyTorch 和 NoteSeq 相关)
    from ..utils import dummy_transformers_and_torch_and_note_seq_objects  # noqa F403
    # 更新假对象字典
    _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
    # 如果三个库都可用,扩展弃用模块,添加 MIDI 处理器和谱图扩散管道
    _import_structure["deprecated"].extend(
        [
            "MidiProcessor",  # MIDI 处理器
            "SpectrogramDiffusionPipeline",  # 谱图扩散管道
        ]
    )
try:
    # 检查 PyTorch 和 Transformers 库是否可用
        if not (is_torch_available() and is_transformers_available()):
            # 如果任一库不可用,抛出异常表示可选依赖不可用
            raise OptionalDependencyNotAvailable()
# 捕获未满足可选依赖项的异常
except OptionalDependencyNotAvailable:
    # 从上层模块导入虚拟的 Torch 和 Transformers 对象
    from ..utils import dummy_torch_and_transformers_objects  # noqa F403
    # 更新虚拟对象的字典,以获取导入的虚拟对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 将过时的导入结构中添加一组管道名称
    _import_structure["deprecated"].extend(
        [
            "VQDiffusionPipeline",
            "AltDiffusionPipeline",
            "AltDiffusionImg2ImgPipeline",
            "CycleDiffusionPipeline",
            "StableDiffusionInpaintPipelineLegacy",
            "StableDiffusionPix2PixZeroPipeline",
            "StableDiffusionParadigmsPipeline",
            "StableDiffusionModelEditingPipeline",
            "VersatileDiffusionDualGuidedPipeline",
            "VersatileDiffusionImageVariationPipeline",
            "VersatileDiffusionPipeline",
            "VersatileDiffusionTextToImagePipeline",
        ]
    )
    # 为“amused”添加相关管道名称
    _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
    # 为“animatediff”添加相关管道名称
    _import_structure["animatediff"] = [
        "AnimateDiffPipeline",
        "AnimateDiffControlNetPipeline",
        "AnimateDiffSDXLPipeline",
        "AnimateDiffSparseControlNetPipeline",
        "AnimateDiffVideoToVideoPipeline",
    ]
    # 为“flux”添加相关管道名称
    _import_structure["flux"] = ["FluxPipeline"]
    # 为“audioldm”添加相关管道名称
    _import_structure["audioldm"] = ["AudioLDMPipeline"]
    # 为“audioldm2”添加相关管道名称
    _import_structure["audioldm2"] = [
        "AudioLDM2Pipeline",
        "AudioLDM2ProjectionModel",
        "AudioLDM2UNet2DConditionModel",
    ]
    # 为“blip_diffusion”添加相关管道名称
    _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
    # 为“cogvideo”添加相关管道名称
    _import_structure["cogvideo"] = [
        "CogVideoXPipeline",
        "CogVideoXImageToVideoPipeline",
        "CogVideoXVideoToVideoPipeline",
    ]
    # 为“controlnet”扩展相关管道名称
    _import_structure["controlnet"].extend(
        [
            "BlipDiffusionControlNetPipeline",
            "StableDiffusionControlNetImg2ImgPipeline",
            "StableDiffusionControlNetInpaintPipeline",
            "StableDiffusionControlNetPipeline",
            "StableDiffusionXLControlNetImg2ImgPipeline",
            "StableDiffusionXLControlNetInpaintPipeline",
            "StableDiffusionXLControlNetPipeline",
        ]
    )
    # 为“pag”扩展相关管道名称
    _import_structure["pag"].extend(
        [
            "AnimateDiffPAGPipeline",
            "KolorsPAGPipeline",
            "HunyuanDiTPAGPipeline",
            "StableDiffusion3PAGPipeline",
            "StableDiffusionPAGPipeline",
            "StableDiffusionControlNetPAGPipeline",
            "StableDiffusionXLPAGPipeline",
            "StableDiffusionXLPAGInpaintPipeline",
            "StableDiffusionXLControlNetPAGPipeline",
            "StableDiffusionXLPAGImg2ImgPipeline",
            "PixArtSigmaPAGPipeline",
        ]
    )
    # 为“controlnet_xs”扩展相关管道名称
    _import_structure["controlnet_xs"].extend(
        [
            "StableDiffusionControlNetXSPipeline",
            "StableDiffusionXLControlNetXSPipeline",
        ]
    )
    # 为“controlnet_hunyuandit”扩展相关管道名称
    _import_structure["controlnet_hunyuandit"].extend(
        [
            "HunyuanDiTControlNetPipeline",
        ]
    )
    # 将 "StableDiffusion3ControlNetPipeline" 添加到 "controlnet_sd3" 的导入结构中
    _import_structure["controlnet_sd3"].extend(
        [
            "StableDiffusion3ControlNetPipeline",
        ]
    )
    # 定义 "deepfloyd_if" 的导入结构,包含多个管道
    _import_structure["deepfloyd_if"] = [
        "IFImg2ImgPipeline",
        "IFImg2ImgSuperResolutionPipeline",
        "IFInpaintingPipeline",
        "IFInpaintingSuperResolutionPipeline",
        "IFPipeline",
        "IFSuperResolutionPipeline",
    ]
    # 设置 "hunyuandit" 的导入结构,仅包含一个管道
    _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
    # 定义 "kandinsky" 的导入结构,包含多个相关管道
    _import_structure["kandinsky"] = [
        "KandinskyCombinedPipeline",
        "KandinskyImg2ImgCombinedPipeline",
        "KandinskyImg2ImgPipeline",
        "KandinskyInpaintCombinedPipeline",
        "KandinskyInpaintPipeline",
        "KandinskyPipeline",
        "KandinskyPriorPipeline",
    ]
    # 定义 "kandinsky2_2" 的导入结构,包含多个管道
    _import_structure["kandinsky2_2"] = [
        "KandinskyV22CombinedPipeline",
        "KandinskyV22ControlnetImg2ImgPipeline",
        "KandinskyV22ControlnetPipeline",
        "KandinskyV22Img2ImgCombinedPipeline",
        "KandinskyV22Img2ImgPipeline",
        "KandinskyV22InpaintCombinedPipeline",
        "KandinskyV22InpaintPipeline",
        "KandinskyV22Pipeline",
        "KandinskyV22PriorEmb2EmbPipeline",
        "KandinskyV22PriorPipeline",
    ]
    # 定义 "kandinsky3" 的导入结构,包含两个管道
    _import_structure["kandinsky3"] = [
        "Kandinsky3Img2ImgPipeline",
        "Kandinsky3Pipeline",
    ]
    # 定义 "latent_consistency_models" 的导入结构,包含两个管道
    _import_structure["latent_consistency_models"] = [
        "LatentConsistencyModelImg2ImgPipeline",
        "LatentConsistencyModelPipeline",
    ]
    # 将 "LDMTextToImagePipeline" 添加到 "latent_diffusion" 的导入结构中
    _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
    # 将稳定扩散相关的管道添加到 "ledits_pp" 的导入结构中
    _import_structure["ledits_pp"].extend(
        [
            "LEditsPPPipelineStableDiffusion",
            "LEditsPPPipelineStableDiffusionXL",
        ]
    )
    # 设置 "latte" 的导入结构,仅包含一个管道
    _import_structure["latte"] = ["LattePipeline"]
    # 设置 "lumina" 的导入结构,仅包含一个管道
    _import_structure["lumina"] = ["LuminaText2ImgPipeline"]
    # 将 "MarigoldDepthPipeline" 和 "MarigoldNormalsPipeline" 添加到 "marigold" 的导入结构中
    _import_structure["marigold"].extend(
        [
            "MarigoldDepthPipeline",
            "MarigoldNormalsPipeline",
        ]
    )
    # 设置 "musicldm" 的导入结构,仅包含一个管道
    _import_structure["musicldm"] = ["MusicLDMPipeline"]
    # 设置 "paint_by_example" 的导入结构,仅包含一个管道
    _import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
    # 设置 "pia" 的导入结构,仅包含一个管道
    _import_structure["pia"] = ["PIAPipeline"]
    # 设置 "pixart_alpha" 的导入结构,包含两个管道
    _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
    # 设置 "semantic_stable_diffusion" 的导入结构,仅包含一个管道
    _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
    # 设置 "shap_e" 的导入结构,包含两个管道
    _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
    # 定义 "stable_audio" 的导入结构,包含两个管道
    _import_structure["stable_audio"] = [
        "StableAudioProjectionModel",
        "StableAudioPipeline",
    ]
    # 定义 "stable_cascade" 的导入结构,包含多个管道
    _import_structure["stable_cascade"] = [
        "StableCascadeCombinedPipeline",
        "StableCascadeDecoderPipeline",
        "StableCascadePriorPipeline",
    ]
    # 向 stable_diffusion 的导入结构中添加多个相关的管道名称
    _import_structure["stable_diffusion"].extend(
        [
            # 添加 CLIP 图像投影管道
            "CLIPImageProjection",
            # 添加稳定扩散深度到图像管道
            "StableDiffusionDepth2ImgPipeline",
            # 添加稳定扩散图像变体管道
            "StableDiffusionImageVariationPipeline",
            # 添加稳定扩散图像到图像管道
            "StableDiffusionImg2ImgPipeline",
            # 添加稳定扩散图像修复管道
            "StableDiffusionInpaintPipeline",
            # 添加稳定扩散指令图像到图像管道
            "StableDiffusionInstructPix2PixPipeline",
            # 添加稳定扩散潜在上采样管道
            "StableDiffusionLatentUpscalePipeline",
            # 添加稳定扩散主管道
            "StableDiffusionPipeline",
            # 添加稳定扩散上采样管道
            "StableDiffusionUpscalePipeline",
            # 添加稳定 UnCLIP 图像到图像管道
            "StableUnCLIPImg2ImgPipeline",
            # 添加稳定 UnCLIP 管道
            "StableUnCLIPPipeline",
            # 添加稳定扩散 LDM 3D 管道
            "StableDiffusionLDM3DPipeline",
        ]
    )
    # 为 aura_flow 设置导入结构,包括其管道
    _import_structure["aura_flow"] = ["AuraFlowPipeline"]
    # 为 stable_diffusion_3 设置相关管道
    _import_structure["stable_diffusion_3"] = [
        # 添加稳定扩散 3 管道
        "StableDiffusion3Pipeline",
        # 添加稳定扩散 3 图像到图像管道
        "StableDiffusion3Img2ImgPipeline",
        # 添加稳定扩散 3 图像修复管道
        "StableDiffusion3InpaintPipeline",
    ]
    # 为 stable_diffusion_attend_and_excite 设置导入结构
    _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
    # 为 stable_diffusion_safe 设置安全管道
    _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
    # 为 stable_diffusion_sag 设置导入结构
    _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
    # 为 stable_diffusion_gligen 设置导入结构
    _import_structure["stable_diffusion_gligen"] = [
        # 添加稳定扩散 GLIGEN 管道
        "StableDiffusionGLIGENPipeline",
        # 添加稳定扩散 GLIGEN 文本图像管道
        "StableDiffusionGLIGENTextImagePipeline",
    ]
    # 为 stable_video_diffusion 设置导入结构
    _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
    # 向 stable_diffusion_xl 的导入结构中添加多个管道
    _import_structure["stable_diffusion_xl"].extend(
        [
            # 添加稳定扩散 XL 图像到图像管道
            "StableDiffusionXLImg2ImgPipeline",
            # 添加稳定扩散 XL 图像修复管道
            "StableDiffusionXLInpaintPipeline",
            # 添加稳定扩散 XL 指令图像到图像管道
            "StableDiffusionXLInstructPix2PixPipeline",
            # 添加稳定扩散 XL 主管道
            "StableDiffusionXLPipeline",
        ]
    )
    # 为 stable_diffusion_diffedit 设置导入结构
    _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
    # 为 stable_diffusion_ldm3d 设置导入结构
    _import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
    # 为 stable_diffusion_panorama 设置导入结构
    _import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
    # 为 t2i_adapter 设置导入结构,包括适配管道
    _import_structure["t2i_adapter"] = [
        # 添加稳定扩散适配器管道
        "StableDiffusionAdapterPipeline",
        # 添加稳定扩散 XL 适配器管道
        "StableDiffusionXLAdapterPipeline",
    ]
    # 为 text_to_video_synthesis 设置多个视频合成相关管道
    _import_structure["text_to_video_synthesis"] = [
        # 添加文本到视频稳定扩散管道
        "TextToVideoSDPipeline",
        # 添加文本到视频零管道
        "TextToVideoZeroPipeline",
        # 添加文本到视频零稳定扩散 XL 管道
        "TextToVideoZeroSDXLPipeline",
        # 添加视频到视频稳定扩散管道
        "VideoToVideoSDPipeline",
    ]
    # 为 i2vgen_xl 设置导入结构
    _import_structure["i2vgen_xl"] = ["I2VGenXLPipeline"]
    # 为 unclip 设置相关管道
    _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
    # 为 unidiffuser 设置多个管道
    _import_structure["unidiffuser"] = [
        # 添加图像文本管道输出
        "ImageTextPipelineOutput",
        # 添加 UniDiffuser 模型
        "UniDiffuserModel",
        # 添加 UniDiffuser 管道
        "UniDiffuserPipeline",
        # 添加 UniDiffuser 文本解码器
        "UniDiffuserTextDecoder",
    ]
    # 为 wuerstchen 设置多个管道
    _import_structure["wuerstchen"] = [
        # 添加 Wuerstchen 组合管道
        "WuerstchenCombinedPipeline",
        # 添加 Wuerstchen 解码器管道
        "WuerstchenDecoderPipeline",
        # 添加 Wuerstchen 先验管道
        "WuerstchenPriorPipeline",
    ]
# 尝试检查 ONNX 是否可用
try:
    # 如果 ONNX 不可用,抛出异常
    if not is_onnx_available():
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从工具模块导入假 ONNX 对象,防止导入错误
    from ..utils import dummy_onnx_objects  # noqa F403
    # 更新虚拟对象字典,添加假 ONNX 对象
    _dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
# 如果 ONNX 可用,更新导入结构
else:
    _import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
# 尝试检查 PyTorch、Transformers 和 ONNX 是否都可用
try:
    # 如果任何一个不可用,抛出异常
    if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从工具模块导入假 PyTorch、Transformers 和 ONNX 对象
    from ..utils import dummy_torch_and_transformers_and_onnx_objects  # noqa F403
    # 更新虚拟对象字典,添加假对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
# 如果都可用,扩展导入结构
else:
    _import_structure["stable_diffusion"].extend(
        [
            "OnnxStableDiffusionImg2ImgPipeline",
            "OnnxStableDiffusionInpaintPipeline",
            "OnnxStableDiffusionPipeline",
            "OnnxStableDiffusionUpscalePipeline",
            "StableDiffusionOnnxPipeline",
        ]
    )
# 尝试检查 PyTorch、Transformers 和 K-Diffusion 是否都可用
try:
    # 如果任何一个不可用,抛出异常
    if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从工具模块导入假 PyTorch、Transformers 和 K-Diffusion 对象
    from ..utils import (
        dummy_torch_and_transformers_and_k_diffusion_objects,
    )
    # 更新虚拟对象字典,添加假对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
# 如果都可用,更新导入结构
else:
    _import_structure["stable_diffusion_k_diffusion"] = [
        "StableDiffusionKDiffusionPipeline",
        "StableDiffusionXLKDiffusionPipeline",
    ]
# 尝试检查 PyTorch、Transformers 和 SentencePiece 是否都可用
try:
    # 如果任何一个不可用,抛出异常
    if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从工具模块导入假 PyTorch、Transformers 和 SentencePiece 对象
    from ..utils import (
        dummy_torch_and_transformers_and_sentencepiece_objects,
    )
    # 更新虚拟对象字典,添加假对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
# 如果都可用,更新导入结构
else:
    _import_structure["kolors"] = [
        "KolorsPipeline",
        "KolorsImg2ImgPipeline",
    ]
# 尝试检查 Flax 是否可用
try:
    # 如果 Flax 不可用,抛出异常
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从工具模块导入假 Flax 对象,防止导入错误
    from ..utils import dummy_flax_objects  # noqa F403
    # 更新虚拟对象字典,添加假 Flax 对象
    _dummy_objects.update(get_objects_from_module(dummy_flax_objects))
# 如果 Flax 可用,更新导入结构
else:
    _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
# 尝试检查 Flax 和 Transformers 是否都可用
try:
    # 如果任何一个不可用,抛出异常
    if not (is_flax_available() and is_transformers_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从工具模块导入假 Flax 和 Transformers 对象
    from ..utils import dummy_flax_and_transformers_objects  # noqa F403
    # 更新虚拟对象字典,添加假对象
    _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
# 如果都可用,扩展导入结构
else:
    _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
    # 将稳定扩散模型相关的类名添加到导入结构中
        _import_structure["stable_diffusion"].extend(
            [
                # 添加图像到图像转换管道类名
                "FlaxStableDiffusionImg2ImgPipeline",
                # 添加图像修复管道类名
                "FlaxStableDiffusionInpaintPipeline",
                # 添加基础稳定扩散管道类名
                "FlaxStableDiffusionPipeline",
            ]
        )
    # 将稳定扩散 XL 模型相关的类名添加到导入结构中
        _import_structure["stable_diffusion_xl"].extend(
            [
                # 添加稳定扩散 XL 管道类名
                "FlaxStableDiffusionXLPipeline",
            ]
        )
# 检查是否为类型检查或慢导入条件
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 检查是否可用 PyTorch
        if not is_torch_available():
            # 如果不可用,则引发可选依赖项不可用异常
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 导入占位符对象以避免运行时错误
        from ..utils.dummy_pt_objects import *  # noqa F403
    else:
        # 导入自动图像到图像管道相关类
        from .auto_pipeline import (
            AutoPipelineForImage2Image,
            AutoPipelineForInpainting,
            AutoPipelineForText2Image,
        )
        # 导入一致性模型管道
        from .consistency_models import ConsistencyModelPipeline
        # 导入舞蹈扩散管道
        from .dance_diffusion import DanceDiffusionPipeline
        # 导入 DDIM 管道
        from .ddim import DDIMPipeline
        # 导入 DDPM 管道
        from .ddpm import DDPMPipeline
        # 导入已弃用的管道
        from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
        # 导入 DIT 管道
        from .dit import DiTPipeline
        # 导入潜在扩散超分辨率管道
        from .latent_diffusion import LDMSuperResolutionPipeline
        # 导入管道工具类
        from .pipeline_utils import (
            AudioPipelineOutput,
            DiffusionPipeline,
            ImagePipelineOutput,
            StableDiffusionMixin,
        )
    try:
        # 检查是否可用 PyTorch 和 librosa
        if not (is_torch_available() and is_librosa_available()):
            # 如果不可用,则引发可选依赖项不可用异常
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 导入占位符对象以避免运行时错误
        from ..utils.dummy_torch_and_librosa_objects import *
    else:
        # 导入已弃用的音频扩散管道和 Mel 类
        from .deprecated import AudioDiffusionPipeline, Mel
    try:
        # 检查是否可用 PyTorch 和 transformers
        if not (is_torch_available() and is_transformers_available()):
            # 如果不可用,则引发可选依赖项不可用异常
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 导入占位符对象以避免运行时错误
        from ..utils.dummy_torch_and_transformers_objects import *
else:
    # 导入 sys 模块
    import sys
    # 创建懒加载模块实例
    sys.modules[__name__] = _LazyModule(
        __name__,
        globals()["__file__"],
        _import_structure,
        module_spec=__spec__,
    )
    # 设置占位符对象到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)
# 版权所有 2024 NVIDIA 和 The HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0("许可证")许可;
# 除非遵循该许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件在许可证下分发是以“按原样”基础进行的,
# 不提供任何形式的明示或暗示的担保或条件。
# 有关许可证下权限和限制的具体语言,请参见许可证。
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入可选类型、元组和联合类型
from typing import Optional, Tuple, Union
# 导入 NumPy 库,通常用于数值计算
import numpy as np
# 导入 PyTorch 库,通常用于深度学习
import torch
# 从配置工具中导入 ConfigMixin 和 register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
# 从 utils 模块导入 BaseOutput 基类
from ...utils import BaseOutput
# 从 utils.torch_utils 导入生成随机张量的函数
from ...utils.torch_utils import randn_tensor
# 从调度工具中导入 SchedulerMixin
from ..scheduling_utils import SchedulerMixin
# 定义 KarrasVeOutput 类,继承自 BaseOutput
@dataclass
class KarrasVeOutput(BaseOutput):
    """
    调度器步骤函数输出的输出类。
    参数:
        prev_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`,用于图像):
            先前时间步的计算样本 (x_{t-1})。`prev_sample` 应作为下一个模型输入使用
            在去噪循环中。
        derivative (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`,用于图像):
            预测的原始图像样本的导数 (x_0)。
        pred_original_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`,用于图像):
            基于当前时间步模型输出的预测去噪样本 (x_{0})。
            `pred_original_sample` 可用于预览进度或进行引导。
    """
    # 先前样本,类型为 torch.Tensor
    prev_sample: torch.Tensor
    # 导数,类型为 torch.Tensor
    derivative: torch.Tensor
    # 可选的预测原始样本,类型为 torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
# 定义 KarrasVeScheduler 类,继承自 SchedulerMixin 和 ConfigMixin
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
    """
    针对方差扩展模型的随机调度器。
    该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。有关库为所有调度器实现的通用
    方法的详细信息,请查看超类文档,例如加载和保存。
    <Tip>
    有关参数的更多详细信息,请参见 [附录 E](https://arxiv.org/abs/2206.00364)。用于查找特定模型的
    最优 `{s_noise, s_churn, s_min, s_max}` 的网格搜索值在论文的表 5 中进行了描述。
    </Tip>
    # 参数说明部分,描述每个参数的含义和默认值
        Args:
            sigma_min (`float`, defaults to 0.02):
                # 最小噪声幅度
                The minimum noise magnitude.
            sigma_max (`float`, defaults to 100):
                # 最大噪声幅度
                The maximum noise magnitude.
            s_noise (`float`, defaults to 1.007):
                # 额外噪声量,抵消采样时的细节损失,合理范围为 [1.000, 1.011]
                The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
                1.011].
            s_churn (`float`, defaults to 80):
                # 控制整体随机性程度的参数,合理范围为 [0, 100]
                The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
            s_min (`float`, defaults to 0.05):
                # 添加噪声的起始 sigma 范围值,合理范围为 [0, 10]
                The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
            s_max (`float`, defaults to 50):
                # 添加噪声的结束 sigma 范围值,合理范围为 [0.2, 80]
                The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
        """
    
        # 定义阶数为 2
        order = 2
    
        # 初始化方法,注册到配置
        @register_to_config
        def __init__(
            self,
            # 最小噪声幅度,默认值为 0.02
            sigma_min: float = 0.02,
            # 最大噪声幅度,默认值为 100
            sigma_max: float = 100,
            # 额外噪声量,默认值为 1.007
            s_noise: float = 1.007,
            # 随机性控制参数,默认值为 80
            s_churn: float = 80,
            # sigma 范围起始值,默认值为 0.05
            s_min: float = 0.05,
            # sigma 范围结束值,默认值为 50
            s_max: float = 50,
        ):
            # 设置初始噪声分布的标准差
            self.init_noise_sigma = sigma_max
    
            # 可设置值
            # 推理步骤的数量,初始为 None
            self.num_inference_steps: int = None
            # 时间步的张量,初始为 None
            self.timesteps: np.IntTensor = None
            # sigma(t_i) 的张量,初始为 None
            self.schedule: torch.Tensor = None  # sigma(t_i)
    
        # 处理模型输入以确保与调度器的互换性
        def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
            """
            确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性。
    
            Args:
                sample (`torch.Tensor`):
                    # 输入样本
                    The input sample.
                timestep (`int`, *optional*):
                    # 当前扩散链中的时间步
                    The current timestep in the diffusion chain.
    
            Returns:
                `torch.Tensor`:
                    # 返回缩放后的输入样本
                    A scaled input sample.
            """
            # 返回未改变的样本
            return sample
    
        # 设置扩散链使用的离散时间步
        def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
            """
            设置用于扩散链的离散时间步(在推理之前运行)。
    
            Args:
                num_inference_steps (`int`):
                    # 生成样本时使用的扩散步骤数量
                    The number of diffusion steps used when generating samples with a pre-trained model.
                device (`str` or `torch.device`, *optional*):
                    # 将时间步移动到的设备,如果为 None,则不移动
                    The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
            """
            # 设置推理步骤数量
            self.num_inference_steps = num_inference_steps
            # 创建时间步数组并反转
            timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
            # 将时间步转换为张量并移动到指定设备
            self.timesteps = torch.from_numpy(timesteps).to(device)
            # 计算调度的 sigma 值
            schedule = [
                (
                    self.config.sigma_max**2
                    * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
                )
                for i in self.timesteps
            ]
            # 将调度值转换为张量
            self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
    # 定义添加噪声到输入样本的函数
    def add_noise_to_input(
            self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
        ) -> Tuple[torch.Tensor, float]:
            """
            显式的 Langevin 类似的“搅动”步骤,根据 `gamma_i ≥ 0` 添加噪声,以达到更高的噪声水平 `sigma_hat = sigma_i + gamma_i*sigma_i`。
    
            参数:
                sample (`torch.Tensor`):
                    输入样本。
                sigma (`float`):
                generator (`torch.Generator`, *可选*):
                    随机数生成器。
            """
            # 检查 sigma 是否在配置的最小值和最大值之间
            if self.config.s_min <= sigma <= self.config.s_max:
                # 计算 gamma,确保不会超过最大值
                gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
            else:
                # 如果不在范围内,gamma 为 0
                gamma = 0
    
            # 从标准正态分布中采样噪声 eps
            eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
            # 计算新的噪声水平
            sigma_hat = sigma + gamma * sigma
            # 更新样本,添加噪声
            sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
    
            # 返回更新后的样本和新的噪声水平
            return sample_hat, sigma_hat
    
    # 定义从上一个时间步预测样本的步骤函数
    def step(
            self,
            model_output: torch.Tensor,
            sigma_hat: float,
            sigma_prev: float,
            sample_hat: torch.Tensor,
            return_dict: bool = True,
        ) -> Union[KarrasVeOutput, Tuple]:
            """
            通过反转 SDE 从学习的模型输出中传播扩散过程(通常是预测的噪声)。
    
            参数:
                model_output (`torch.Tensor`):
                    学习扩散模型的直接输出。
                sigma_hat (`float`):
                sigma_prev (`float`):
                sample_hat (`torch.Tensor`):
                return_dict (`bool`, *可选*, 默认为 `True`):
                    是否返回一个 [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] 或 `tuple`。
    
            返回:
                [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] 或 `tuple`:
                    如果 return_dict 为 `True`,返回 [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`],
                    否则返回一个元组,第一个元素是样本张量。
            """
    
            # 根据模型输出和 sigma_hat 计算预测的原始样本
            pred_original_sample = sample_hat + sigma_hat * model_output
            # 计算样本的导数
            derivative = (sample_hat - pred_original_sample) / sigma_hat
            # 计算上一个时间步的样本
            sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
    
            # 如果不返回字典,返回样本和导数
            if not return_dict:
                return (sample_prev, derivative)
    
            # 返回包含样本、导数和预测原始样本的 KarrasVeOutput 对象
            return KarrasVeOutput(
                prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
            )
    
    # 定义带有修正步骤的函数
    def step_correct(
            self,
            model_output: torch.Tensor,
            sigma_hat: float,
            sigma_prev: float,
            sample_hat: torch.Tensor,
            sample_prev: torch.Tensor,
            derivative: torch.Tensor,
            return_dict: bool = True,
    # 处理网络的模型输出,纠正预测样本
    ) -> Union[KarrasVeOutput, Tuple]:
            """
            # 根据网络的模型输出修正预测样本
    
            Args:
                model_output (`torch.Tensor`):
                    # 从学习的扩散模型直接输出的张量
                sigma_hat (`float`): TODO
                sigma_prev (`float`): TODO
                sample_hat (`torch.Tensor`): TODO
                sample_prev (`torch.Tensor`): TODO
                derivative (`torch.Tensor`): TODO
                return_dict (`bool`, *optional*, defaults to `True`):
                    # 是否返回 [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] 或 `tuple`
    
            Returns:
                prev_sample (TODO): # 在扩散链中的更新样本。 derivative (TODO): TODO
    
            """
            # 通过前一个样本和模型输出计算预测的原始样本
            pred_original_sample = sample_prev + sigma_prev * model_output
            # 计算修正后的导数
            derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
            # 更新前一个样本,根据当前和预测的导数
            sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
    
            # 如果不返回字典,则返回更新的样本和导数
            if not return_dict:
                return (sample_prev, derivative)
    
            # 返回 KarrasVeOutput 对象,包含更新的样本和导数
            return KarrasVeOutput(
                prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
            )
    
        # 声明未实现的方法,用于添加噪声
        def add_noise(self, original_samples, noise, timesteps):
            # 引发未实现错误
            raise NotImplementedError()