diffusers-源码解析-四十一-

diffusers 源码解析(四十一)

.\diffusers\pipelines\pag\__init__.py

# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING

# 从相对路径的 utils 模块中导入多个工具函数和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 用于慢速导入的常量
    OptionalDependencyNotAvailable,  # 可选依赖不可用异常
    _LazyModule,  # 用于懒加载模块的类
    get_objects_from_module,  # 从模块中获取对象的函数
    is_flax_available,  # 检查 flax 库是否可用的函数
    is_torch_available,  # 检查 torch 库是否可用的函数
    is_transformers_available,  # 检查 transformers 库是否可用的函数
)

# 定义一个空字典用于存放虚拟对象
_dummy_objects = {}
# 定义一个空字典用于存放导入结构
_import_structure = {}

try:
    # 检查 transformers 和 torch 库是否可用
    if not (is_transformers_available() and is_torch_available()):
        # 如果不可用,抛出可选依赖不可用异常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果捕获到可选依赖不可用异常,从 utils 中导入虚拟对象
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新 _dummy_objects 字典,获取虚拟对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 如果库可用,更新 _import_structure 字典,添加各个管道的导入信息
    _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
    _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
    _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
    _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
    _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
    _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
    _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
    _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
    _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
    _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
    _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]

# 如果在类型检查或慢速导入模式下
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 检查 transformers 和 torch 库是否可用
        if not (is_transformers_available() and is_torch_available()):
            # 如果不可用,抛出可选依赖不可用异常
            raise OptionalDependencyNotAvailable()

    except OptionalDependencyNotAvailable:
        # 捕获异常时,从 utils 中导入虚拟对象
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 如果库可用,导入具体的管道实现
        from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
        from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
        from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
        from .pipeline_pag_kolors import KolorsPAGPipeline
        from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
        from .pipeline_pag_sd import StableDiffusionPAGPipeline
        from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
        from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
        from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
        from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
        from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline

else:
    # 如果不在类型检查或慢速导入模式下,导入 sys 模块
    import sys

    # 使用 _LazyModule 创建一个懒加载的模块
    sys.modules[__name__] = _LazyModule(
        __name__,
        globals()["__file__"],  # 当前模块的文件路径
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规范
    )
    # 将 _dummy_objects 中的虚拟对象添加到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\paint_by_example\image_encoder.py

# 版权声明,说明代码的版权所有者及使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 该文件根据 Apache License, Version 2.0("许可证")授权; 
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件在“按现状”基础上分发,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参见许可证以了解有关权限和限制的具体语言。
import torch  # 导入 PyTorch 库
from torch import nn  # 从 PyTorch 导入神经网络模块
from transformers import CLIPPreTrainedModel, CLIPVisionModel  # 导入 CLIP 相关模型

from ...models.attention import BasicTransformerBlock  # 导入基本变换器块
from ...utils import logging  # 导入日志工具


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,禁用 Pylint 对无效名称的警告


class PaintByExampleImageEncoder(CLIPPreTrainedModel):  # 定义图像编码器类,继承自预训练的 CLIP 模型
    def __init__(self, config, proj_size=None):  # 初始化方法,接收配置和可选的投影大小
        super().__init__(config)  # 调用父类初始化方法
        self.proj_size = proj_size or getattr(config, "projection_dim", 768)  # 设置投影大小,默认为 768

        self.model = CLIPVisionModel(config)  # 创建 CLIP 视觉模型实例
        self.mapper = PaintByExampleMapper(config)  # 创建映射器实例
        self.final_layer_norm = nn.LayerNorm(config.hidden_size)  # 创建层归一化实例
        self.proj_out = nn.Linear(config.hidden_size, self.proj_size)  # 创建线性变换,用于输出投影

        # 用于缩放的无条件向量
        self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))  # 初始化无条件向量为随机值

    def forward(self, pixel_values, return_uncond_vector=False):  # 前向传播方法,接收像素值和是否返回无条件向量的标志
        clip_output = self.model(pixel_values=pixel_values)  # 通过模型处理像素值
        latent_states = clip_output.pooler_output  # 获取池化输出作为潜在状态
        latent_states = self.mapper(latent_states[:, None])  # 使用映射器处理潜在状态
        latent_states = self.final_layer_norm(latent_states)  # 对潜在状态进行层归一化
        latent_states = self.proj_out(latent_states)  # 通过线性层进行投影
        if return_uncond_vector:  # 如果需要返回无条件向量
            return latent_states, self.uncond_vector  # 返回潜在状态和无条件向量

        return latent_states  # 否则仅返回潜在状态


class PaintByExampleMapper(nn.Module):  # 定义映射器类,继承自 PyTorch 的 nn.Module
    def __init__(self, config):  # 初始化方法,接收配置
        super().__init__()  # 调用父类初始化方法
        num_layers = (config.num_hidden_layers + 1) // 5  # 计算层数,确保至少为 1
        hid_size = config.hidden_size  # 获取隐藏层大小
        num_heads = 1  # 设置注意力头的数量为 1
        self.blocks = nn.ModuleList(  # 创建模块列表,包含多个变换器块
            [
                BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True)  # 添加变换器块
                for _ in range(num_layers)  # 根据层数创建多个块
            ]
        )

    def forward(self, hidden_states):  # 前向传播方法,接收隐藏状态
        for block in self.blocks:  # 遍历所有变换器块
            hidden_states = block(hidden_states)  # 依次通过每个块处理隐藏状态

        return hidden_states  # 返回最终的隐藏状态

.\diffusers\pipelines\paint_by_example\pipeline_paint_by_example.py

# 版权声明,表明该文件的所有权和使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证,使用该文件的条件
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 可以在此获取许可证的副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,否则软件按“原样”分发,不提供任何形式的保证
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取有关权限和限制的具体信息
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect  # 导入inspect模块,用于获取对象的信息
from typing import Callable, List, Optional, Union  # 导入类型提示功能

import numpy as np  # 导入numpy库,用于数值计算
import PIL.Image  # 导入PIL库,用于图像处理
import torch  # 导入PyTorch库,用于深度学习
from transformers import CLIPImageProcessor  # 导入CLIP图像处理器

from ...image_processor import VaeImageProcessor  # 从上级模块导入VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel  # 导入模型相关类
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler  # 导入调度器
from ...utils import deprecate, logging  # 导入工具函数和日志记录
from ...utils.torch_utils import randn_tensor  # 从工具模块导入随机张量生成函数
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入扩散管道和混合类
from ..stable_diffusion import StableDiffusionPipelineOutput  # 导入稳定扩散管道的输出类
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker  # 导入安全检查器
from .image_encoder import PaintByExampleImageEncoder  # 从当前模块导入图像编码器

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,便于调试
# pylint: disable=invalid-name  # 禁用pylint关于名称的无效警告

# 从diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img复制的函数
def retrieve_latents(
    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
    # 检查encoder_output是否具有latent_dist属性且采样模式为'sample'
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        # 返回latent分布的样本
        return encoder_output.latent_dist.sample(generator)
    # 检查encoder_output是否具有latent_dist属性且采样模式为'argmax'
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        # 返回latent分布的众数
        return encoder_output.latent_dist.mode()
    # 检查encoder_output是否具有latents属性
    elif hasattr(encoder_output, "latents"):
        # 返回latents属性
        return encoder_output.latents
    # 如果都不满足,抛出异常
    else:
        raise AttributeError("Could not access latents of provided encoder_output")

# 准备图像和掩码以供“按示例绘制”管道使用
def prepare_mask_and_masked_image(image, mask):
    """
    准备一对 (image, mask),使其可以被 Paint by Example 管道使用。
    这意味着这些输入将转换为``torch.Tensor``,形状为``batch x channels x height x width``,
    其中``channels``为``3``(对于``image``)和``1``(对于``mask``)。

    ``image`` 将转换为 ``torch.float32`` 并归一化为 ``[-1, 1]``。
    ``mask`` 将被二值化(``mask > 0.5``)并同样转换为 ``torch.float32``。
    ```
    # 函数参数说明
    Args:
        # 输入图像,类型可以是 np.array、PIL.Image 或 torch.Tensor
        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
            # 描述图像的不同可能格式,包括 PIL.Image、np.array 或 torch.Tensor
            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
        # 掩码,用于指定需要修复的区域
        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
            # 描述掩码的不同可能格式,类似于图像
            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.

    # 异常说明
    Raises:
        # 触发条件为 torch.Tensor 格式图像或掩码的数值范围不正确
        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
        # 类型错误,当图像和掩码类型不匹配时抛出
        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
            (ot the other way around).

    # 返回值说明
    Returns:
        # 返回一个包含掩码和修复图像的元组,均为 torch.Tensor 格式,具有 4 个维度
        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
            dimensions: ``batch x channels x height x width``.
    """
    # 检查输入图像是否为 torch.Tensor 类型
    if isinstance(image, torch.Tensor):
        # 如果掩码不是 torch.Tensor,抛出类型错误
        if not isinstance(mask, torch.Tensor):
            raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")

        # 如果图像为单个图像,将其转换为批处理格式
        # Batch single image
        if image.ndim == 3:
            # 确保单个图像的形状为 (3, H, W)
            assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
            # 在第一个维度添加批处理维度
            image = image.unsqueeze(0)

        # 如果掩码为二维,添加批处理和通道维度
        # Batch and add channel dim for single mask
        if mask.ndim == 2:
            # 在前面添加两个维度
            mask = mask.unsqueeze(0).unsqueeze(0)

        # 如果掩码为三维,检查其与图像的批次匹配
        # Batch single mask or add channel dim
        if mask.ndim == 3:
            # Batched mask
            if mask.shape[0] == image.shape[0]:
                # 如果掩码的批次与图像相同,添加通道维度
                mask = mask.unsqueeze(1)
            else:
                # 否则,在前面添加批处理维度
                mask = mask.unsqueeze(0)

        # 确保图像和掩码都是四维
        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
        # 确保图像和掩码的空间维度相同
        assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
        # 确保图像和掩码的批处理大小相同
        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
        # 确保掩码只有一个通道
        assert mask.shape[1] == 1, "Mask image must have a single channel"

        # 检查图像的数值范围是否在 [-1, 1] 之间
        # Check image is in [-1, 1]
        if image.min() < -1 or image.max() > 1:
            raise ValueError("Image should be in [-1, 1] range")

        # 检查掩码的数值范围是否在 [0, 1] 之间
        # Check mask is in [0, 1]
        if mask.min() < 0 or mask.max() > 1:
            raise ValueError("Mask should be in [0, 1] range")

        # 对掩码进行反转,以便于修复
        # paint-by-example inverses the mask
        mask = 1 - mask

        # 二值化掩码
        # Binarize mask
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1

        # 将图像转换为 float32 类型
        # Image as float32
        image = image.to(dtype=torch.float32)
    # 如果掩码是 torch.Tensor 类型,但图像不是,则抛出类型错误
    elif isinstance(mask, torch.Tensor):
        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
    else:
        # 如果输入的 image 是 PIL 图像对象,则将其转换为列表
        if isinstance(image, PIL.Image.Image):
            image = [image]

        # 将每个图像转换为 RGB 格式,并拼接成一个数组,增加维度以适应后续处理
        image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
        # 将图像数组的维度顺序调整为 (批量, 通道, 高, 宽)
        image = image.transpose(0, 3, 1, 2)
        # 将 NumPy 数组转换为 PyTorch 张量并归一化到 [-1, 1] 范围
        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

        # 处理 mask
        # 如果输入的 mask 是 PIL 图像对象,则将其转换为列表
        if isinstance(mask, PIL.Image.Image):
            mask = [mask]

        # 将每个掩膜图像转换为灰度格式,并拼接成一个数组,增加维度以适应后续处理
        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
        # 将掩膜数组转换为 float32 类型并归一化到 [0, 1] 范围
        mask = mask.astype(np.float32) / 255.0

        # paint-by-example 方法反转掩膜
        mask = 1 - mask

        # 将掩膜中低于 0.5 的值设置为 0,高于或等于 0.5 的值设置为 1
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        # 将 NumPy 数组转换为 PyTorch 张量
        mask = torch.from_numpy(mask)

    # 将图像与掩膜相乘,得到被掩膜处理的图像
    masked_image = image * mask

    # 返回掩膜和被掩膜处理的图像
    return mask, masked_image
# 定义一个名为 PaintByExamplePipeline 的类,继承自 DiffusionPipeline 和 StableDiffusionMixin
class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
    r""" 
    # 警告提示,表示这是一个实验性特性
    <Tip warning={true}>
    🧪 This is an experimental feature!
    </Tip>

    # 使用 Stable Diffusion 进行图像引导的图像修补的管道。

    # 该模型从 [`DiffusionPipeline`] 继承。检查超类文档以获取所有管道的通用方法
    # (下载、保存、在特定设备上运行等)。

    # 参数说明:
        vae ([`AutoencoderKL`]):
            用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。
        image_encoder ([`PaintByExampleImageEncoder`]):
            编码示例输入图像。`unet` 是基于示例图像而非文本提示进行条件处理。
        tokenizer ([`~transformers.CLIPTokenizer`]):
            用于文本分词的 `CLIPTokenizer`。
        unet ([`UNet2DConditionModel`]):
            用于去噪编码图像潜在的 `UNet2DConditionModel`。
        scheduler ([`SchedulerMixin`]):
            与 `unet` 结合使用以去噪编码图像潜在的调度器,可以是
            [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`]。
        safety_checker ([`StableDiffusionSafetyChecker`]):
            估计生成图像是否可能被视为冒犯或有害的分类模块。
            请参考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以获取有关模型潜在危害的更多细节。
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            用于从生成图像中提取特征的 `CLIPImageProcessor`;用作 `safety_checker` 的输入。
    """

    # TODO: 如果管道没有 feature_extractor,则需要在初始图像(如果为 PIL 格式)编码时给出描述性消息。

    # 定义模型在 CPU 上卸载的顺序,指定 'unet' 在前,'vae' 在后
    model_cpu_offload_seq = "unet->vae"
    # 定义在 CPU 卸载时排除的组件,指定 'image_encoder' 不参与卸载
    _exclude_from_cpu_offload = ["image_encoder"]
    # 定义可选组件,指定 'safety_checker' 为可选
    _optional_components = ["safety_checker"]

    # 初始化方法,设置管道的主要组件
    def __init__(
        self,
        vae: AutoencoderKL,
        image_encoder: PaintByExampleImageEncoder,
        unet: UNet2DConditionModel,
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        requires_safety_checker: bool = False,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 注册各个模块,设置管道的组成部分
        self.register_modules(
            vae=vae,
            image_encoder=image_encoder,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        # 计算 VAE 的缩放因子,基于 VAE 配置中的块输出通道数
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 初始化 VaeImageProcessor,使用计算出的缩放因子
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        # 将是否需要安全检查器的信息注册到配置中
        self.register_to_config(requires_safety_checker=requires_safety_checker)
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 复制而来
        def run_safety_checker(self, image, device, dtype):
            # 如果安全检查器不存在,将有害概念标记为 None
            if self.safety_checker is None:
                has_nsfw_concept = None
            else:
                # 如果输入图像是张量,使用图像处理器后处理为 PIL 格式
                if torch.is_tensor(image):
                    feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
                else:
                    # 如果输入图像不是张量,将其转换为 PIL 格式
                    feature_extractor_input = self.image_processor.numpy_to_pil(image)
                # 使用特征提取器处理图像并转移到指定设备
                safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
                # 运行安全检查器,返回处理后的图像和有害概念的存在情况
                image, has_nsfw_concept = self.safety_checker(
                    images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
                )
            # 返回处理后的图像和有害概念
            return image, has_nsfw_concept
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制而来
        def prepare_extra_step_kwargs(self, generator, eta):
            # 准备调度器步骤的额外参数,因为并非所有调度器都有相同的签名
            # eta(η)仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
            # eta 对应于 DDIM 论文中的 η:https://arxiv.org/abs/2010.02502
            # 应在 [0, 1] 之间
    
            # 检查调度器是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            extra_step_kwargs = {}
            if accepts_eta:
                # 如果接受 eta,将其添加到额外参数中
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            if accepts_generator:
                # 如果接受 generator,将其添加到额外参数中
                extra_step_kwargs["generator"] = generator
            # 返回准备好的额外参数
            return extra_step_kwargs
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 复制而来
        def decode_latents(self, latents):
            # 警告信息,提示 decode_latents 方法已弃用,将在 1.0.0 中移除
            deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
            # 记录弃用警告
            deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
    
            # 根据 VAE 的缩放因子调整潜在向量
            latents = 1 / self.vae.config.scaling_factor * latents
            # 解码潜在向量,返回图像
            image = self.vae.decode(latents, return_dict=False)[0]
            # 将图像缩放到 [0, 1] 范围内
            image = (image / 2 + 0.5).clamp(0, 1)
            # 将图像转换为 float32 格式以兼容 bfloat16
            image = image.cpu().permute(0, 2, 3, 1).float().numpy()
            # 返回处理后的图像
            return image
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs 复制而来
    # 检查输入参数的有效性
    def check_inputs(self, image, height, width, callback_steps):
        # 检查 `image` 是否为有效类型,必须是 `torch.Tensor`、`PIL.Image.Image` 或列表
        if (
            not isinstance(image, torch.Tensor)
            and not isinstance(image, PIL.Image.Image)
            and not isinstance(image, list)
        ):
            # 如果 `image` 类型不符合,抛出错误
            raise ValueError(
                "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
                f" {type(image)}"
            )

        # 检查 `height` 和 `width` 是否能被 8 整除
        if height % 8 != 0 or width % 8 != 0:
            # 如果不能整除,抛出错误
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        # 检查 `callback_steps` 是否有效,必须是正整数或 None
        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            # 如果无效,抛出错误
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

    # 从 StableDiffusionPipeline 复制的准备潜在变量的方法
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 定义潜在变量的形状
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 检查生成器列表的长度是否与批量大小匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 如果不匹配,抛出错误
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        # 如果没有提供潜在变量,则生成随机潜在变量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 如果提供了潜在变量,将其转移到指定设备
            latents = latents.to(device)

        # 将初始噪声缩放到调度器所需的标准差
        latents = latents * self.scheduler.init_noise_sigma
        # 返回准备好的潜在变量
        return latents

    # 从 StableDiffusionInpaintPipeline 复制的准备掩膜潜在变量的方法
    def prepare_mask_latents(
        # 定义方法的输入参数,包括掩膜和其他信息
        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
    ):
        # 将掩码调整为与潜变量形状相同,以便将掩码与潜变量拼接
        # 这样做可以避免在使用 cpu_offload 和半精度时出现问题
        mask = torch.nn.functional.interpolate(
            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
        )  # 通过插值调整掩码的大小
        mask = mask.to(device=device, dtype=dtype)  # 将掩码移动到指定设备并转换数据类型

        masked_image = masked_image.to(device=device, dtype=dtype)  # 将掩码图像移动到指定设备并转换数据类型

        if masked_image.shape[1] == 4:  # 检查掩码图像是否为四通道
            masked_image_latents = masked_image  # 如果是,直接将其赋值给潜变量
        else:
            masked_image_latents = self._encode_vae_image(masked_image, generator=generator)  # 否则,使用 VAE 编码图像

        # 针对每个提示重复掩码和掩码图像潜变量,使用适合 MPS 的方法
        if mask.shape[0] < batch_size:  # 检查掩码的数量是否少于批处理大小
            if not batch_size % mask.shape[0] == 0:  # 检查掩码数量是否可整除批处理大小
                raise ValueError(
                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
                    f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
                    " of masks that you pass is divisible by the total requested batch size."
                )  # 如果不匹配,抛出值错误
            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)  # 重复掩码以匹配批处理大小
        if masked_image_latents.shape[0] < batch_size:  # 检查潜变量数量是否少于批处理大小
            if not batch_size % masked_image_latents.shape[0] == 0:  # 检查潜变量数量是否可整除批处理大小
                raise ValueError(
                    "The passed images and the required batch size don't match. Images are supposed to be duplicated"
                    f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
                    " Make sure the number of images that you pass is divisible by the total requested batch size."
                )  # 如果不匹配,抛出值错误
            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)  # 重复潜变量以匹配批处理大小

        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask  # 根据是否使用无分类器引导选择掩码
        masked_image_latents = (
            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
        )  # 根据是否使用无分类器引导选择潜变量

        # 调整设备以防止与潜变量模型输入拼接时出现设备错误
        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)  # 将潜变量移动到指定设备并转换数据类型
        return mask, masked_image_latents  # 返回处理后的掩码和潜变量

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image 复制
    # 定义一个编码变分自编码器图像的私有方法
        def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
            # 检查生成器是否为列表类型
            if isinstance(generator, list):
                # 对每个图像编码并获取潜在表示,使用对应的生成器
                image_latents = [
                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
                    for i in range(image.shape[0])
                ]
                # 将所有潜在表示在第0维上拼接成一个张量
                image_latents = torch.cat(image_latents, dim=0)
            else:
                # 如果生成器不是列表,直接编码图像并获取潜在表示
                image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
    
            # 将潜在表示乘以缩放因子
            image_latents = self.vae.config.scaling_factor * image_latents
    
            # 返回编码后的潜在表示
            return image_latents
    
        # 定义一个编码图像的私有方法
        def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
            # 获取图像编码器参数的数据类型
            dtype = next(self.image_encoder.parameters()).dtype
    
            # 检查输入图像是否为张量,如果不是则提取特征
            if not isinstance(image, torch.Tensor):
                image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
    
            # 将图像移动到指定设备,并转换为正确的数据类型
            image = image.to(device=device, dtype=dtype)
            # 对图像进行编码,获取图像嵌入和负提示嵌入
            image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
    
            # 复制图像嵌入以适应每个提示的生成数量
            bs_embed, seq_len, _ = image_embeddings.shape
            image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
            # 重塑嵌入张量的形状
            image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
    
            # 检查是否使用无分类器引导
            if do_classifier_free_guidance:
                # 复制负提示嵌入以匹配图像嵌入的数量
                negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
                negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)
    
                # 为无分类器引导执行两个前向传播,通过拼接无条件和文本嵌入来避免两个前向传播
                image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
    
            # 返回编码后的图像嵌入
            return image_embeddings
    
        # 定义一个调用方法,禁用梯度计算以提高效率
        @torch.no_grad()
        def __call__(
            # 接收示例图像和图像的参数,允许不同类型的输入
            example_image: Union[torch.Tensor, PIL.Image.Image],
            image: Union[torch.Tensor, PIL.Image.Image],
            mask_image: Union[torch.Tensor, PIL.Image.Image],
            # 可选参数定义图像的高度和宽度
            height: Optional[int] = None,
            width: Optional[int] = None,
            # 定义推理步骤的数量和引导缩放比例
            num_inference_steps: int = 50,
            guidance_scale: float = 5.0,
            # 负提示的可选输入
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量
            num_images_per_prompt: Optional[int] = 1,
            # 控制采样多样性的参数
            eta: float = 0.0,
            # 生成器的可选输入
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 可选的潜在变量输入
            latents: Optional[torch.Tensor] = None,
            # 输出类型的可选参数
            output_type: Optional[str] = "pil",
            # 是否返回字典格式的结果
            return_dict: bool = True,
            # 可选的回调函数用于处理中间结果
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 每隔多少步调用一次回调
            callback_steps: int = 1,

.\diffusers\pipelines\paint_by_example\__init__.py

# 从 dataclasses 模块导入 dataclass 装饰器,用于简化数据类的定义
from dataclasses import dataclass
# 从 typing 模块导入类型检查所需的类型提示
from typing import TYPE_CHECKING, List, Optional, Union

# 导入 numpy 库,通常用于数值计算
import numpy as np
# 导入 PIL 库,处理图像
import PIL
# 从 PIL 导入 Image 类,用于图像处理
from PIL import Image

# 从相对路径导入 utils 模块中的多个工具函数和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 常量,指示是否需要慢速导入
    OptionalDependencyNotAvailable,  # 异常类,表示可选依赖不可用
    _LazyModule,  # 类,延迟加载模块的实现
    get_objects_from_module,  # 函数,从模块中获取对象
    is_torch_available,  # 函数,检查 PyTorch 是否可用
    is_transformers_available,  # 函数,检查 Transformers 是否可用
)

# 初始化一个空字典,存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典,用于存储模块导入结构
_import_structure = {}

# 尝试检查可选依赖项是否可用
try:
    # 如果 Transformers 和 PyTorch 不可用,抛出异常
    if not (is_transformers_available() and is_torch_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 模块导入虚拟对象,避免错误使用
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新虚拟对象字典,添加从虚拟模块获取的对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果没有异常,更新导入结构字典
else:
    _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"]  # 指定图像编码器模块
    _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"]  # 指定图像处理管道模块

# 如果正在进行类型检查或需要慢速导入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    # 尝试再次检查可选依赖项是否可用
    try:
        if not (is_transformers_available() and is_torch_available()):
            raise OptionalDependencyNotAvailable()

    # 捕获可选依赖不可用的异常
    except OptionalDependencyNotAvailable:
        # 从虚拟对象模块导入所有内容
        from ...utils.dummy_torch_and_transformers_objects import *
    # 如果没有异常,从实际模块导入所需的类
    else:
        from .image_encoder import PaintByExampleImageEncoder  # 导入图像编码器
        from .pipeline_paint_by_example import PaintByExamplePipeline  # 导入图像处理管道

# 如果不是类型检查且不需要慢速导入
else:
    import sys  # 导入 sys 模块以访问 Python 运行时

    # 使用 LazyModule 创建一个延迟加载的模块,减少启动时间
    sys.modules[__name__] = _LazyModule(
        __name__,  # 当前模块名
        globals()["__file__"],  # 当前模块的文件路径
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规范
    )

    # 将虚拟对象添加到当前模块中
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)  # 设置属性

.\diffusers\pipelines\pia\pipeline_pia.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件
# 根据许可证分发是在“按现状”基础上提供的,
# 不附带任何形式的明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和
# 限制的特定语言。

import inspect  # 导入 inspect 模块,用于检查对象的类型和属性
from dataclasses import dataclass  # 从 dataclasses 导入 dataclass 装饰器,用于简化类的定义
from typing import Any, Callable, Dict, List, Optional, Union  # 导入类型注解,用于类型提示

import numpy as np  # 导入 NumPy 库,通常用于数值计算和数组操作
import PIL  # 导入 PIL 库,用于图像处理
import torch  # 导入 PyTorch 库,用于深度学习
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection  # 从 transformers 导入 CLIP 相关类

from ...image_processor import PipelineImageInput  # 从本地模块导入 PipelineImageInput 类
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  # 从本地模块导入各种混合类
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel  # 从本地模块导入模型类
from ...models.lora import adjust_lora_scale_text_encoder  # 从本地模块导入调整 Lora 缩放的函数
from ...models.unets.unet_motion_model import MotionAdapter  # 从本地模块导入 MotionAdapter 类
from ...schedulers import (  # 从本地模块导入调度器类
    DDIMScheduler,  # 导入 DDIM 调度器
    DPMSolverMultistepScheduler,  # 导入 DPM 多步调度器
    EulerAncestralDiscreteScheduler,  # 导入 Euler 祖先离散调度器
    EulerDiscreteScheduler,  # 导入 Euler 离散调度器
    LMSDiscreteScheduler,  # 导入 LMS 离散调度器
    PNDMScheduler,  # 导入 PNDM 调度器
)
from ...utils import (  # 从本地模块导入各种工具函数和常量
    USE_PEFT_BACKEND,  # 导入标识是否使用 PEFT 后端的常量
    BaseOutput,  # 导入 BaseOutput 类,通常用于输出格式
    logging,  # 导入 logging 模块,用于记录日志
    replace_example_docstring,  # 导入替换示例文档字符串的函数
    scale_lora_layers,  # 导入缩放 Lora 层的函数
    unscale_lora_layers,  # 导入取消缩放 Lora 层的函数
)
from ...utils.torch_utils import randn_tensor  # 从本地模块导入生成随机张量的函数
from ...video_processor import VideoProcessor  # 从本地模块导入视频处理类
from ..free_init_utils import FreeInitMixin  # 从上级模块导入 FreeInitMixin 类
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 从上级模块导入扩散管道和稳定扩散混合类


logger = logging.get_logger(__name__)  # 创建一个记录器实例,用于记录当前模块的日志,禁用 pylint 对名称的警告

EXAMPLE_DOC_STRING = """  # 定义一个示例文档字符串的常量
```  # 该常量的开始部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量的结束部分
```py  # 该常量的结束部分
```  # 该常量
    # 示例代码的使用说明
        Examples:
            # 导入所需的库
            ```py
            >>> import torch  # 导入 PyTorch 库
            >>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline  # 从 diffusers 导入相关类
            >>> from diffusers.utils import export_to_gif, load_image  # 导入工具函数
    
            # 从预训练模型加载运动适配器
            >>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
            # 从预训练模型创建 PIAPipeline 对象,并指定适配器和数据类型
            >>> pipe = PIAPipeline.from_pretrained(
            ...     "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16
            ... )
    
            # 设置调度器为 EulerDiscreteScheduler,并使用现有配置
            >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
            # 从指定 URL 加载图像
            >>> image = load_image(
            ...     "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
            ... )
            # 调整图像大小为 512x512 像素
            >>> image = image.resize((512, 512))
            # 定义正向提示内容
            >>> prompt = "cat in a hat"
            # 定义反向提示内容,以减少不良效果
            >>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted"
            # 创建一个随机数生成器并设置种子
            >>> generator = torch.Generator("cpu").manual_seed(0)
            # 生成输出图像,通过管道处理输入图像和提示
            >>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator)
            # 获取输出结果中的第一帧
            >>> frames = output.frames[0]
            # 将帧导出为 GIF 动画文件
            >>> export_to_gif(frames, "pia-animation.gif")
"""
# 定义一个包含不同运动范围的列表,每个子列表代表不同类型的运动
RANGE_LIST = [
    [1.0, 0.9, 0.85, 0.85, 0.85, 0.8],  # 0 小运动
    [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75],  # 中等运动
    [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5],  # 大运动
    [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0],  # 循环
    [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0],  # 循环
    [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0],  # 循环
    [0.5, 0.4, 0.4, 0.4, 0.35, 0.3],  # 风格迁移候选小运动
    [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2],  # 风格迁移中等运动
    [0.5, 0.2],  # 风格迁移大运动
]


# 定义一个函数,根据统计信息准备掩码系数
def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int):
    # 确保视频帧数大于 0
    assert num_frames > 0, "video_length should be greater than 0"

    # 确保视频帧数大于条件帧
    assert num_frames > cond_frame, "video_length should be greater than cond_frame"

    # 将 RANGE_LIST 赋值给 range_list
    range_list = RANGE_LIST

    # 确保运动缩放类型在范围列表中可用
    assert motion_scale < len(range_list), f"motion_scale type{motion_scale} not implemented"

    # 根据运动缩放类型获取对应的系数
    coef = range_list[motion_scale]
    # 用最后一个系数填充至 num_frames 长度
    coef = coef + ([coef[-1]] * (num_frames - len(coef)))

    # 计算每帧与条件帧的距离
    order = [abs(i - cond_frame) for i in range(num_frames)]
    # 根据距离重新排列系数
    coef = [coef[order[i]] for i in range(num_frames)]

    # 返回重新排列后的系数
    return coef


@dataclass
# 定义一个数据类,用于 PIAPipeline 的输出
class PIAPipelineOutput(BaseOutput):
    r"""
    PIAPipeline 的输出类。

    参数:
        frames (`torch.Tensor`, `np.ndarray`, 或 List[List[PIL.Image.Image]]):
            长度为 `batch_size` 的嵌套列表,包含每个 `num_frames` 的去噪 PIL 图像序列,形状为
            `(batch_size, num_frames, channels, height, width)` 的 NumPy 数组,或形状为 
            `(batch_size, num_frames, channels, height, width)` 的 Torch 张量。
    """

    # 输出帧,可以是多种数据类型
    frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]


# 定义一个用于文本到视频生成的管道类
class PIAPipeline(
    DiffusionPipeline,
    StableDiffusionMixin,
    TextualInversionLoaderMixin,
    IPAdapterMixin,
    StableDiffusionLoraLoaderMixin,
    FromSingleFileMixin,
    FreeInitMixin,
):
    r"""
    文本到视频生成的管道。

    此模型继承自 [`DiffusionPipeline`]. 请查看超类文档以了解所有管道实现的通用方法
    (下载、保存、在特定设备上运行等)。

    此管道还继承以下加载方法:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用于加载文本反转嵌入
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] 用于加载 IP 适配器
    # 函数参数说明
    Args:
        vae ([`AutoencoderKL`]):
            # 变分自编码器 (VAE) 模型,用于将图像编码和解码为潜在表示
        text_encoder ([`CLIPTextModel`]):
            # 冻结的文本编码器,使用 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
        tokenizer (`CLIPTokenizer`):
            # [`~transformers.CLIPTokenizer`] 用于对文本进行标记化
        unet ([`UNet2DConditionModel`]):
            # [`UNet2DConditionModel`] 用于创建 UNetMotionModel,以去噪编码后的视频潜在特征
        motion_adapter ([`MotionAdapter`]):
            # [`MotionAdapter`] 与 `unet` 结合使用,以去噪编码后的视频潜在特征
        scheduler ([`SchedulerMixin`]):
            # 与 `unet` 结合使用的调度器,用于去噪编码后的图像潜在特征。可以是
            # [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`]
    """

    # 定义模型 CPU 卸载顺序
    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    # 定义可选组件列表
    _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
    # 定义回调张量输入列表
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

    def __init__(
        # 初始化方法的参数列表
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: Union[UNet2DConditionModel, UNetMotionModel],
        scheduler: Union[
            # 允许的调度器类型
            DDIMScheduler,
            PNDMScheduler,
            LMSDiscreteScheduler,
            EulerDiscreteScheduler,
            EulerAncestralDiscreteScheduler,
            DPMSolverMultistepScheduler,
        ],
        motion_adapter: Optional[MotionAdapter] = None,
        feature_extractor: CLIPImageProcessor = None,
        image_encoder: CLIPVisionModelWithProjection = None,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果 unet 是 UNet2DConditionModel 的实例,则转换为 UNetMotionModel
        if isinstance(unet, UNet2DConditionModel):
            unet = UNetMotionModel.from_unet2d(unet, motion_adapter)

        # 注册模块,初始化各个组件
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            motion_adapter=motion_adapter,
            scheduler=scheduler,
            feature_extractor=feature_extractor,
            image_encoder=image_encoder,
        )
        # 计算 VAE 的缩放因子
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 初始化视频处理器,设置是否调整大小和 VAE 缩放因子
        self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt 中复制,num_images_per_prompt -> num_videos_per_prompt
    # 定义一个编码提示的函数,接受多个参数
        def encode_prompt(
            self,  # 类的实例
            prompt,  # 输入的提示文本
            device,  # 计算设备(如 CPU 或 GPU)
            num_images_per_prompt,  # 每个提示生成的图像数量
            do_classifier_free_guidance,  # 是否使用无分类器引导
            negative_prompt=None,  # 可选的负面提示文本
            prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入张量
            negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入张量
            lora_scale: Optional[float] = None,  # 可选的 LoRA 缩放因子
            clip_skip: Optional[int] = None,  # 可选的剪辑跳过参数
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image 复制的函数
        def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):  # 定义图像编码函数
            dtype = next(self.image_encoder.parameters()).dtype  # 获取图像编码器参数的数据类型
    
            if not isinstance(image, torch.Tensor):  # 如果输入的图像不是张量
                image = self.feature_extractor(image, return_tensors="pt").pixel_values  # 使用特征提取器将其转换为张量
    
            image = image.to(device=device, dtype=dtype)  # 将图像张量移动到指定设备并设置数据类型
            if output_hidden_states:  # 如果需要输出隐藏状态
                image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]  # 编码图像并获取倒数第二个隐藏状态
                image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)  # 根据生成图像数量重复隐藏状态
                uncond_image_enc_hidden_states = self.image_encoder(  # 编码全零图像以获取无条件隐藏状态
                    torch.zeros_like(image), output_hidden_states=True
                ).hidden_states[-2]  # 获取无条件图像的倒数第二个隐藏状态
                uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(  # 重复无条件隐藏状态
                    num_images_per_prompt, dim=0
                )
                return image_enc_hidden_states, uncond_image_enc_hidden_states  # 返回编码的图像隐藏状态和无条件隐藏状态
            else:  # 如果不需要输出隐藏状态
                image_embeds = self.image_encoder(image).image_embeds  # 编码图像以获取图像嵌入
                image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)  # 根据生成图像数量重复图像嵌入
                uncond_image_embeds = torch.zeros_like(image_embeds)  # 创建与图像嵌入相同形状的全零无条件嵌入
    
                return image_embeds, uncond_image_embeds  # 返回图像嵌入和无条件嵌入
    
        # 从 diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents 复制的函数
        def decode_latents(self, latents):  # 定义解码潜变量的函数
            latents = 1 / self.vae.config.scaling_factor * latents  # 按缩放因子调整潜变量
    
            batch_size, channels, num_frames, height, width = latents.shape  # 获取潜变量的形状信息
            latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)  # 重新排列并重塑潜变量
    
            image = self.vae.decode(latents).sample  # 使用 VAE 解码潜变量以获取图像样本
            video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)  # 处理图像以生成视频格式
            # 我们总是将其转换为 float32,因为这不会造成显著开销,并且与 bfloat16 兼容
            video = video.float()  # 将视频数据转换为 float32 类型
            return video  # 返回解码后的视频数据
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的函数
    # 定义准备额外参数的方法,供调度器步骤使用
    def prepare_extra_step_kwargs(self, generator, eta):
        # 为调度器步骤准备额外的关键字参数,因为不同调度器的参数签名可能不同
        # eta(η)仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
        # eta 对应于 DDIM 论文中的 η:https://arxiv.org/abs/2010.02502
        # 并且应该在 [0, 1] 之间

        # 检查调度器的步骤函数是否接受 eta 参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化额外步骤参数字典
        extra_step_kwargs = {}
        # 如果接受 eta 参数,则将其添加到额外参数字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 检查调度器的步骤函数是否接受 generator 参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受 generator 参数,则将其添加到额外参数字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回准备好的额外步骤参数字典
        return extra_step_kwargs

    # 定义检查输入参数的方法
    def check_inputs(
        self,
        prompt,  # 输入提示
        height,  # 输出图像的高度
        width,   # 输出图像的宽度
        negative_prompt=None,  # 可选的负面提示
        prompt_embeds=None,    # 可选的提示嵌入
        negative_prompt_embeds=None,  # 可选的负面提示嵌入
        ip_adapter_image=None,  # 可选的图像适配器输入图像
        ip_adapter_image_embeds=None,  # 可选的图像适配器输入嵌入
        callback_on_step_end_tensor_inputs=None,  # 可选的步骤结束回调输入
    ):
        # 检查高度和宽度是否能被8整除
        if height % 8 != 0 or width % 8 != 0:
            # 抛出值错误,提示高度和宽度不符合要求
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        # 检查回调输入是否存在且不全在已注册的回调输入中
        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            # 抛出值错误,提示未找到有效的回调输入
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )

        # 检查是否同时传入了 prompt 和 prompt_embeds
        if prompt is not None and prompt_embeds is not None:
            # 抛出值错误,提示不能同时传入两者
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        # 检查是否同时未传入 prompt 和 prompt_embeds
        elif prompt is None and prompt_embeds is None:
            # 抛出值错误,提示至少需要提供一个
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        # 检查 prompt 的类型
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            # 抛出值错误,提示 prompt 的类型不正确
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 检查是否同时传入了 negative_prompt 和 negative_prompt_embeds
        if negative_prompt is not None and negative_prompt_embeds is not None:
            # 抛出值错误,提示不能同时传入两者
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        # 检查 prompt_embeds 和 negative_prompt_embeds 是否形状一致
        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                # 抛出值错误,提示两者形状不一致
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

        # 检查是否同时传入了 ip_adapter_image 和 ip_adapter_image_embeds
        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
            # 抛出值错误,提示不能同时传入两者
            raise ValueError(
                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
            )

        # 检查 ip_adapter_image_embeds 是否存在
        if ip_adapter_image_embeds is not None:
            # 检查类型是否为列表
            if not isinstance(ip_adapter_image_embeds, list):
                # 抛出值错误,提示类型不正确
                raise ValueError(
                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
                )
            # 检查列表中的第一个元素的维度
            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
                # 抛出值错误,提示维度不符合要求
                raise ValueError(
                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
                )
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds 复制的代码
        def prepare_ip_adapter_image_embeds(
            self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
        ):
            # 初始化图像嵌入列表
            image_embeds = []
            # 如果启用分类器自由引导,初始化负图像嵌入列表
            if do_classifier_free_guidance:
                negative_image_embeds = []
            # 如果没有提供图像嵌入
            if ip_adapter_image_embeds is None:
                # 确保输入的图像是列表格式
                if not isinstance(ip_adapter_image, list):
                    ip_adapter_image = [ip_adapter_image]
                # 检查图像数量是否与适配器数量匹配
                if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                    raise ValueError(
                        f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                    )
                # 遍历每个图像和对应的投影层
                for single_ip_adapter_image, image_proj_layer in zip(
                    ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
                ):
                    # 确定是否输出隐藏状态
                    output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                    # 编码单个图像,获取嵌入
                    single_image_embeds, single_negative_image_embeds = self.encode_image(
                        single_ip_adapter_image, device, 1, output_hidden_state
                    )
                    # 将嵌入添加到列表中
                    image_embeds.append(single_image_embeds[None, :])
                    # 如果启用分类器自由引导,添加负嵌入
                    if do_classifier_free_guidance:
                        negative_image_embeds.append(single_negative_image_embeds[None, :])
            else:
                # 如果提供了图像嵌入,遍历嵌入列表
                for single_image_embeds in ip_adapter_image_embeds:
                    # 处理负图像嵌入(如果启用引导)
                    if do_classifier_free_guidance:
                        single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                        negative_image_embeds.append(single_negative_image_embeds)
                    # 将嵌入添加到列表中
                    image_embeds.append(single_image_embeds)
    
            # 初始化适配器图像嵌入列表
            ip_adapter_image_embeds = []
            # 遍历图像嵌入,为每个图像复制指定数量
            for i, single_image_embeds in enumerate(image_embeds):
                single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
                # 如果启用分类器自由引导,处理负嵌入
                if do_classifier_free_guidance:
                    single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
    
                # 将嵌入移动到指定设备
                single_image_embeds = single_image_embeds.to(device=device)
                # 添加到适配器图像嵌入列表
                ip_adapter_image_embeds.append(single_image_embeds)
    
            # 返回适配器图像嵌入
            return ip_adapter_image_embeds
    
        # 从 diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents 复制的代码
        def prepare_latents(
            self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
    ):
        # 定义输出张量的形状,包括批量大小、通道数、帧数以及缩放后的高度和宽度
        shape = (
            batch_size,
            num_channels_latents,
            num_frames,
            height // self.vae_scale_factor,
            width // self.vae_scale_factor,
        )
        # 检查生成器是否为列表且其长度与批量大小不匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 抛出错误,提示生成器的长度与请求的有效批量大小不匹配
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        # 如果潜变量为 None,则生成随机张量作为潜变量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 将给定的潜变量转移到指定设备
            latents = latents.to(device)

        # 将初始噪声按调度器要求的标准差进行缩放
        latents = latents * self.scheduler.init_noise_sigma
        # 返回处理后的潜变量
        return latents

    def prepare_masked_condition(
        self,
        image,
        batch_size,
        num_channels_latents,
        num_frames,
        height,
        width,
        dtype,
        device,
        generator,
        motion_scale=0,
    ):
        # 定义输出张量的形状,包括批量大小、通道数、帧数以及缩放后的高度和宽度
        shape = (
            batch_size,
            num_channels_latents,
            num_frames,
            height // self.vae_scale_factor,
            width // self.vae_scale_factor,
        )
        # 解包形状信息,获取缩放后的高度和宽度
        _, _, _, scaled_height, scaled_width = shape

        # 对输入图像进行预处理
        image = self.video_processor.preprocess(image)
        # 将预处理后的图像转移到指定设备并转换为指定数据类型
        image = image.to(device, dtype)

        # 如果生成器是列表,则逐个编码每个图像并采样
        if isinstance(generator, list):
            image_latent = [
                self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size)
            ]
            # 将所有潜变量张量按维度 0 连接起来
            image_latent = torch.cat(image_latent, dim=0)
        else:
            # 否则直接编码图像并采样
            image_latent = self.vae.encode(image).latent_dist.sample(generator)

        # 将潜变量转移到指定设备并转换为指定数据类型
        image_latent = image_latent.to(device=device, dtype=dtype)
        # 对潜变量进行插值调整,改变大小到缩放后的高度和宽度
        image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width])
        # 复制潜变量并按配置的缩放因子进行缩放
        image_latent_padding = image_latent.clone() * self.vae.config.scaling_factor

        # 创建一个全零的掩码张量,大小与批量大小、帧数、缩放后的高度和宽度匹配
        mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype)
        # 根据统计信息准备掩码系数
        mask_coef = prepare_mask_coef_by_statistics(num_frames, 0, motion_scale)
        # 创建一个全零的张量用于存储被掩盖的图像,形状与 image_latent_padding 匹配
        masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to(
            device=device, dtype=self.unet.dtype
        )
        # 遍历每一帧,更新掩码和被掩盖的图像
        for f in range(num_frames):
            mask[:, :, f, :, :] = mask_coef[f]
            masked_image[:, :, f, :, :] = image_latent_padding.clone()

        # 根据条件决定是否复制掩码以支持分类器自由引导
        mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
        # 根据条件决定是否复制被掩盖的图像以支持分类器自由引导
        masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image

        # 返回掩码和被掩盖的图像
        return mask, masked_image

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps 复制
    # 定义获取时间步长的方法,接受推理步数、强度和设备作为参数
    def get_timesteps(self, num_inference_steps, strength, device):
        # 计算初始时间步,确保不超过总推理步数
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
        # 计算开始时间步,确保不小于零
        t_start = max(num_inference_steps - init_timestep, 0)
        # 从调度器中获取相应的时间步长
        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
        # 如果调度器具有设置开始索引的方法,则调用该方法
        if hasattr(self.scheduler, "set_begin_index"):
            self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
        # 返回计算得到的时间步和剩余推理步数
        return timesteps, num_inference_steps - t_start
    
        # 定义一个属性,用于获取引导比例
        @property
        def guidance_scale(self):
            return self._guidance_scale
    
        # 定义一个属性,用于获取剪切跳过参数
        @property
        def clip_skip(self):
            return self._clip_skip
    
        # 定义一个属性,判断是否使用无分类器引导
        @property
        def do_classifier_free_guidance(self):
            return self._guidance_scale > 1
    
        # 定义一个属性,用于获取交叉注意力参数
        @property
        def cross_attention_kwargs(self):
            return self._cross_attention_kwargs
    
        # 定义一个属性,用于获取时间步数
        @property
        def num_timesteps(self):
            return self._num_timesteps
    
        # 禁用梯度计算,装饰器用于优化性能
        @torch.no_grad()
        # 替换示例文档字符串
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        # 定义调用方法,接受多个输入参数
        def __call__(
            self,
            image: PipelineImageInput,
            prompt: Union[str, List[str]] = None,
            strength: float = 1.0,
            num_frames: Optional[int] = 16,
            height: Optional[int] = None,
            width: Optional[int] = None,
            num_inference_steps: int = 50,
            guidance_scale: float = 7.5,
            negative_prompt: Optional[Union[str, List[str]]] = None,
            num_videos_per_prompt: Optional[int] = 1,
            eta: float = 0.0,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            latents: Optional[torch.Tensor] = None,
            prompt_embeds: Optional[torch.Tensor] = None,
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            ip_adapter_image: Optional[PipelineImageInput] = None,
            ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
            motion_scale: int = 0,
            output_type: Optional[str] = "pil",
            return_dict: bool = True,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            clip_skip: Optional[int] = None,
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],

.\diffusers\pipelines\pia\__init__.py

# 导入类型检查相关的常量
from typing import TYPE_CHECKING

# 从 utils 模块导入所需的常量和函数
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 慢导入的标志
    OptionalDependencyNotAvailable,  # 可选依赖不可用的异常
    _LazyModule,  # 懒加载模块的类
    get_objects_from_module,  # 从模块中获取对象的函数
    is_torch_available,  # 检查是否可用 PyTorch 的函数
    is_transformers_available,  # 检查是否可用 Transformers 的函数
)

# 初始化一个空字典用于存储假对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}

try:
    # 检查 Transformers 和 PyTorch 是否可用
    if not (is_transformers_available() and is_torch_available()):
        # 如果不可用,则抛出异常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果捕获到可选依赖不可用的异常,则导入假对象
    from ...utils import dummy_torch_and_transformers_objects

    # 将假对象更新到 _dummy_objects 字典
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 如果依赖可用,则更新导入结构
    _import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"]

# 检查类型检查或慢导入的标志
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 再次检查 Transformers 和 PyTorch 是否可用
        if not (is_transformers_available() and is_torch_available()):
            # 如果不可用,则抛出异常
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果捕获到可选依赖不可用的异常,则导入假对象
        from ...utils.dummy_torch_and_transformers_objects import *

    else:
        # 如果依赖可用,则从 pipeline_pia 导入相关类
        from .pipeline_pia import PIAPipeline, PIAPipelineOutput

else:
    # 如果不是类型检查或慢导入,则使用懒加载模块
    import sys

    # 将当前模块替换为懒加载模块实例
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名称
        globals()["__file__"],  # 模块文件路径
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规格
    )
    # 将假对象设置到当前模块中
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\pipeline_flax_utils.py

# 指定编码为 UTF-8
# coding=utf-8
# 版权声明,表明文件归 HuggingFace Inc. 团队所有
# 版权声明,表明文件归 NVIDIA CORPORATION 所有
#
# 根据 Apache License, Version 2.0 许可使用本文件
# 只能在遵循许可的情况下使用本文件
# 可以在以下网址获取许可
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则软件以“现状”分发
# 不提供任何明示或暗示的保证或条件
# 详见许可证中关于权限和限制的具体条款

# 导入 importlib 模块,用于动态导入模块
import importlib
# 导入 inspect 模块,用于获取对象的内部信息
import inspect
# 导入 os 模块,用于与操作系统交互
import os
# 从 typing 模块导入各种类型提示
from typing import Any, Dict, List, Optional, Union

# 导入 flax 框架
import flax
# 导入 numpy 库,用于数值计算
import numpy as np
# 导入 PIL.Image,用于图像处理
import PIL.Image
# 从 flax.core.frozen_dict 导入 FrozenDict,用于不可变字典
from flax.core.frozen_dict import FrozenDict
# 从 huggingface_hub 导入创建仓库和下载快照的函数
from huggingface_hub import create_repo, snapshot_download
# 从 huggingface_hub.utils 导入参数验证函数
from huggingface_hub.utils import validate_hf_hub_args
# 从 PIL 导入 Image 用于图像处理
from PIL import Image
# 从 tqdm.auto 导入进度条显示
from tqdm.auto import tqdm

# 从上层模块导入 ConfigMixin 类
from ..configuration_utils import ConfigMixin
# 从上层模块导入与模型相关的常量和类
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
# 从上层模块导入与调度器相关的常量和类
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
# 从上层模块导入一些工具函数和常量
from ..utils import (
    CONFIG_NAME,  # 配置文件名常量
    BaseOutput,   # 基础输出类
    PushToHubMixin,  # 用于推送到 Hugging Face Hub 的混合类
    http_user_agent,  # HTTP 用户代理字符串
    is_transformers_available,  # 检查 transformers 库是否可用的函数
    logging,  # 日志模块
)

# 检查 transformers 库是否可用,如果可用则导入 FlaxPreTrainedModel
if is_transformers_available():
    from transformers import FlaxPreTrainedModel

# 定义加载模型的文件名常量
INDEX_FILE = "diffusion_flax_model.bin"

# 创建日志记录器
logger = logging.get_logger(__name__)

# 定义可加载的类及其方法
LOADABLE_CLASSES = {
    "diffusers": {
        "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
        "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
        "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
    },
    "transformers": {
        "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
        "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
        "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
        "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
        "ProcessorMixin": ["save_pretrained", "from_pretrained"],
        "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
    },
}

# 创建一个字典以存储所有可导入的类
ALL_IMPORTABLE_CLASSES = {}
# 遍历可加载的类并将其更新到 ALL_IMPORTABLE_CLASSES 字典中
for library in LOADABLE_CLASSES:
    ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])

# 定义导入 Flax 模型的函数
def import_flax_or_no_model(module, class_name):
    try:
        # 1. 首先确保如果存在 Flax 对象,则导入该对象
        class_obj = getattr(module, "Flax" + class_name)
    except AttributeError:
        # 2. 如果失败,则说明没有模型,不附加 "Flax"
        class_obj = getattr(module, class_name)
    except AttributeError:
        # 如果两者都不存在,抛出错误
        raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")

    # 返回找到的类对象
    return class_obj

# 定义 FlaxImagePipelineOutput 类,继承自 BaseOutput
@flax.struct.dataclass
class FlaxImagePipelineOutput(BaseOutput):
    """
    图像管道的输出类。
    # 定义函数参数的文档字符串
        Args:
            images (`List[PIL.Image.Image]` or `np.ndarray`)
                # 输入参数 images 是去噪后的 PIL 图像列表,长度为 `batch_size` 或形状为 `(batch_size, height, width, num_channels)` 的 NumPy 数组。
        """
    
        # 声明 images 变量的类型,可以是 PIL 图像列表或 NumPy 数组
        images: Union[List[PIL.Image.Image], np.ndarray]
# FlaxDiffusionPipeline类是Flax基础管道的基类
class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
    r"""
    Flax基础管道的基类。

    [`FlaxDiffusionPipeline`]存储扩散管道的所有组件(模型、调度器和处理器),并提供加载、下载和保存模型的方法。它还包括以下方法:

        - 启用/禁用去噪迭代的进度条

    类属性:

        - **config_name** ([`str`]) -- 存储扩散管道组件的类和模块名称的配置文件名。
    """

    # 定义配置文件名
    config_name = "model_index.json"

    # 注册模块的方法,接收任意关键字参数
    def register_modules(self, **kwargs):
        # 为避免循环导入,在此处导入
        from diffusers import pipelines

        # 遍历传入的模块
        for name, module in kwargs.items():
            # 如果模块为None,注册字典为None值
            if module is None:
                register_dict = {name: (None, None)}
            else:
                # 获取模块的库名
                library = module.__module__.split(".")[0]

                # 检查模块是否为管道模块
                pipeline_dir = module.__module__.split(".")[-2]
                path = module.__module__.split(".")
                is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

                # 如果库不在LOADABLE_CLASSES中,或模块是管道模块,则将库名设为管道目录名
                if library not in LOADABLE_CLASSES or is_pipeline_module:
                    library = pipeline_dir

                # 获取类名
                class_name = module.__class__.__name__

                # 注册字典为库名和类名的元组
                register_dict = {name: (library, class_name)}

            # 保存模型索引配置
            self.register_to_config(**register_dict)

            # 将模块设置为当前对象的属性
            setattr(self, name, module)

    # 保存预训练模型的方法,接收目录路径和参数等
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        params: Union[Dict, FrozenDict],
        push_to_hub: bool = False,
        **kwargs,
    @classmethod
    # 类方法装饰器,表示该方法是类级别的方法
    @validate_hf_hub_args
    @classmethod
    # 获取对象初始化方法的签名参数
    def _get_signature_keys(cls, obj):
        # 获取对象初始化方法的参数
        parameters = inspect.signature(obj.__init__).parameters
        # 获取必需参数
        required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
        # 获取可选参数
        optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
        # 计算期望的模块名
        expected_modules = set(required_parameters.keys()) - {"self"}

        # 返回期望的模块名和可选参数
        return expected_modules, optional_parameters

    # 属性装饰器,表明该方法为属性
    # 定义一个返回管道组件的字典的方法
    def components(self) -> Dict[str, Any]:
        r"""
        `self.components` 属性对于使用相同权重和配置运行不同的管道非常有用,避免重新分配内存。

        示例:

        ```py
        >>> from diffusers import (
        ...     FlaxStableDiffusionPipeline,
        ...     FlaxStableDiffusionImg2ImgPipeline,
        ... )

        >>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
        ...     "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
        ... )
        >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
        ```py

        返回:
            包含初始化管道所需所有模块的字典。
        """
        # 获取期望的模块和可选参数
        expected_modules, optional_parameters = self._get_signature_keys(self)
        # 创建一个字典,包含配置中的所有模块,但排除以“_”开头的和可选参数
        components = {
            k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
        }

        # 检查组件的键是否与期望的模块一致
        if set(components.keys()) != expected_modules:
            # 如果不一致,抛出错误并显示期望和实际定义的模块
            raise ValueError(
                f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
                f" {expected_modules} to be defined, but {components} are defined."
            )

        # 返回组件字典
        return components

    # 静态方法:将 NumPy 图像或图像批次转换为 PIL 图像
    @staticmethod
    def numpy_to_pil(images):
        """
        将 NumPy 图像或图像批次转换为 PIL 图像。
        """
        # 如果图像的维度为 3,增加一个新的维度
        if images.ndim == 3:
            images = images[None, ...]
        # 将图像值缩放到 0-255 之间并转换为无符号整数类型
        images = (images * 255).round().astype("uint8")
        # 如果图像的最后一个维度是 1,表示灰度图像
        if images.shape[-1] == 1:
            # 特殊情况处理单通道灰度图像
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            # 对于其他类型图像,直接转换
            pil_images = [Image.fromarray(image) for image in images]

        # 返回 PIL 图像列表
        return pil_images

    # TODO: 使其兼容 jax.lax
    # 定义一个进度条的方法,接受可迭代对象
    def progress_bar(self, iterable):
        # 如果没有进度条配置,则初始化为空字典
        if not hasattr(self, "_progress_bar_config"):
            self._progress_bar_config = {}
        # 如果已有配置,检查其类型是否为字典
        elif not isinstance(self._progress_bar_config, dict):
            raise ValueError(
                f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
            )

        # 返回带有进度条的可迭代对象
        return tqdm(iterable, **self._progress_bar_config)

    # 设置进度条的配置参数
    def set_progress_bar_config(self, **kwargs):
        # 将配置参数赋值给进度条配置
        self._progress_bar_config = kwargs

.\diffusers\pipelines\pipeline_loading_utils.py

# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,说明版权归 HuggingFace Inc. 团队所有
# Copyright 2024 The HuggingFace Inc. team.
#
# 根据 Apache 许可证第 2.0 版进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 除非遵守许可证,否则不得使用本文件
# you may not use this file except in compliance with the License.
# 可以通过以下网址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,否则软件在"按原样"基础上分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件,包括明示或暗示
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 请参阅许可证以获取有关权限和限制的具体信息
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入必要的库
import importlib
import os
import re
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

# 导入 PyTorch 库
import torch
# 从 huggingface_hub 导入模型信息函数
from huggingface_hub import model_info
# 导入 Hugging Face Hub 的参数验证工具
from huggingface_hub.utils import validate_hf_hub_args
# 导入版本管理工具
from packaging import version

# 导入当前模块的版本信息
from .. import __version__
# 从 utils 模块导入一些工具函数和常量
from ..utils import (
    FLAX_WEIGHTS_NAME,
    ONNX_EXTERNAL_WEIGHTS_NAME,
    ONNX_WEIGHTS_NAME,
    SAFETENSORS_WEIGHTS_NAME,
    WEIGHTS_NAME,
    get_class_from_dynamic_module,
    is_accelerate_available,
    is_peft_available,
    is_transformers_available,
    logging,
)
# 从 torch_utils 模块导入编译模块检查函数
from ..utils.torch_utils import is_compiled_module

# 检查 transformers 库是否可用
if is_transformers_available():
    # 导入 transformers 库
    import transformers
    # 从 transformers 中导入预训练模型基类
    from transformers import PreTrainedModel
    # 导入 transformers 中的权重名称常量
    from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
    from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
    from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME

# 检查 accelerate 库是否可用
if is_accelerate_available():
    # 导入 accelerate 库
    import accelerate
    # 从 accelerate 中导入模型调度函数
    from accelerate import dispatch_model
    # 导入用于从模块中移除钩子的工具
    from accelerate.hooks import remove_hook_from_module
    # 从 accelerate 中导入计算模块大小和获取最大内存的工具
    from accelerate.utils import compute_module_sizes, get_max_memory

# 定义加载模型时使用的索引文件名
INDEX_FILE = "diffusion_pytorch_model.bin"
# 定义自定义管道文件名
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
# 定义虚拟模块的文件夹路径
DUMMY_MODULES_FOLDER = "diffusers.utils"
# 定义 transformers 虚拟模块的文件夹路径
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
# 定义连接管道的关键字列表
CONNECTED_PIPES_KEYS = ["prior"]

# 创建一个记录器实例,用于记录日志信息
logger = logging.get_logger(__name__)

# 定义可加载类的字典,映射库名到相应的类和方法
LOADABLE_CLASSES = {
    "diffusers": {
        "ModelMixin": ["save_pretrained", "from_pretrained"],
        "SchedulerMixin": ["save_pretrained", "from_pretrained"],
        "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
        "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
    },
    "transformers": {
        "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
        "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
        "PreTrainedModel": ["save_pretrained", "from_pretrained"],
        "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
        "ProcessorMixin": ["save_pretrained", "from_pretrained"],
        "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
    },
    "onnxruntime.training": {
        "ORTModule": ["save_pretrained", "from_pretrained"],
    },
}

# 初始化一个空字典,用于存储所有可导入的类
ALL_IMPORTABLE_CLASSES = {}
# 遍历 LOADABLE_CLASSES 字典中的每个库
for library in LOADABLE_CLASSES:
    # 将指定库中的可加载类更新到所有可导入的类集合中
        ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
# 检查文件名是否与 safetensors 兼容,返回布尔值
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
    """
    检查 safetensors 兼容性:
    - 默认情况下,所有模型使用默认 pytorch 序列化保存,因此我们使用默认 pytorch 文件列表来了解所需的 safetensors 文件。
    - 仅当每个默认 pytorch 文件都有匹配的 safetensors 文件时,模型才与 safetensors 兼容。

    将默认 pytorch 序列化文件名转换为 safetensors 序列化文件名:
    - 对于来自 diffusers 库的模型,仅需将 ".bin" 扩展名替换为 ".safetensors"
    - 对于来自 transformers 库的模型,文件名从 "pytorch_model" 更改为 "model",并将 ".bin" 扩展名替换为 ".safetensors"
    """
    # 初始化一个空列表,用于存储默认 pytorch 文件名
    pt_filenames = []

    # 初始化一个空集合,用于存储 safetensors 文件名
    sf_filenames = set()

    # 如果未传递组件,则将其设置为空列表
    passed_components = passed_components or []

    # 遍历输入的文件名
    for filename in filenames:
        # 分离文件名和扩展名
        _, extension = os.path.splitext(filename)

        # 如果文件在传递的组件中,跳过处理
        if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
            continue

        # 如果扩展名为 .bin,则添加到 pytorch 文件名列表中
        if extension == ".bin":
            pt_filenames.append(os.path.normpath(filename))
        # 如果扩展名为 .safetensors,则添加到 safetensors 文件名集合中
        elif extension == ".safetensors":
            sf_filenames.add(os.path.normpath(filename))

    # 遍历所有默认 pytorch 文件名
    for filename in pt_filenames:
        # 拆分路径和文件名
        path, filename = os.path.split(filename)
        filename, extension = os.path.splitext(filename)

        # 如果文件名以 "pytorch_model" 开头,则进行替换
        if filename.startswith("pytorch_model"):
            filename = filename.replace("pytorch_model", "model")
        else:
            filename = filename

        # 构建预期的 safetensors 文件名
        expected_sf_filename = os.path.normpath(os.path.join(path, filename))
        expected_sf_filename = f"{expected_sf_filename}.safetensors"
        # 检查预期的 safetensors 文件名是否在集合中
        if expected_sf_filename not in sf_filenames:
            logger.warning(f"{expected_sf_filename} not found")
            return False

    # 如果所有检查通过,返回 True
    return True


# 检查文件名是否与 variant 兼容,返回文件路径列表或字符串
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
    # 定义一个权重文件名的列表
    weight_names = [
        WEIGHTS_NAME,
        SAFETENSORS_WEIGHTS_NAME,
        FLAX_WEIGHTS_NAME,
        ONNX_WEIGHTS_NAME,
        ONNX_EXTERNAL_WEIGHTS_NAME,
    ]

    # 如果 transformers 可用,添加更多权重文件名
    if is_transformers_available():
        weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]

    # 从权重文件名中提取前缀
    weight_prefixes = [w.split(".")[0] for w in weight_names]
    # 从权重文件名中提取后缀
    weight_suffixs = [w.split(".")[-1] for w in weight_names]
    # 定义 transformers 索引格式的正则表达式
    transformers_index_format = r"\d{5}-of-\d{5}"
    # 如果 variant 不为 None,表示需要处理变体文件
    if variant is not None:
        # 定义一个正则表达式,匹配带有变体的权重文件名
        variant_file_re = re.compile(
            rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
        )
        # 定义一个正则表达式,匹配带有变体的索引文件名
        variant_index_re = re.compile(
            rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
        )

    # 定义一个正则表达式,匹配不带变体的权重文件名
    non_variant_file_re = re.compile(
        rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
    )
    # 定义一个正则表达式,匹配不带变体的索引文件名
    non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")

    # 如果 variant 不为 None,获取所有变体权重和索引文件名
    if variant is not None:
        variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
        variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
        # 合并变体权重和索引文件名
        variant_filenames = variant_weights | variant_indexes
    else:
        # 如果没有变体,则变体文件名集合为空
        variant_filenames = set()

    # 获取所有不带变体的权重文件名
    non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
    # 获取所有不带变体的索引文件名
    non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
    # 合并不带变体的权重和索引文件名
    non_variant_filenames = non_variant_weights | non_variant_indexes

    # 默认情况下使用所有变体文件名
    usable_filenames = set(variant_filenames)

    # 定义一个函数,将文件名转换为对应的变体文件名
    def convert_to_variant(filename):
        # 如果文件名中包含 'index',则替换为带变体的索引文件名
        if "index" in filename:
            variant_filename = filename.replace("index", f"index.{variant}")
        # 如果文件名符合特定格式,则转换为带变体的文件名
        elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
            variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
        # 否则默认按变体格式修改文件名
        else:
            variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
        # 返回变体文件名
        return variant_filename

    # 遍历所有不带变体的文件名
    for f in non_variant_filenames:
        # 转换为对应的变体文件名
        variant_filename = convert_to_variant(f)
        # 如果该变体文件名不在可用文件名集合中,则添加
        if variant_filename not in usable_filenames:
            usable_filenames.add(f)

    # 返回可用文件名和变体文件名的集合
    return usable_filenames, variant_filenames
# 装饰器,用于验证 Hugging Face Hub 参数
@validate_hf_hub_args
# 定义一个函数,用于发出关于过时模型变体的警告
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
    # 获取模型信息,包括预训练模型的路径和其他参数
    info = model_info(
        pretrained_model_name_or_path,
        token=token,
        revision=None,
    )
    # 从模型信息中提取所有文件名
    filenames = {sibling.rfilename for sibling in info.siblings}
    # 获取与指定变体兼容的模型文件名
    comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
    # 去掉文件名中的版本信息,生成新文件名列表
    comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]

    # 检查给定的模型文件名是否是兼容文件名的子集
    if set(model_filenames).issubset(set(comp_model_filenames)):
        # 发出关于通过修订加载模型变体的警告
        warnings.warn(
            f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
            FutureWarning,
        )
    else:
        # 发出关于不正确加载模型变体的警告,并请求用户报告缺失文件的问题
        warnings.warn(
            f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
            FutureWarning,
        )


# 定义一个函数,用于解包模型
def _unwrap_model(model):
    """Unwraps a model."""
    # 检查模型是否为编译模块,如果是则解包
    if is_compiled_module(model):
        model = model._orig_mod

    # 检查 PEFT 是否可用
    if is_peft_available():
        from peft import PeftModel

        # 如果模型是 PeftModel 类型,则解包至基础模型
        if isinstance(model, PeftModel):
            model = model.base_model.model

    # 返回解包后的模型
    return model


# 定义一个简单的帮助函数,用于在不正确模块时抛出或发出警告
def maybe_raise_or_warn(
    library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
    """Simple helper method to raise or warn in case incorrect module has been passed"""
    # 如果当前模块不是管道模块
        if not is_pipeline_module:
            # 动态导入指定的库
            library = importlib.import_module(library_name)
            # 获取库中指定名称的类对象
            class_obj = getattr(library, class_name)
            # 遍历可导入类,构建文件名到类对象的字典
            class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
    
            # 初始化期望的类对象为 None
            expected_class_obj = None
            # 遍历类候选者,检查是否与目标类兼容
            for class_name, class_candidate in class_candidates.items():
                # 如果候选类不为 None 且是目标类的子类
                if class_candidate is not None and issubclass(class_obj, class_candidate):
                    # 将其设置为期望的类对象
                    expected_class_obj = class_candidate
    
            # Dynamo 将原始模型包装在一个私有类中。
            # 没有找到公共 API 获取原始类。
            # 从传入的类对象中获取子模型
            sub_model = passed_class_obj[name]
            # 解包子模型,获取原始模型
            unwrapped_sub_model = _unwrap_model(sub_model)
            # 获取解包后模型的类
            model_cls = unwrapped_sub_model.__class__
    
            # 检查解包模型类是否是期望类的子类
            if not issubclass(model_cls, expected_class_obj):
                # 如果不是,抛出值错误,指明类型不匹配
                raise ValueError(
                    f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
                )
        else:
            # 如果是管道模块,记录警告信息
            logger.warning(
                f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
                " has the correct type"
            )
# 定义一个获取类对象和候选类的辅助方法
def get_class_obj_and_candidates(
    # 传入库名、类名、可导入的类、管道、是否为管道模块、组件名及缓存目录
    library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
    """简单的辅助方法来检索模块的类对象以及潜在的父类对象"""
    # 构建组件文件夹路径
    component_folder = os.path.join(cache_dir, component_name)

    # 如果是管道模块
    if is_pipeline_module:
        # 获取指定库名的管道模块
        pipeline_module = getattr(pipelines, library_name)

        # 获取指定类对象
        class_obj = getattr(pipeline_module, class_name)
        # 创建类候选字典,键为可导入类名,值为类对象
        class_candidates = {c: class_obj for c in importable_classes.keys()}
    # 如果组件文件存在
    elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
        # 从动态模块加载自定义组件
        class_obj = get_class_from_dynamic_module(
            component_folder, module_file=library_name + ".py", class_name=class_name
        )
        # 创建类候选字典
        class_candidates = {c: class_obj for c in importable_classes.keys()}
    else:
        # 否则从库中导入
        library = importlib.import_module(library_name)

        # 获取指定类对象
        class_obj = getattr(library, class_name)
        # 创建类候选字典
        class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

    # 返回类对象及候选类字典
    return class_obj, class_candidates


# 定义一个获取自定义管道类的辅助方法
def _get_custom_pipeline_class(
    # 传入自定义管道及其相关参数
    custom_pipeline,
    repo_id=None,
    hub_revision=None,
    class_name=None,
    cache_dir=None,
    revision=None,
):
    # 如果自定义管道是一个文件
    if custom_pipeline.endswith(".py"):
        path = Path(custom_pipeline)
        # 分解为文件夹和文件名
        file_name = path.name
        custom_pipeline = path.parent.absolute()
    # 如果提供了仓库 ID
    elif repo_id is not None:
        file_name = f"{custom_pipeline}.py"
        custom_pipeline = repo_id
    else:
        # 默认文件名
        file_name = CUSTOM_PIPELINE_FILE_NAME

    # 如果提供了仓库 ID 和修订版本
    if repo_id is not None and hub_revision is not None:
        # 从 Hub 加载管道代码时,确保覆盖修订版本
        revision = hub_revision

    # 返回从动态模块获取类
    return get_class_from_dynamic_module(
        custom_pipeline,
        module_file=file_name,
        class_name=class_name,
        cache_dir=cache_dir,
        revision=revision,
    )


# 定义一个获取管道类的辅助方法
def _get_pipeline_class(
    # 传入类对象及其他参数
    class_obj,
    config=None,
    load_connected_pipeline=False,
    custom_pipeline=None,
    repo_id=None,
    hub_revision=None,
    class_name=None,
    cache_dir=None,
    revision=None,
):
    # 如果提供了自定义管道
    if custom_pipeline is not None:
        # 调用获取自定义管道类的方法
        return _get_custom_pipeline_class(
            custom_pipeline,
            repo_id=repo_id,
            hub_revision=hub_revision,
            class_name=class_name,
            cache_dir=cache_dir,
            revision=revision,
        )

    # 如果类对象不是 DiffusionPipeline
    if class_obj.__name__ != "DiffusionPipeline":
        # 直接返回类对象
        return class_obj

    # 导入 diffusers 模块
    diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
    # 获取类名,如果未提供则从配置中获取
    class_name = class_name or config["_class_name"]
    # 如果类名不存在,抛出错误
    if not class_name:
        raise ValueError(
            "在配置文件中找不到类名。请确保传入正确的 `class_name`。"
        )
    # 如果类名以 "Flax" 开头,则去掉前四个字符,否则保持原样
        class_name = class_name[4:] if class_name.startswith("Flax") else class_name
    
        # 从 diffusers_module 动态获取指定名称的类
        pipeline_cls = getattr(diffusers_module, class_name)
    
        # 如果需要加载连接的管道
        if load_connected_pipeline:
            # 从 auto_pipeline 导入获取连接管道的函数
            from .auto_pipeline import _get_connected_pipeline
    
            # 获取与管道类相关联的连接管道类
            connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
            # 如果找到了连接的管道类
            if connected_pipeline_cls is not None:
                # 记录加载的连接管道类的信息
                logger.info(
                    f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
                )
            else:
                # 记录没有找到连接管道类的信息
                logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
    
            # 使用找到的连接管道类或保持原管道类
            pipeline_cls = connected_pipeline_cls or pipeline_cls
    
        # 返回最终的管道类
        return pipeline_cls
# 加载一个空模型的函数,接受多种参数
def _load_empty_model(
    library_name: str,  # 库的名称
    class_name: str,  # 类的名称
    importable_classes: List[Any],  # 可导入的类列表
    pipelines: Any,  # 管道相关信息
    is_pipeline_module: bool,  # 是否为管道模块的布尔值
    name: str,  # 模型名称
    torch_dtype: Union[str, torch.dtype],  # Torch 数据类型
    cached_folder: Union[str, os.PathLike],  # 缓存文件夹路径
    **kwargs,  # 其他额外参数
):
    # 检索类对象
    class_obj, _ = get_class_obj_and_candidates(
        library_name,  # 库的名称
        class_name,  # 类的名称
        importable_classes,  # 可导入的类列表
        pipelines,  # 管道相关信息
        is_pipeline_module,  # 是否为管道模块的布尔值
        component_name=name,  # 组件名称
        cache_dir=cached_folder,  # 缓存目录
    )

    # 检查是否可用 transformers 库
    if is_transformers_available():
        # 解析 transformers 的版本号
        transformers_version = version.parse(version.parse(transformers.__version__).base_version)
    else:
        # 如果不可用,版本设置为 "N/A"
        transformers_version = "N/A"

    # 确定库的类型
    is_transformers_model = (
        is_transformers_available()  # 检查 transformers 库是否可用
        and issubclass(class_obj, PreTrainedModel)  # 检查类是否为 PreTrainedModel 的子类
        and transformers_version >= version.parse("4.20.0")  # 检查版本号是否符合要求
    )
    # 导入 diffusers 模块
    diffusers_module = importlib.import_module(__name__.split(".")[0])
    # 检查类是否为 diffusers 模型的子类
    is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)

    model = None  # 初始化模型为 None
    config_path = cached_folder  # 设置配置路径为缓存文件夹
    # 设置用户代理信息
    user_agent = {
        "diffusers": __version__,  # 当前 diffusers 的版本
        "file_type": "model",  # 文件类型为模型
        "framework": "pytorch",  # 框架为 PyTorch
    }

    # 如果是 diffusers 模型
    if is_diffusers_model:
        # 加载配置,然后在元信息上加载模型
        config, unused_kwargs, commit_hash = class_obj.load_config(
            os.path.join(config_path, name),  # 配置文件路径
            cache_dir=cached_folder,  # 缓存目录
            return_unused_kwargs=True,  # 返回未使用的关键字参数
            return_commit_hash=True,  # 返回提交哈希值
            force_download=kwargs.pop("force_download", False),  # 强制下载标志
            proxies=kwargs.pop("proxies", None),  # 代理设置
            local_files_only=kwargs.pop("local_files_only", False),  # 仅使用本地文件标志
            token=kwargs.pop("token", None),  # 认证令牌
            revision=kwargs.pop("revision", None),  # 版本修订信息
            subfolder=kwargs.pop("subfolder", None),  # 子文件夹
            user_agent=user_agent,  # 用户代理信息
        )
        # 初始化空权重
        with accelerate.init_empty_weights():
            model = class_obj.from_config(config, **unused_kwargs)  # 从配置中创建模型
    # 如果是 transformers 模型
    elif is_transformers_model:
        config_class = getattr(class_obj, "config_class", None)  # 获取配置类
        # 如果配置类为空,抛出错误
        if config_class is None:
            raise ValueError("`config_class` cannot be None. Please double-check the model.")

        # 从预训练配置加载模型配置
        config = config_class.from_pretrained(
            cached_folder,  # 缓存文件夹
            subfolder=name,  # 子文件夹
            force_download=kwargs.pop("force_download", False),  # 强制下载标志
            proxies=kwargs.pop("proxies", None),  # 代理设置
            local_files_only=kwargs.pop("local_files_only", False),  # 仅使用本地文件标志
            token=kwargs.pop("token", None),  # 认证令牌
            revision=kwargs.pop("revision", None),  # 版本修订信息
            user_agent=user_agent,  # 用户代理信息
        )
        # 初始化空权重
        with accelerate.init_empty_weights():
            model = class_obj(config)  # 使用配置创建模型

    # 如果模型已创建
    if model is not None:
        model = model.to(dtype=torch_dtype)  # 将模型转换为指定的数据类型
    return model  # 返回加载的模型
    # 定义一个包含模块大小的字典,键为模块名称,值为模块大小(以浮点数表示)
        module_sizes: Dict[str, float], 
        # 定义一个包含设备内存的字典,键为设备名称,值为设备内存(以浮点数表示)
        device_memory: Dict[str, float], 
        # 定义设备映射策略的字符串,默认值为 "balanced"
        device_mapping_strategy: str = "balanced"
):
    # 获取设备内存字典的所有设备 ID,并转换为列表
    device_ids = list(device_memory.keys())
    # 创建一个设备循环列表,包含设备 ID 的正序和反序
    device_cycle = device_ids + device_ids[::-1]
    # 复制设备内存字典,以避免修改原始字典
    device_memory = device_memory.copy()

    # 初始化设备 ID 和组件的映射字典
    device_id_component_mapping = {}
    # 当前设备索引,初始化为 0
    current_device_index = 0
    # 遍历模块大小字典
    for component in module_sizes:
        # 根据当前索引获取对应的设备 ID
        device_id = device_cycle[current_device_index % len(device_cycle)]
        # 获取当前组件所需的内存大小
        component_memory = module_sizes[component]
        # 获取当前设备的可用内存
        curr_device_memory = device_memory[device_id]

        # 如果 GPU 的内存不足以容纳当前组件,则将其转移到 CPU
        if component_memory > curr_device_memory:
            device_id_component_mapping["cpu"] = [component]
        else:
            # 如果设备 ID 不在映射中,则初始化该设备的组件列表
            if device_id not in device_id_component_mapping:
                device_id_component_mapping[device_id] = [component]
            else:
                # 如果设备 ID 已存在,则将组件添加到该设备的组件列表
                device_id_component_mapping[device_id].append(component)

            # 更新设备的剩余内存
            device_memory[device_id] -= component_memory
            # 移动到下一个设备索引
            current_device_index += 1

    # 返回设备 ID 到组件的映射字典
    return device_id_component_mapping


def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
    # 为了避免循环导入问题,导入管道模块
    from diffusers import pipelines

    # 从关键字参数中获取 torch 数据类型,默认值为 torch.float32
    torch_dtype = kwargs.get("torch_dtype", torch.float32)

    # 在一个元设备上加载管道中的每个模块,以便推导出设备映射
    init_empty_modules = {}
    # 遍历初始化字典中的每个名称及其对应的库名称和类名称
        for name, (library_name, class_name) in init_dict.items():
            # 如果类名以 "Flax" 开头,抛出不支持的错误
            if class_name.startswith("Flax"):
                raise ValueError("Flax pipelines are not supported with `device_map`.")
    
            # 定义所有可导入的类
            is_pipeline_module = hasattr(pipelines, library_name)  # 检查 pipelines 是否有对应的库名称
            importable_classes = ALL_IMPORTABLE_CLASSES  # 获取所有可导入类的集合
            loaded_sub_model = None  # 初始化已加载的子模型为 None
    
            # 使用传入的子模型或从库名称加载类名称
            if name in passed_class_obj:  # 如果传入的类对象中有该名称
                # 如果模型在管道模块中,则从管道加载模型
                # 检查传入的类对象是否具有正确的父类
                maybe_raise_or_warn(
                    library_name,
                    library,
                    class_name,
                    importable_classes,
                    passed_class_obj,
                    name,
                    is_pipeline_module,
                )
                with accelerate.init_empty_weights():  # 在初始化空权重的上下文中
                    loaded_sub_model = passed_class_obj[name]  # 从传入的对象中加载模型
    
            else:
                # 加载空模型
                loaded_sub_model = _load_empty_model(
                    library_name=library_name,  # 库名称
                    class_name=class_name,  # 类名称
                    importable_classes=importable_classes,  # 可导入类集合
                    pipelines=pipelines,  # 管道
                    is_pipeline_module=is_pipeline_module,  # 是否为管道模块
                    pipeline_class=pipeline_class,  # 管道类
                    name=name,  # 名称
                    torch_dtype=torch_dtype,  # torch 数据类型
                    cached_folder=kwargs.get("cached_folder", None),  # 缓存文件夹
                    force_download=kwargs.get("force_download", None),  # 强制下载标志
                    proxies=kwargs.get("proxies", None),  # 代理设置
                    local_files_only=kwargs.get("local_files_only", None),  # 仅限本地文件标志
                    token=kwargs.get("token", None),  # 访问令牌
                    revision=kwargs.get("revision", None),  # 版本修订号
                )
    
            # 如果已加载子模型不为 None,将其添加到初始化空模块中
            if loaded_sub_model is not None:
                init_empty_modules[name] = loaded_sub_model
    
        # 确定设备映射
        # 获取一个按大小排序的字典,用于映射模型级组件
        # 到其大小。
        module_sizes = {
            module_name: compute_module_sizes(module, dtype=torch_dtype)[""]  # 计算每个模块的大小
            for module_name, module in init_empty_modules.items()  # 遍历初始化空模块中的每个模块
            if isinstance(module, torch.nn.Module)  # 仅考虑 PyTorch 模块
        }
        # 对模块大小字典进行排序,按值降序排列
        module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
    
        # 获取每个设备(仅限 GPU)的最大可用内存
        max_memory = get_max_memory(max_memory)  # 获取最大内存信息
        # 对最大内存字典进行排序,按值降序排列
        max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
        # 从最大内存字典中移除 CPU 条目
        max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
    
        # 获取一个字典,用于将模型级组件映射到基于最大内存和模型大小的可用设备
        final_device_map = None  # 初始化最终设备映射为 None
    # 检查最大内存的长度是否大于 0
        if len(max_memory) > 0:
            # 分配组件到设备,并返回设备与组件的映射
            device_id_component_mapping = _assign_components_to_devices(
                module_sizes, max_memory, device_mapping_strategy=device_map
            )
    
            # 初始化最终设备映射字典
            final_device_map = {}
            # 遍历设备 ID 和其对应的组件
            for device_id, components in device_id_component_mapping.items():
                # 遍历每个组件,将其映射到设备 ID
                for component in components:
                    final_device_map[component] = device_id
    
        # 返回最终的设备映射字典
        return final_device_map
# 定义一个加载子模型的辅助方法,接受多个参数来配置模型加载
def load_sub_model(
    # 模型所在库的名称
    library_name: str,
    # 模型类的名称
    class_name: str,
    # 可导入类的列表
    importable_classes: List[Any],
    # 管道相关参数
    pipelines: Any,
    # 是否为管道模块的标志
    is_pipeline_module: bool,
    # 管道类
    pipeline_class: Any,
    # 指定的 torch 数据类型
    torch_dtype: torch.dtype,
    # 提供者参数
    provider: Any,
    # 会话选项
    sess_options: Any,
    # 设备映射,可能是字典或字符串
    device_map: Optional[Union[Dict[str, torch.device], str]],
    # 最大内存使用配置
    max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
    # 离线文件夹路径
    offload_folder: Optional[Union[str, os.PathLike]],
    # 是否离线保存状态字典
    offload_state_dict: bool,
    # 模型变体的字典
    model_variants: Dict[str, str],
    # 模型名称
    name: str,
    # 是否从 Flax 框架加载
    from_flax: bool,
    # 模型变体名称
    variant: str,
    # 是否使用低 CPU 内存使用模式
    low_cpu_mem_usage: bool,
    # 缓存文件夹路径
    cached_folder: Union[str, os.PathLike],
):
    """从指定库和类名加载模块 `name` 的辅助方法"""

    # 获取类对象及候选类列表
    class_obj, class_candidates = get_class_obj_and_candidates(
        library_name,
        class_name,
        importable_classes,
        pipelines,
        is_pipeline_module,
        component_name=name,
        cache_dir=cached_folder,
    )

    load_method_name = None
    # 获取加载方法名称
    for class_name, class_candidate in class_candidates.items():
        # 如果候选类不为 None,且 class_obj 是其子类
        if class_candidate is not None and issubclass(class_obj, class_candidate):
            # 从可导入类中获取加载方法名称
            load_method_name = importable_classes[class_name][1]

    # 如果加载方法名称为 None,说明是一个虚拟模块 -> 抛出错误
    if load_method_name is None:
        none_module = class_obj.__module__
        # 检查模块路径是否属于虚拟模块
        is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
            TRANSFORMERS_DUMMY_MODULES_FOLDER
        )
        # 如果是虚拟模块,调用 class_obj 以获得友好的错误信息
        if is_dummy_path and "dummy" in none_module:
            class_obj()

        # 抛出值错误,说明没有定义任何加载方法
        raise ValueError(
            f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
            f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
        )

    # 获取加载方法
    load_method = getattr(class_obj, load_method_name)

    # 为加载方法添加关键字参数
    diffusers_module = importlib.import_module(__name__.split(".")[0])
    loading_kwargs = {}
    # 如果类是 PyTorch 模块,则添加 torch_dtype 参数
    if issubclass(class_obj, torch.nn.Module):
        loading_kwargs["torch_dtype"] = torch_dtype
    # 如果类是 OnnxRuntimeModel,则添加 provider 和 sess_options 参数
    if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
        loading_kwargs["provider"] = provider
        loading_kwargs["sess_options"] = sess_options

    # 检查类是否为 diffusers 模型
    is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)

    # 检查 transformers 是否可用
    if is_transformers_available():
        # 获取 transformers 的版本
        transformers_version = version.parse(version.parse(transformers.__version__).base_version)
    else:
        # 如果不可用,则版本标记为 "N/A"
        transformers_version = "N/A"

    # 检查类是否为 transformers 模型,并且版本满足要求
    is_transformers_model = (
        is_transformers_available()
        and issubclass(class_obj, PreTrainedModel)
        and transformers_version >= version.parse("4.20.0")
    )

    # 加载 transformers 模型时,如果 device_map 为 None,则权重将被初始化,而不是 diffusers.
    # 为了加快默认加载速度,设置 `low_cpu_mem_usage=low_cpu_mem_usage` 标志,默认为 `True`。
    # 确保权重不会被初始化,从而显著加快加载速度。
    if is_diffusers_model or is_transformers_model:
        # 将设备映射添加到加载参数中
        loading_kwargs["device_map"] = device_map
        # 设置最大内存使用限制
        loading_kwargs["max_memory"] = max_memory
        # 设置用于卸载的文件夹
        loading_kwargs["offload_folder"] = offload_folder
        # 设置卸载状态字典的标志
        loading_kwargs["offload_state_dict"] = offload_state_dict
        # 从模型变体中获取对应的变体
        loading_kwargs["variant"] = model_variants.pop(name, None)

        if from_flax:
            # 如果模型来自 Flax,设置标志
            loading_kwargs["from_flax"] = True

        # 以下内容可以在 `transformers` 的最低版本高于 4.27 时删除
        if (
            is_transformers_model
            and loading_kwargs["variant"] is not None
            and transformers_version < version.parse("4.27.0")
        ):
            # 如果版本不符合要求,抛出导入错误
            raise ImportError(
                f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
            )
        elif is_transformers_model and loading_kwargs["variant"] is None:
            # 如果变体为空,从加载参数中移除变体
            loading_kwargs.pop("variant")

        # 如果来自 Flax 并且模型是变换器模型,无法使用 `low_cpu_mem_usage` 加载
        if not (from_flax and is_transformers_model):
            # 设置低 CPU 内存使用标志
            loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
        else:
            # 否则将标志设置为 False
            loading_kwargs["low_cpu_mem_usage"] = False

    # 检查模块是否在子目录中
    if os.path.isdir(os.path.join(cached_folder, name)):
        # 从子目录加载模型
        loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
    else:
        # 否则从根目录加载模型
        loaded_sub_model = load_method(cached_folder, **loading_kwargs)

    if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
        # 移除模型中的钩子
        remove_hook_from_module(loaded_sub_model, recurse=True)
        # 检查是否需要将模型卸载到 CPU
        needs_offloading_to_cpu = device_map[""] == "cpu"

        if needs_offloading_to_cpu:
            # 如果需要卸载,将模型分发到 CPU
            dispatch_model(
                loaded_sub_model,
                state_dict=loaded_sub_model.state_dict(),
                device_map=device_map,
                force_hooks=True,
                main_device=0,
            )
        else:
            # 否则正常分发模型
            dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)

    # 返回加载的子模型
    return loaded_sub_model
# 根据传入模块获取类库名称和类名的元组
def _fetch_class_library_tuple(module):
    # 在这里导入模块以避免循环导入问题
    diffusers_module = importlib.import_module(__name__.split(".")[0])
    # 获取 diffusers_module 中的 pipelines 属性
    pipelines = getattr(diffusers_module, "pipelines")

    # 从原始模块注册配置,而不是从动态编译的模块
    not_compiled_module = _unwrap_model(module)
    # 获取模块的库名称
    library = not_compiled_module.__module__.split(".")[0]

    # 检查该模块是否为管道模块
    module_path_items = not_compiled_module.__module__.split(".")
    # 获取模块路径倒数第二项,作为管道目录
    pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None

    # 获取模块路径,并检查是否为管道模块
    path = not_compiled_module.__module__.split(".")
    is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

    # 如果库名称不在 LOADABLE_CLASSES 中,则认为它是自定义模块
    # 或者如果是管道模块,则库名称设为模块名称
    if is_pipeline_module:
        library = pipeline_dir
    elif library not in LOADABLE_CLASSES:
        library = not_compiled_module.__module__

    # 获取类名
    class_name = not_compiled_module.__class__.__name__

    # 返回库名称和类名的元组
    return (library, class_name)
posted @ 2024-10-22 12:34  绝不原创的飞龙  阅读(169)  评论(0)    收藏  举报