diffusers-源码解析-四十五-

diffusers 源码解析(四十五)

.\diffusers\pipelines\stable_diffusion\pipeline_flax_stable_diffusion_img2img.py

# 版权声明,表明此文件的版权归 HuggingFace 团队所有
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证进行授权,用户必须遵守该许可证使用此文件
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 用户可以在以下链接获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用的法律要求或书面同意,否则此文件按“原样”提供,没有任何明示或暗示的担保
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言管辖权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings  # 导入警告模块,用于发出警告消息
from functools import partial  # 从 functools 导入 partial 函数,用于部分应用
from typing import Dict, List, Optional, Union  # 导入类型提示相关的模块

import jax  # 导入 JAX 库,用于加速数值计算
import jax.numpy as jnp  # 导入 JAX 的 NumPy 模块,作为 jnp 使用
import numpy as np  # 导入 NumPy 库,作为 np 使用
from flax.core.frozen_dict import FrozenDict  # 从 flax 导入 FrozenDict,用于创建不可变字典
from flax.jax_utils import unreplicate  # 从 flax 导入 unreplicate 函数,用于去除 JAX 复制
from flax.training.common_utils import shard  # 从 flax 导入 shard 函数,用于数据切分
from PIL import Image  # 从 PIL 导入 Image 模块,用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel  # 导入 transformers 中的处理器和模型

from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel  # 导入模型
from ...schedulers import (  # 从调度器模块导入各种调度器
    FlaxDDIMScheduler,
    FlaxDPMSolverMultistepScheduler,
    FlaxLMSDiscreteScheduler,
    FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring  # 导入工具函数和日志模块
from ..pipeline_flax_utils import FlaxDiffusionPipeline  # 从 pipeline_flax_utils 导入 FlaxDiffusionPipeline
from .pipeline_output import FlaxStableDiffusionPipelineOutput  # 从 pipeline_output 导入 FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker  # 从 safety_checker_flax 导入安全检查器

logger = logging.get_logger(__name__)  # 创建一个日志记录器,使用当前模块的名称

# 设置为 True 时使用 Python 的 for 循环,而非 jax.fori_loop,以便于调试
DEBUG = False

# 示例文档字符串的模板
EXAMPLE_DOC_STRING = """
    # 示例代码块,用于演示如何使用库和函数
    Examples:
        ```py
        # 导入 JAX 库
        >>> import jax
        # 导入 NumPy 库
        >>> import numpy as np
        # 导入 JAX 的 NumPy 实现
        >>> import jax.numpy as jnp
        # 从 flax.jax_utils 导入复制函数
        >>> from flax.jax_utils import replicate
        # 从 flax.training.common_utils 导入分片函数
        >>> from flax.training.common_utils import shard
        # 导入 requests 库用于发送 HTTP 请求
        >>> import requests
        # 从 io 模块导入 BytesIO 用于处理字节流
        >>> from io import BytesIO
        # 从 PIL 库导入 Image 类用于图像处理
        >>> from PIL import Image
        # 从 diffusers 导入 FlaxStableDiffusionImg2ImgPipeline 类
        >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline


        # 定义一个创建随机数种子的函数
        >>> def create_key(seed=0):
        ...     # 返回一个基于给定种子的 JAX 随机数生成器密钥
        ...     return jax.random.PRNGKey(seed)


        # 使用种子 0 创建随机数生成器密钥
        >>> rng = create_key(0)

        # 定义要下载的图像 URL
        >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
        # 发送 GET 请求以获取图像
        >>> response = requests.get(url)
        # 从响应内容中读取图像,并转换为 RGB 模式
        >>> init_img = Image.open(BytesIO(response.content)).convert("RGB")
        # 调整图像大小为 768x512 像素
        >>> init_img = init_img.resize((768, 512))

        # 定义提示词
        >>> prompts = "A fantasy landscape, trending on artstation"

        # 从预训练模型中加载图像到图像生成管道
        >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
        ...     "CompVis/stable-diffusion-v1-4",  # 模型名称
        ...     revision="flax",                  # 版本标识
        ...     dtype=jnp.bfloat16,               # 数据类型
        ... )

        # 获取设备数量以生成样本
        >>> num_samples = jax.device_count()
        # 根据设备数量拆分随机数生成器的密钥
        >>> rng = jax.random.split(rng, jax.device_count())
        # 准备输入的提示词和图像,复制 num_samples 次
        >>> prompt_ids, processed_image = pipeline.prepare_inputs(
        ...     prompt=[prompts] * num_samples,    # 创建提示词列表
        ...     image=[init_img] * num_samples     # 创建图像列表
        ... )
        # 复制参数以便在多个设备上使用
        >>> p_params = replicate(params)
        # 将提示词 ID 分片以适应设备
        >>> prompt_ids = shard(prompt_ids)
        # 将处理后的图像分片以适应设备
        >>> processed_image = shard(processed_image)

        # 调用管道生成图像
        >>> output = pipeline(
        ...     prompt_ids=prompt_ids,              # 提示词 ID
        ...     image=processed_image,              # 处理后的图像
        ...     params=p_params,                    # 复制的参数
        ...     prng_seed=rng,                      # 随机数种子
        ...     strength=0.75,                     # 强度参数
        ...     num_inference_steps=50,            # 推理步骤数
        ...     jit=True,                          # 启用 JIT 编译
        ...     height=512,                        # 输出图像高度
        ...     width=768,                         # 输出图像宽度
        ... ).images  # 获取生成的图像

        # 将输出的图像转换为 PIL 格式以便展示
        >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
        ```py 
# 定义一个基于 Flax 的文本引导图像生成管道类,用于图像到图像的生成
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
    r"""
    基于 Flax 的管道,用于使用 Stable Diffusion 进行文本引导的图像到图像生成。

    该模型继承自 [`FlaxDiffusionPipeline`]。有关所有管道的通用方法的文档(下载、保存、在特定设备上运行等),请查看超类文档。

    参数:
        vae ([`FlaxAutoencoderKL`]):
            用于对图像进行编码和解码的变分自编码器(VAE)模型,将图像转换为潜在表示。
        text_encoder ([`~transformers.FlaxCLIPTextModel`]):
            冻结的文本编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
        tokenizer ([`~transformers.CLIPTokenizer`]):
            用于对文本进行标记化的 `CLIPTokenizer`。
        unet ([`FlaxUNet2DConditionModel`]):
            用于对编码图像潜在空间进行去噪的 `FlaxUNet2DConditionModel`。
        scheduler ([`SchedulerMixin`]):
            与 `unet` 结合使用的调度器,用于去噪编码的图像潜在空间。可以是
            [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
            [`FlaxDPMSolverMultistepScheduler`] 之一。
        safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
            分类模块,用于评估生成的图像是否可能被认为是冒犯性或有害的。
            有关模型潜在危害的更多详细信息,请参阅 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5)。
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            用于从生成图像中提取特征的 `CLIPImageProcessor`;作为输入用于 `safety_checker`。
    """

    # 初始化方法,设置管道的各个组件
    def __init__(
        self,
        # 变分自编码器模型
        vae: FlaxAutoencoderKL,
        # 文本编码器模型
        text_encoder: FlaxCLIPTextModel,
        # 文本标记器
        tokenizer: CLIPTokenizer,
        # 去噪模型
        unet: FlaxUNet2DConditionModel,
        # 调度器,用于去噪处理
        scheduler: Union[
            FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
        ],
        # 安全检查模块
        safety_checker: FlaxStableDiffusionSafetyChecker,
        # 特征提取器
        feature_extractor: CLIPImageProcessor,
        # 数据类型,默认为 float32
        dtype: jnp.dtype = jnp.float32,
    ):
        # 调用父类的构造函数
        super().__init__()
        # 设置数据类型属性
        self.dtype = dtype

        # 检查安全检查器是否为 None
        if safety_checker is None:
            # 记录警告,提醒用户禁用了安全检查器
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 注册模块,将各个组件进行初始化
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        # 计算 VAE 的缩放因子,基于其配置的输出通道数
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

    # 准备输入,接受文本提示和图像
    def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
        # 检查 prompt 类型是否为字符串或列表
        if not isinstance(prompt, (str, list)):
            # 如果不符合类型,抛出错误
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 检查 image 类型是否为 PIL 图像或列表
        if not isinstance(image, (Image.Image, list)):
            # 如果不符合类型,抛出错误
            raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")

        # 如果 image 是单个图像,则转换为列表
        if isinstance(image, Image.Image):
            image = [image]

        # 预处理图像,并将它们拼接为一个数组
        processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])

        # 将文本提示编码为模型输入格式
        text_input = self.tokenizer(
            prompt,
            padding="max_length",  # 填充到最大长度
            max_length=self.tokenizer.model_max_length,  # 最大长度设置
            truncation=True,  # 超出最大长度时截断
            return_tensors="np",  # 返回 NumPy 格式的张量
        )
        # 返回文本输入 ID 和处理后的图像
        return text_input.input_ids, processed_images

    # 获取是否包含不适宜内容的概念
    def _get_has_nsfw_concepts(self, features, params):
        # 使用安全检查器检查特征是否包含不适宜内容
        has_nsfw_concepts = self.safety_checker(features, params)
        # 返回检查结果
        return has_nsfw_concepts
    # 定义一个安全检查器的运行方法,处理输入的图像
    def _run_safety_checker(self, images, safety_model_params, jit=False):
        # 当 jit 为 True 时,安全模型参数应已被复制
        pil_images = [Image.fromarray(image) for image in images]  # 将 NumPy 数组转换为 PIL 图像
        features = self.feature_extractor(pil_images, return_tensors="np").pixel_values  # 提取图像特征并返回像素值
    
        if jit:  # 如果启用 JIT 编译
            features = shard(features)  # 将特征分片以优化性能
            has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)  # 检查是否存在 NSFW 概念
            has_nsfw_concepts = unshard(has_nsfw_concepts)  # 将结果反分片
            safety_model_params = unreplicate(safety_model_params)  # 反复制安全模型参数
        else:  # 如果没有启用 JIT 编译
            has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)  # 获取 NSFW 概念的存在性
    
        images_was_copied = False  # 标记图像是否已被复制
        for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):  # 遍历 NSFW 概念的列表
            if has_nsfw_concept:  # 如果检测到 NSFW 概念
                if not images_was_copied:  # 如果尚未复制图像
                    images_was_copied = True  # 标记为已复制
                    images = images.copy()  # 复制图像数组
    
                images[idx] = np.zeros(images[idx].shape, dtype=np.uint8)  # 用黑色图像替换原图像
    
            if any(has_nsfw_concepts):  # 如果任一图像有 NSFW 概念
                warnings.warn(  # 发出警告
                    "Potential NSFW content was detected in one or more images. A black image will be returned"
                    " instead. Try again with a different prompt and/or seed."
                )
    
        return images, has_nsfw_concepts  # 返回处理后的图像和 NSFW 概念的存在性
    
    # 定义获取开始时间步的方法
    def get_timestep_start(self, num_inference_steps, strength):
        # 使用初始时间步计算原始时间步
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)  # 计算初始时间步,确保不超出总步骤
    
        t_start = max(num_inference_steps - init_timestep, 0)  # 计算开始时间步,确保不为负
    
        return t_start  # 返回开始时间步
    
    # 定义生成方法
    def _generate(
        self,
        prompt_ids: jnp.ndarray,  # 输入提示的 ID
        image: jnp.ndarray,  # 输入图像
        params: Union[Dict, FrozenDict],  # 模型参数
        prng_seed: jax.Array,  # 随机种子
        start_timestep: int,  # 开始时间步
        num_inference_steps: int,  # 推理步骤数量
        height: int,  # 生成图像的高度
        width: int,  # 生成图像的宽度
        guidance_scale: float,  # 引导比例
        noise: Optional[jnp.ndarray] = None,  # 噪声选项
        neg_prompt_ids: Optional[jnp.ndarray] = None,  # 负提示 ID 选项
        @replace_example_docstring(EXAMPLE_DOC_STRING)  # 用示例文档字符串替换
        def __call__(  # 定义可调用方法
            self,
            prompt_ids: jnp.ndarray,  # 输入提示的 ID
            image: jnp.ndarray,  # 输入图像
            params: Union[Dict, FrozenDict],  # 模型参数
            prng_seed: jax.Array,  # 随机种子
            strength: float = 0.8,  # 强度参数,默认为 0.8
            num_inference_steps: int = 50,  # 推理步骤数量,默认为 50
            height: Optional[int] = None,  # 生成图像的高度,默认为 None
            width: Optional[int] = None,  # 生成图像的宽度,默认为 None
            guidance_scale: Union[float, jnp.ndarray] = 7.5,  # 引导比例,默认为 7.5
            noise: jnp.ndarray = None,  # 噪声,默认为 None
            neg_prompt_ids: jnp.ndarray = None,  # 负提示 ID,默认为 None
            return_dict: bool = True,  # 是否返回字典,默认为 True
            jit: bool = False,  # 是否启用 JIT 编译,默认为 False
# 静态参数为 pipe, start_timestep, num_inference_steps, height, width。任何更改都会触发重新编译。
# 非静态参数为 (sharded) 输入张量,按其第一维映射 (因此为 `0`)。
@partial(
    jax.pmap,  # 使用 JAX 的 pmap 函数进行并行映射
    in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0),  # 指定输入参数的维度
    static_broadcasted_argnums=(0, 5, 6, 7, 8),  # 静态广播的参数索引
)
def _p_generate(
    pipe,  # 生成管道对象
    prompt_ids,  # 输入的提示 ID
    image,  # 输入的图像数据
    params,  # 其他参数
    prng_seed,  # 随机数种子
    start_timestep,  # 开始的时间步
    num_inference_steps,  # 推理的步骤数
    height,  # 图像的高度
    width,  # 图像的宽度
    guidance_scale,  # 引导尺度
    noise,  # 噪声数据
    neg_prompt_ids,  # 负提示 ID
):
    # 调用管道的生成方法,传递所有必要的参数
    return pipe._generate(
        prompt_ids,  # 提示 ID
        image,  # 图像数据
        params,  # 其他参数
        prng_seed,  # 随机数种子
        start_timestep,  # 开始时间步
        num_inference_steps,  # 推理步骤数
        height,  # 图像高度
        width,  # 图像宽度
        guidance_scale,  # 引导尺度
        noise,  # 噪声数据
        neg_prompt_ids,  # 负提示 ID
    )


@partial(jax.pmap, static_broadcasted_argnums=(0,))  # 使用 JAX 的 pmap 函数进行并行映射
def _p_get_has_nsfw_concepts(pipe, features, params):
    # 调用管道的方法以获取是否包含 NSFW 概念的特征
    return pipe._get_has_nsfw_concepts(features, params)


def unshard(x: jnp.ndarray):
    # 将输入张量 x 重组为适合的形状,合并设备和批次维度
    num_devices, batch_size = x.shape[:2]  # 获取设备数量和批次大小
    rest = x.shape[2:]  # 获取剩余维度
    # 重新调整形状为 (num_devices * batch_size, 剩余维度)
    return x.reshape(num_devices * batch_size, *rest)


def preprocess(image, dtype):
    w, h = image.size  # 获取图像的宽度和高度
    # 调整宽度和高度为 32 的整数倍
    w, h = (x - x % 32 for x in (w, h))  
    # 重新调整图像大小,使用 Lanczos 插值法
    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    # 将图像转换为 NumPy 数组并归一化到 [0, 1] 范围
    image = jnp.array(image).astype(dtype) / 255.0
    # 调整数组维度为 (1, 通道数, 高度, 宽度)
    image = image[None].transpose(0, 3, 1, 2)
    # 将图像值范围转换为 [-1, 1]
    return 2.0 * image - 1.0

.\diffusers\pipelines\stable_diffusion\pipeline_flax_stable_diffusion_inpaint.py

# 版权声明,指明该文件的版权信息
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache 许可证第 2.0 版许可使用本文件
# Licensed under the Apache License, Version 2.0 (the "License");
# 除非遵循许可证,否则不得使用本文件
# you may not use this file except in compliance with the License.
# 可以通过以下网址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议规定,否则根据许可证分发的软件
# Unless required by applicable law or agreed to in writing, software
# 是按“原样”基础分发的,不提供任何形式的担保或条件
# distributed under the License is distributed on an "AS IS" BASIS,
# 不论是明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 参见许可证以了解适用权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings  # 导入 warnings 模块以处理警告
from functools import partial  # 从 functools 导入 partial,用于部分应用函数
from typing import Dict, List, Optional, Union  # 导入类型注解工具

import jax  # 导入 jax 库,用于高性能数值计算
import jax.numpy as jnp  # 导入 jax 的 numpy 作为 jnp
import numpy as np  # 导入 numpy 库以进行数组操作
from flax.core.frozen_dict import FrozenDict  # 从 flax 导入 FrozenDict 用于不可变字典
from flax.jax_utils import unreplicate  # 从 flax 导入 unreplicate,用于去除复制
from flax.training.common_utils import shard  # 从 flax 导入 shard,用于数据分片
from packaging import version  # 导入 version 用于版本比较
from PIL import Image  # 从 PIL 导入 Image 用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel  # 导入 transformers 库的相关组件

from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel  # 导入自定义模型
from ...schedulers import (  # 从自定义调度器导入各类调度器
    FlaxDDIMScheduler,
    FlaxDPMSolverMultistepScheduler,
    FlaxLMSDiscreteScheduler,
    FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring  # 导入工具函数
from ..pipeline_flax_utils import FlaxDiffusionPipeline  # 导入 FlaxDiffusionPipeline 类
from .pipeline_output import FlaxStableDiffusionPipelineOutput  # 导入输出类
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker  # 导入安全检查器类

logger = logging.get_logger(__name__)  # 创建日志记录器,使用当前模块名称

# 设置为 True 时使用 Python 的 for 循环而不是 jax.fori_loop,以便于调试
DEBUG = False

EXAMPLE_DOC_STRING = """  # 定义示例文档字符串,通常用于文档生成
# 示例代码块,用于展示如何使用 JAX 和 Flax 进行图像处理
    Examples:
        ```py
        # 导入必要的库
        >>> import jax
        >>> import numpy as np
        >>> from flax.jax_utils import replicate
        >>> from flax.training.common_utils import shard
        >>> import PIL
        >>> import requests
        >>> from io import BytesIO
        >>> from diffusers import FlaxStableDiffusionInpaintPipeline

        # 定义一个函数,用于下载图像并转换为 RGB 格式
        >>> def download_image(url):
        ...     # 发送 GET 请求以获取图像内容
        ...     response = requests.get(url)
        ...     # 打开下载的内容并转换为 RGB 图像
        ...     return PIL.Image.open(BytesIO(response.content)).convert("RGB")

        # 定义图像和掩码的 URL
        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

        # 下载并调整初始图像和掩码图像的大小
        >>> init_image = download_image(img_url).resize((512, 512))
        >>> mask_image = download_image(mask_url).resize((512, 512))

        # 从预训练模型加载管道和参数
        >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(
        ...     "xvjiarui/stable-diffusion-2-inpainting"
        ... )

        # 定义处理图像时使用的提示
        >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
        # 初始化随机种子
        >>> prng_seed = jax.random.PRNGKey(0)
        # 定义推理步骤的数量
        >>> num_inference_steps = 50

        # 获取设备数量以便并行处理
        >>> num_samples = jax.device_count()
        # 将提示、初始图像和掩码图像扩展为设备数量的列表
        >>> prompt = num_samples * [prompt]
        >>> init_image = num_samples * [init_image]
        >>> mask_image = num_samples * [mask_image]
        # 准备输入,得到提示 ID 和处理后的图像
        >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(
        ...     prompt, init_image, mask_image
        ... )
        # 分割输入和随机数生成器

        # 复制参数以适应每个设备
        >>> params = replicate(params)
        # 根据设备数量分割随机种子
        >>> prng_seed = jax.random.split(prng_seed, jax.device_count())
        # 将提示 ID 和处理后的图像分割以适应每个设备
        >>> prompt_ids = shard(prompt_ids)
        >>> processed_masked_images = shard(processed_masked_images)
        >>> processed_masks = shard(processed_masks)

        # 运行管道以生成图像
        >>> images = pipeline(
        ...     prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
        ... ).images
        # 将生成的图像数组转换为 PIL 图像格式
        >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
        ```  

FlaxStableDiffusionInpaintPipeline 类定义,继承自 FlaxDiffusionPipeline

class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
r"""
Flax 基于 Stable Diffusion 的文本引导图像修补的管道。

<Tip warning={true}>

🧪 这是一个实验性功能!

</Tip>

该模型继承自 [`FlaxDiffusionPipeline`]。有关所有管道通用方法(下载、保存、在特定设备上运行等)的实现,请查看父类文档。

参数:
    vae ([`FlaxAutoencoderKL`]):
        用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。
    text_encoder ([`~transformers.FlaxCLIPTextModel`]):
        冻结的文本编码器 ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
    tokenizer ([`~transformers.CLIPTokenizer`]):
        用于标记化文本的 `CLIPTokenizer`。
    unet ([`FlaxUNet2DConditionModel`]):
        用于去噪编码图像潜在表示的 `FlaxUNet2DConditionModel`。
    scheduler ([`SchedulerMixin`]):
        与 `unet` 结合使用以去噪编码图像潜在表示的调度器。可以是以下之一
        [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
        [`FlaxDPMSolverMultistepScheduler`]。
    safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
        估计生成图像是否可能被认为是冒犯性或有害的分类模块。
        请参考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以获取有关模型潜在危害的更多详细信息。
    feature_extractor ([`~transformers.CLIPImageProcessor`]):
        从生成图像中提取特征的 `CLIPImageProcessor`;用作 `safety_checker` 的输入。
"""

# 构造函数初始化
def __init__(
    # 变分自编码器(VAE)模型实例
    vae: FlaxAutoencoderKL,
    # 文本编码器模型实例
    text_encoder: FlaxCLIPTextModel,
    # 标记器实例
    tokenizer: CLIPTokenizer,
    # 去噪模型实例
    unet: FlaxUNet2DConditionModel,
    # 调度器实例,指定可用的调度器类型
    scheduler: Union[
        FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
    ],
    # 安全检查模块实例
    safety_checker: FlaxStableDiffusionSafetyChecker,
    # 特征提取器实例
    feature_extractor: CLIPImageProcessor,
    # 数据类型,默认为 float32
    dtype: jnp.dtype = jnp.float32,
# 定义初始化方法,接收多个参数
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置数据类型属性
        self.dtype = dtype

        # 检查安全检查器是否为 None
        if safety_checker is None:
            # 记录警告信息,提醒用户禁用安全检查器的风险
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 检查 UNet 版本是否小于 0.9.0
        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        # 检查 UNet 的样本大小是否小于 64
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        # 如果满足两个条件,构造弃用警告信息
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
                " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            # 调用弃用函数,传递警告信息
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            # 创建新配置字典,并更新样本大小为 64
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            # 将新配置赋值给 UNet 的内部字典
            unet._internal_dict = FrozenDict(new_config)

        # 注册多个模块以供使用
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        # 计算 VAE 的缩放因子
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

    # 定义准备输入的方法,接收多个参数
    def prepare_inputs(
        self,
        # 输入提示,可以是字符串或字符串列表
        prompt: Union[str, List[str]],
        # 输入图像,可以是单张图像或图像列表
        image: Union[Image.Image, List[Image.Image]],
        # 输入掩码,可以是单张掩码或掩码列表
        mask: Union[Image.Image, List[Image.Image]],
):
    # 检查 prompt 是否为字符串或列表类型,不符合则抛出异常
    if not isinstance(prompt, (str, list)):
        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

    # 检查 image 是否为 PIL 图像或列表类型,不符合则抛出异常
    if not isinstance(image, (Image.Image, list)):
        raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")

    # 如果 image 是单个 PIL 图像,则将其转为列表
    if isinstance(image, Image.Image):
        image = [image]

    # 检查 mask 是否为 PIL 图像或列表类型,不符合则抛出异常
    if not isinstance(mask, (Image.Image, list)):
        raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")

    # 如果 mask 是单个 PIL 图像,则将其转为列表
    if isinstance(mask, Image.Image):
        mask = [mask]

    # 对图像进行预处理,并合并为一个数组
    processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image])
    # 对掩膜进行预处理,并合并为一个数组
    processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask])
    # 将处理后的掩膜中小于0.5的值设为0
    processed_masks = processed_masks.at[processed_masks < 0.5].set(0)
    # 将处理后的掩膜中大于等于0.5的值设为1
    processed_masks = processed_masks.at[processed_masks >= 0.5].set(1)

    # 根据掩膜对图像进行遮罩处理
    processed_masked_images = processed_images * (processed_masks < 0.5)

    # 将 prompt 进行编码,并设置最大长度、填充和截断
    text_input = self.tokenizer(
        prompt,
        padding="max_length",
        max_length=self.tokenizer.model_max_length,
        truncation=True,
        return_tensors="np",
    )
    # 返回编码后的输入 ID、处理后的图像和掩膜
    return text_input.input_ids, processed_masked_images, processed_masks

def _get_has_nsfw_concepts(self, features, params):
    # 使用安全检查器检查特征中是否存在 NSFW 概念
    has_nsfw_concepts = self.safety_checker(features, params)
    # 返回 NSFW 概念的检测结果
    return has_nsfw_concepts

def _run_safety_checker(self, images, safety_model_params, jit=False):
    # 将传入的图像数组转换为 PIL 图像
    pil_images = [Image.fromarray(image) for image in images]
    # 提取图像特征并返回张量形式的像素值
    features = self.feature_extractor(pil_images, return_tensors="np").pixel_values

    # 如果开启 JIT 优化,则对特征进行分片
    if jit:
        features = shard(features)
        # 使用 NSFW 概念检测函数获取结果
        has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
        # 对结果进行反分片处理
        has_nsfw_concepts = unshard(has_nsfw_concepts)
        safety_model_params = unreplicate(safety_model_params)
    else:
        # 否则直接调用获取 NSFW 概念的函数
        has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)

    images_was_copied = False
    # 遍历每个 NSFW 概念的检测结果
    for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
        if has_nsfw_concept:
            # 如果发现 NSFW 概念且尚未复制图像,则进行复制
            if not images_was_copied:
                images_was_copied = True
                images = images.copy()

            # 将对应图像替换为黑色图像
            images[idx] = np.zeros(images[idx].shape, dtype=np.uint8)  # black image

        # 如果检测到任何 NSFW 概念,则发出警告
        if any(has_nsfw_concepts):
            warnings.warn(
                "Potential NSFW content was detected in one or more images. A black image will be returned"
                " instead. Try again with a different prompt and/or seed."
            )

    # 返回处理后的图像和 NSFW 概念的检测结果
    return images, has_nsfw_concepts
# 定义一个生成函数,处理图像生成的相关操作
    def _generate(
        # 输入的提示ID数组,通常用于模型输入
        self,
        prompt_ids: jnp.ndarray,
        # 输入的掩码数组,指示哪些部分需要处理
        mask: jnp.ndarray,
        # 被掩码的图像数组,作为生成过程的基础
        masked_image: jnp.ndarray,
        # 模型参数,可以是字典或冻结字典类型
        params: Union[Dict, FrozenDict],
        # 随机数种子,用于生成可重复的结果
        prng_seed: jax.Array,
        # 推理步骤的数量,控制生成的细致程度
        num_inference_steps: int,
        # 生成图像的高度
        height: int,
        # 生成图像的宽度
        width: int,
        # 指导比例,用于调整生成图像与提示的相关性
        guidance_scale: float,
        # 可选的潜在表示,用于进一步控制生成过程
        latents: Optional[jnp.ndarray] = None,
        # 可选的负提示ID数组,用于增强生成效果
        neg_prompt_ids: Optional[jnp.ndarray] = None,
    # 使用装饰器替换示例文档字符串,提供函数的文档说明
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义调用函数,进行图像生成操作
    def __call__(
        # 输入的提示ID数组
        self,
        prompt_ids: jnp.ndarray,
        # 输入的掩码数组
        mask: jnp.ndarray,
        # 被掩码的图像数组
        masked_image: jnp.ndarray,
        # 模型参数
        params: Union[Dict, FrozenDict],
        # 随机数种子
        prng_seed: jax.Array,
        # 推理步骤的数量,默认为50
        num_inference_steps: int = 50,
        # 生成图像的高度,默认为None(可选)
        height: Optional[int] = None,
        # 生成图像的宽度,默认为None(可选)
        width: Optional[int] = None,
        # 指导比例,默认为7.5
        guidance_scale: Union[float, jnp.ndarray] = 7.5,
        # 可选的潜在表示,默认为None
        latents: jnp.ndarray = None,
        # 可选的负提示ID数组,默认为None
        neg_prompt_ids: jnp.ndarray = None,
        # 返回字典格式的结果,默认为True
        return_dict: bool = True,
        # 是否使用JIT编译,默认为False
        jit: bool = False,

静态参数为管道、推理步骤数、高度和宽度。更改会触发重新编译。

非静态参数为在其第一维度(因此为0)映射的(分片)输入张量。

@partial(
jax.pmap, # 使用 JAX 的并行映射功能
in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), # 指定输入张量的维度映射
static_broadcasted_argnums=(0, 6, 7, 8), # 静态广播参数的索引
)
def _p_generate(
pipe, # 管道对象
prompt_ids, # 提示 ID
mask, # 掩码
masked_image, # 被掩码的图像
params, # 参数
prng_seed, # 随机种子
num_inference_steps, # 推理步骤数
height, # 图像高度
width, # 图像宽度
guidance_scale, # 引导比例
latents, # 潜在表示
neg_prompt_ids, # 负提示 ID
):
return pipe._generate( # 调用管道的生成方法
prompt_ids, # 提示 ID
mask, # 掩码
masked_image, # 被掩码的图像
params, # 参数
prng_seed, # 随机种子
num_inference_steps, # 推理步骤数
height, # 图像高度
width, # 图像宽度
guidance_scale, # 引导比例
latents, # 潜在表示
neg_prompt_ids, # 负提示 ID
)

@partial(jax.pmap, static_broadcasted_argnums=(0,)) # 使用 JAX 的并行映射功能
def _p_get_has_nsfw_concepts(pipe, features, params): # 检查特征是否包含 NSFW 概念
return pipe._get_has_nsfw_concepts(features, params) # 调用管道的方法

def unshard(x: jnp.ndarray): # 定义 unshard 函数,接受一个 ndarray
# einops.rearrange(x, 'd b ... -> (d b) ...') # 用于调整张量的形状
num_devices, batch_size = x.shape[:2] # 获取设备数量和批次大小
rest = x.shape[2:] # 获取其余维度
return x.reshape(num_devices * batch_size, rest) # 重新调整形状为 (db, ...)

def preprocess_image(image, dtype): # 定义预处理图像的函数
w, h = image.size # 获取图像的宽度和高度
w, h = (x - x % 32 for x in (w, h)) # 调整宽度和高度为 32 的整数倍
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) # 按新大小调整图像
image = jnp.array(image).astype(dtype) / 255.0 # 转换为 ndarray 并归一化
image = image[None].transpose(0, 3, 1, 2) # 调整维度顺序
return 2.0 * image - 1.0 # 将图像值范围调整到 [-1, 1]

def preprocess_mask(mask, dtype): # 定义预处理掩码的函数
w, h = mask.size # 获取掩码的宽度和高度
w, h = (x - x % 32 for x in (w, h)) # 调整宽度和高度为 32 的整数倍
mask = mask.resize((w, h)) # 按新大小调整掩码
mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 # 转换为灰度并归一化
mask = jnp.expand_dims(mask, axis=(0, 1)) # 扩展维度以适应模型输入

return mask  # 返回处理后的掩码

# `.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py`

```py
# 版权声明,表明该代码的版权所有者及相关条款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版("许可证")进行许可;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件在 "按原样" 基础上分发,
# 不提供任何明示或暗示的担保或条件。
# 请参见许可证以获取有关特定语言治理权限和
# 限制的更多信息。

# 导入 inspect 模块以进行获取对象的文档字符串和源代码
import inspect
# 从 typing 模块导入类型提示所需的工具
from typing import Callable, List, Optional, Union

# 导入 numpy 库用于数值计算
import numpy as np
# 导入 torch 库用于深度学习模型的构建和训练
import torch
# 从 transformers 库导入 CLIP 图像处理器和 CLIP 分词器
from transformers import CLIPImageProcessor, CLIPTokenizer

# 从配置工具导入 FrozenDict 用于处理不可变字典
from ...configuration_utils import FrozenDict
# 从调度器导入不同类型的调度器
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
# 从工具模块导入去deprecated功能和日志记录
from ...utils import deprecate, logging
# 从 onnx_utils 导入 ONNX 相关的类型和模型
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
# 从 pipeline_utils 导入扩散管道
from ..pipeline_utils import DiffusionPipeline
# 导入 StableDiffusionPipelineOutput 模块
from . import StableDiffusionPipelineOutput

# 创建一个日志记录器,用于记录该模块的日志信息
logger = logging.get_logger(__name__)

# 定义 OnnxStableDiffusionPipeline 类,继承自 DiffusionPipeline
class OnnxStableDiffusionPipeline(DiffusionPipeline):
    # 声明类的各个成员变量,表示使用的模型组件
    vae_encoder: OnnxRuntimeModel
    vae_decoder: OnnxRuntimeModel
    text_encoder: OnnxRuntimeModel
    tokenizer: CLIPTokenizer
    unet: OnnxRuntimeModel
    scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
    safety_checker: OnnxRuntimeModel
    feature_extractor: CLIPImageProcessor

    # 定义可选组件的列表,包括安全检查器和特征提取器
    _optional_components = ["safety_checker", "feature_extractor"]
    # 标记该管道为 ONNX 格式
    _is_onnx = True

    # 初始化函数,设置各个组件的参数
    def __init__(
        self,
        vae_encoder: OnnxRuntimeModel,  # VAE 编码器模型
        vae_decoder: OnnxRuntimeModel,  # VAE 解码器模型
        text_encoder: OnnxRuntimeModel,  # 文本编码器模型
        tokenizer: CLIPTokenizer,        # CLIP 分词器
        unet: OnnxRuntimeModel,          # U-Net 模型
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],  # 调度器
        safety_checker: OnnxRuntimeModel,  # 安全检查器模型
        feature_extractor: CLIPImageProcessor,  # 特征提取器
        requires_safety_checker: bool = True,  # 是否需要安全检查器
    ):
    # 定义用于编码提示的私有方法
    def _encode_prompt(
        self,
        prompt: Union[str, List[str]],  # 输入的提示文本,可以是字符串或字符串列表
        num_images_per_prompt: Optional[int],  # 每个提示生成的图像数量
        do_classifier_free_guidance: bool,  # 是否使用无分类器引导
        negative_prompt: Optional[str],  # 负面提示文本
        prompt_embeds: Optional[np.ndarray] = None,  # 提示的嵌入表示
        negative_prompt_embeds: Optional[np.ndarray] = None,  # 负面提示的嵌入表示
    ):
    # 定义检查输入有效性的私有方法
    def check_inputs(
        self,
        prompt: Union[str, List[str]],  # 输入的提示文本
        height: Optional[int],  # 图像高度
        width: Optional[int],  # 图像宽度
        callback_steps: int,  # 回调步骤数量
        negative_prompt: Optional[str] = None,  # 负面提示文本
        prompt_embeds: Optional[np.ndarray] = None,  # 提示的嵌入表示
        negative_prompt_embeds: Optional[np.ndarray] = None,  # 负面提示的嵌入表示
    # 进行一系列参数检查,以确保输入值的有效性
        ):
            # 检查高度和宽度是否都能被8整除
            if height % 8 != 0 or width % 8 != 0:
                # 如果不能整除,抛出值错误异常,提示当前高度和宽度
                raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
    
            # 检查回调步骤是否为正整数
            if (callback_steps is None) or (
                callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
            ):
                # 如果条件不满足,抛出值错误异常,提示当前回调步骤的类型和值
                raise ValueError(
                    f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                    f" {type(callback_steps)}."
                )
    
            # 检查同时传入 prompt 和 prompt_embeds
            if prompt is not None and prompt_embeds is not None:
                # 如果同时传入,抛出值错误异常,提示只能传入其中一个
                raise ValueError(
                    f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                    " only forward one of the two."
                )
            # 检查 prompt 和 prompt_embeds 是否均未提供
            elif prompt is None and prompt_embeds is None:
                # 抛出值错误异常,提示必须提供至少一个
                raise ValueError(
                    "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
                )
            # 检查 prompt 的类型
            elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                # 如果类型不匹配,抛出值错误异常,提示类型不符合
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
            # 检查同时传入 negative_prompt 和 negative_prompt_embeds
            if negative_prompt is not None and negative_prompt_embeds is not None:
                # 如果同时传入,抛出值错误异常,提示只能传入其中一个
                raise ValueError(
                    f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                    f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
                )
    
            # 检查 prompt_embeds 和 negative_prompt_embeds 是否都提供
            if prompt_embeds is not None and negative_prompt_embeds is not None:
                # 确保它们的形状相同
                if prompt_embeds.shape != negative_prompt_embeds.shape:
                    # 如果形状不匹配,抛出值错误异常,提示它们的形状
                    raise ValueError(
                        "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                        f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                        f" {negative_prompt_embeds.shape}."
                    )
    
        # 定义调用方法,接受多个参数
        def __call__(
            # 提供的提示,类型为字符串或字符串列表
            prompt: Union[str, List[str]] = None,
            # 图像高度,默认为512
            height: Optional[int] = 512,
            # 图像宽度,默认为512
            width: Optional[int] = 512,
            # 推理步骤的数量,默认为50
            num_inference_steps: Optional[int] = 50,
            # 指导尺度,默认为7.5
            guidance_scale: Optional[float] = 7.5,
            # 负提示,类型为字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为1
            num_images_per_prompt: Optional[int] = 1,
            # 额外的随机因素,默认为0.0
            eta: Optional[float] = 0.0,
            # 随机生成器,默认为None
            generator: Optional[np.random.RandomState] = None,
            # 潜在变量,默认为None
            latents: Optional[np.ndarray] = None,
            # 提示的嵌入表示,默认为None
            prompt_embeds: Optional[np.ndarray] = None,
            # 负提示的嵌入表示,默认为None
            negative_prompt_embeds: Optional[np.ndarray] = None,
            # 输出类型,默认为"pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典,默认为True
            return_dict: bool = True,
            # 回调函数,默认为None
            callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
            # 回调步骤,默认为1
            callback_steps: int = 1,
# 定义一个名为 StableDiffusionOnnxPipeline 的类,继承自 OnnxStableDiffusionPipeline
class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
    # 初始化方法,接受多个模型和处理器作为参数
    def __init__(
        self,
        vae_encoder: OnnxRuntimeModel,  # VAE 编码器模型
        vae_decoder: OnnxRuntimeModel,  # VAE 解码器模型
        text_encoder: OnnxRuntimeModel,  # 文本编码器模型
        tokenizer: CLIPTokenizer,        # 分词器
        unet: OnnxRuntimeModel,          # U-Net 模型
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],  # 调度器,可以是多种类型
        safety_checker: OnnxRuntimeModel,  # 安全检查模型
        feature_extractor: CLIPImageProcessor,  # 特征提取器
    ):
        # 定义弃用消息,提醒用户使用替代类
        deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
        # 调用弃用函数,记录弃用警告
        deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
        # 调用父类的初始化方法,传入所有参数
        super().__init__(
            vae_encoder=vae_encoder,
            vae_decoder=vae_decoder,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )

.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion_img2img.py

# 版权信息,表示该代码的所有权归 HuggingFace 团队所有
# 许可信息,表明该文件在 Apache 2.0 许可下分发
# 除非遵守该许可,否则不得使用此文件
# 提供许可的获取地址
# 除非适用法律或书面同意,否则按照 "按现状" 基础分发软件,没有任何明示或暗示的担保
# 详细信息见许可中关于权限和限制的部分

import inspect  # 导入 inspect 模块,用于获取对象的信息
from typing import Callable, List, Optional, Union  # 导入类型提示,定义函数参数和返回值类型

import numpy as np  # 导入 numpy,用于数组和矩阵操作
import PIL.Image  # 导入 PIL.Image,用于图像处理
import torch  # 导入 PyTorch,支持深度学习计算
from transformers import CLIPImageProcessor, CLIPTokenizer  # 从 transformers 库导入 CLIP 图像处理器和分词器

from ...configuration_utils import FrozenDict  # 从配置工具导入 FrozenDict,用于不可变字典
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler  # 导入调度器类,用于模型训练调度
from ...utils import PIL_INTERPOLATION, deprecate, logging  # 导入工具类,处理 PIL 插值、弃用警告和日志记录
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel  # 从 ONNX 工具导入转换类型和模型类
from ..pipeline_utils import DiffusionPipeline  # 导入 DiffusionPipeline,基础管道类
from . import StableDiffusionPipelineOutput  # 导入稳定扩散管道的输出类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess 中复制的 preprocess 函数,调整尺寸从 8 变为 64
def preprocess(image):
    # 弃用消息,通知用户该方法将在未来版本中删除
    deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
    # 触发弃用警告,提醒用户使用替代方法
    deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
    # 如果输入是 PyTorch 张量,则直接返回
    if isinstance(image, torch.Tensor):
        return image
    # 如果输入是 PIL 图像,将其封装到列表中
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    # 如果输入的第一个元素是 PIL 图像
    if isinstance(image[0], PIL.Image.Image):
        # 获取图像的宽度和高度
        w, h = image[0].size
        # 将宽和高调整为64的整数倍
        w, h = (x - x % 64 for x in (w, h))  # resize to integer multiple of 64

        # 调整图像大小并转换为 NumPy 数组
        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
        # 沿着第0维连接图像数组
        image = np.concatenate(image, axis=0)
        # 将图像数据类型转换为浮点型并归一化到 [0, 1]
        image = np.array(image).astype(np.float32) / 255.0
        # 调整数组的维度顺序
        image = image.transpose(0, 3, 1, 2)
        # 将图像数据缩放到 [-1, 1]
        image = 2.0 * image - 1.0
        # 将 NumPy 数组转换为 PyTorch 张量
        image = torch.from_numpy(image)
    # 如果输入的第一个元素是 PyTorch 张量
    elif isinstance(image[0], torch.Tensor):
        # 沿着第0维连接多个张量
        image = torch.cat(image, dim=0)
    # 返回处理后的图像
    return image

# 定义一个用于文本引导的图像到图像生成的管道类,使用稳定扩散模型
class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
    r"""
    用于文本引导的图像到图像生成的管道,基于稳定扩散模型。

    该模型继承自 [`DiffusionPipeline`]。查看超类文档,了解库为所有管道实现的通用方法
    (例如下载或保存、在特定设备上运行等)。
    # 参数说明
    Args:
        vae ([`AutoencoderKL`]):  # 变分自编码器模型,用于将图像编码和解码为潜在表示。
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):  # 冻结的文本编码器,Stable Diffusion 使用 CLIP 的文本部分。
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),具体是
            [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 变体。
        tokenizer (`CLIPTokenizer`):  # CLIPTokenizer 类的分词器。
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]):  # 条件 U-Net 结构,用于去噪编码的图像潜在表示。
            Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):  # 与 `unet` 结合使用的调度器,用于去噪编码的图像潜在表示。
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):  # 分类模块,估计生成的图像是否可能被视为冒犯或有害。
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
        feature_extractor ([`CLIPImageProcessor`]):  # 从生成的图像中提取特征,以作为 `safety_checker` 的输入。
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """

    # 定义变量类型
    vae_encoder: OnnxRuntimeModel  # VAE 编码器的类型
    vae_decoder: OnnxRuntimeModel  # VAE 解码器的类型
    text_encoder: OnnxRuntimeModel  # 文本编码器的类型
    tokenizer: CLIPTokenizer  # 分词器的类型
    unet: OnnxRuntimeModel  # U-Net 模型的类型
    scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]  # 调度器的类型
    safety_checker: OnnxRuntimeModel  # 安全检查器的类型
    feature_extractor: CLIPImageProcessor  # 特征提取器的类型

    # 可选组件列表
    _optional_components = ["safety_checker", "feature_extractor"]  # 包含可选组件的列表
    _is_onnx = True  # 指示是否使用 ONNX 模型

    # 构造函数初始化各个组件
    def __init__(  # 初始化方法
        self,
        vae_encoder: OnnxRuntimeModel,  # 传入的 VAE 编码器
        vae_decoder: OnnxRuntimeModel,  # 传入的 VAE 解码器
        text_encoder: OnnxRuntimeModel,  # 传入的文本编码器
        tokenizer: CLIPTokenizer,  # 传入的分词器
        unet: OnnxRuntimeModel,  # 传入的 U-Net 模型
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],  # 传入的调度器
        safety_checker: OnnxRuntimeModel,  # 传入的安全检查器
        feature_extractor: CLIPImageProcessor,  # 传入的特征提取器
        requires_safety_checker: bool = True,  # 是否需要安全检查器的标志
    # 从 diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt 复制
    def _encode_prompt(  # 编码提示的方法
        self,
        prompt: Union[str, List[str]],  # 提示文本,可以是字符串或字符串列表
        num_images_per_prompt: Optional[int],  # 每个提示生成的图像数量
        do_classifier_free_guidance: bool,  # 是否进行无分类器引导
        negative_prompt: Optional[str],  # 可选的负面提示文本
        prompt_embeds: Optional[np.ndarray] = None,  # 可选的提示嵌入
        negative_prompt_embeds: Optional[np.ndarray] = None,  # 可选的负面提示嵌入
    # 定义一个检查输入参数的函数
        def check_inputs(
            self,  # 类实例自身
            prompt: Union[str, List[str]],  # 提示信息,可以是字符串或字符串列表
            callback_steps: int,  # 回调步骤的整数值
            negative_prompt: Optional[Union[str, List[str]]] = None,  # 可选的负面提示,字符串或列表
            prompt_embeds: Optional[np.ndarray] = None,  # 可选的提示嵌入,NumPy 数组
            negative_prompt_embeds: Optional[np.ndarray] = None,  # 可选的负面提示嵌入,NumPy 数组
        ):
            # 检查回调步骤是否为 None 或非正整数
            if (callback_steps is None) or (
                callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
            ):
                # 如果回调步骤不符合要求,则抛出值错误
                raise ValueError(
                    f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                    f" {type(callback_steps)}."
                )
    
            # 检查是否同时提供了提示和提示嵌入
            if prompt is not None and prompt_embeds is not None:
                # 如果同时提供了,则抛出值错误
                raise ValueError(
                    f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                    " only forward one of the two."
                )
            # 检查提示和提示嵌入是否都未提供
            elif prompt is None and prompt_embeds is None:
                # 如果都未提供,则抛出值错误
                raise ValueError(
                    "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
                )
            # 检查提示类型是否为字符串或列表
            elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                # 如果类型不符合,则抛出值错误
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
            # 检查是否同时提供了负面提示和负面提示嵌入
            if negative_prompt is not None and negative_prompt_embeds is not None:
                # 如果同时提供了,则抛出值错误
                raise ValueError(
                    f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                    f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
                )
    
            # 检查提示嵌入和负面提示嵌入的形状是否一致
            if prompt_embeds is not None and negative_prompt_embeds is not None:
                if prompt_embeds.shape != negative_prompt_embeds.shape:
                    # 如果形状不一致,则抛出值错误
                    raise ValueError(
                        "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                        f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                        f" {negative_prompt_embeds.shape}."
                    )
    
        # 定义可调用的方法,处理提示和生成图像
        def __call__(
            self,  # 类实例自身
            prompt: Union[str, List[str]],  # 提示信息,可以是字符串或字符串列表
            image: Union[np.ndarray, PIL.Image.Image] = None,  # 可选的图像输入,可以是 NumPy 数组或 PIL 图像
            strength: float = 0.8,  # 图像强度的浮点值,默认为 0.8
            num_inference_steps: Optional[int] = 50,  # 可选的推理步骤数,默认为 50
            guidance_scale: Optional[float] = 7.5,  # 可选的引导尺度,默认为 7.5
            negative_prompt: Optional[Union[str, List[str]]] = None,  # 可选的负面提示,字符串或列表
            num_images_per_prompt: Optional[int] = 1,  # 每个提示生成的图像数量,默认为 1
            eta: Optional[float] = 0.0,  # 可选的 eta 值,默认为 0.0
            generator: Optional[np.random.RandomState] = None,  # 可选的随机数生成器
            prompt_embeds: Optional[np.ndarray] = None,  # 可选的提示嵌入,NumPy 数组
            negative_prompt_embeds: Optional[np.ndarray] = None,  # 可选的负面提示嵌入,NumPy 数组
            output_type: Optional[str] = "pil",  # 输出类型,默认为 'pil'
            return_dict: bool = True,  # 是否返回字典格式,默认为 True
            callback: Optional[Callable[[int, int, np.ndarray], None]] = None,  # 可选的回调函数
            callback_steps: int = 1,  # 回调步骤的整数值,默认为 1

.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion_inpaint.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件按“现状”基础进行分发,
# 不附有任何明示或暗示的担保或条件。
# 请参见许可证以获取管理权限的特定语言和
# 限制条款。

import inspect  # 导入 inspect 模块,用于检查对象的属性和方法
from typing import Callable, List, Optional, Union  # 导入类型注释,便于类型检查

import numpy as np  # 导入 NumPy 库,用于数组和矩阵操作
import PIL.Image  # 导入 PIL 图像处理库
import torch  # 导入 PyTorch 库,用于深度学习
from transformers import CLIPImageProcessor, CLIPTokenizer  # 导入 Transformers 库中的图像处理和标记器

from ...configuration_utils import FrozenDict  # 导入 FrozenDict,用于不可变字典
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler  # 导入调度器类
from ...utils import PIL_INTERPOLATION, deprecate, logging  # 导入工具函数和日志模块
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel  # 导入 ONNX 相关工具
from ..pipeline_utils import DiffusionPipeline  # 导入扩散管道基类
from . import StableDiffusionPipelineOutput  # 导入稳定扩散管道输出类


logger = logging.get_logger(__name__)  # 创建一个记录器,使用当前模块名称进行日志记录

NUM_UNET_INPUT_CHANNELS = 9  # 定义 UNet 输入通道的数量
NUM_LATENT_CHANNELS = 4  # 定义潜在通道的数量


def prepare_mask_and_masked_image(image, mask, latents_shape):  # 定义准备掩模和掩模图像的函数
    # 将输入图像转换为 RGB 格式,并调整大小以适应潜在形状
    image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
    # 调整数组形状以适配深度学习模型的输入要求
    image = image[None].transpose(0, 3, 1, 2)
    # 将图像数据类型转换为 float32,并归一化到 [-1, 1] 范围
    image = image.astype(np.float32) / 127.5 - 1.0

    # 将掩模图像转换为灰度并调整大小
    image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
    # 应用掩模到图像,得到掩模图像
    masked_image = image * (image_mask < 127.5)

    # 调整掩模大小以匹配潜在形状,并转换为灰度格式
    mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
    mask = np.array(mask.convert("L"))
    # 将掩模数据类型转换为 float32,并归一化到 [0, 1] 范围
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]  # 添加维度以匹配模型输入要求
    # 将小于 0.5 的值设为 0
    mask[mask < 0.5] = 0
    # 将大于等于 0.5 的值设为 1
    mask[mask >= 0.5] = 1

    return mask, masked_image  # 返回处理后的掩模和掩模图像


class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):  # 定义用于图像修补的扩散管道类
    r"""
    使用稳定扩散进行文本引导的图像修补管道。*这是一个实验特性*。

    此模型继承自 [`DiffusionPipeline`]。请查看超类文档,以获取库为所有管道实现的通用方法
    (例如下载或保存,在特定设备上运行等)。
    # 文档字符串,定义类的参数和它们的类型
    Args:
        vae ([`AutoencoderKL`]):  # 变分自编码器模型,用于编码和解码图像及其潜在表示
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):  # 冻结的文本编码器,用于处理文本输入
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):  # 处理文本的标记器
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]):  # 条件 U-Net 架构,用于去噪编码的图像潜在表示
            Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):  # 调度器,与 `unet` 一起用于去噪图像潜在表示
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):  # 分类模块,评估生成图像是否可能被视为冒犯或有害
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
        feature_extractor ([`CLIPImageProcessor`]):  # 从生成图像中提取特征的模型,用于 `safety_checker` 的输入
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """

    # 定义多个模型的类型,用于保存各个组件的实例
    vae_encoder: OnnxRuntimeModel  # 编码器模型的类型
    vae_decoder: OnnxRuntimeModel  # 解码器模型的类型
    text_encoder: OnnxRuntimeModel  # 文本编码器模型的类型
    tokenizer: CLIPTokenizer  # 文本标记器的类型
    unet: OnnxRuntimeModel  # U-Net 模型的类型
    scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]  # 调度器的类型,可以是多种类型之一
    safety_checker: OnnxRuntimeModel  # 安全检查器模型的类型
    feature_extractor: CLIPImageProcessor  # 特征提取器模型的类型

    _optional_components = ["safety_checker", "feature_extractor"]  # 可选组件的列表
    _is_onnx = True  # 指示当前模型是否为 ONNX 格式

    # 构造函数,用于初始化类的实例
    def __init__(
        self,
        vae_encoder: OnnxRuntimeModel,  # 传入编码器模型实例
        vae_decoder: OnnxRuntimeModel,  # 传入解码器模型实例
        text_encoder: OnnxRuntimeModel,  # 传入文本编码器模型实例
        tokenizer: CLIPTokenizer,  # 传入文本标记器实例
        unet: OnnxRuntimeModel,  # 传入 U-Net 模型实例
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],  # 传入调度器实例
        safety_checker: OnnxRuntimeModel,  # 传入安全检查器模型实例
        feature_extractor: CLIPImageProcessor,  # 传入特征提取器模型实例
        requires_safety_checker: bool = True,  # 指示是否需要安全检查器的布尔参数
    # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
    # 编码提示的函数
    def _encode_prompt(
        self,
        prompt: Union[str, List[str]],  # 输入的提示,可以是字符串或字符串列表
        num_images_per_prompt: Optional[int],  # 每个提示生成的图像数量,默认为可选
        do_classifier_free_guidance: bool,  # 是否执行无分类器自由引导的布尔参数
        negative_prompt: Optional[str],  # 可选的负面提示
        prompt_embeds: Optional[np.ndarray] = None,  # 可选的提示嵌入,默认为 None
        negative_prompt_embeds: Optional[np.ndarray] = None,  # 可选的负面提示嵌入,默认为 None
    # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs
    # 定义一个检查输入参数的函数,确保所有输入都符合预期
        def check_inputs(
            self,  # 指向类实例的引用
            prompt: Union[str, List[str]],  # 输入的提示,类型为字符串或字符串列表
            height: Optional[int],  # 图像高度,类型为可选整数
            width: Optional[int],  # 图像宽度,类型为可选整数
            callback_steps: int,  # 回调步骤数,类型为整数
            negative_prompt: Optional[str] = None,  # 负提示,类型为可选字符串
            prompt_embeds: Optional[np.ndarray] = None,  # 提示的嵌入表示,类型为可选numpy数组
            negative_prompt_embeds: Optional[np.ndarray] = None,  # 负提示的嵌入表示,类型为可选numpy数组
        ):
            # 检查高度和宽度是否都能被8整除
            if height % 8 != 0 or width % 8 != 0:
                # 如果不能整除,抛出值错误
                raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
    
            # 检查回调步骤是否为正整数
            if (callback_steps is None) or (
                callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
            ):
                # 如果不是正整数,抛出值错误
                raise ValueError(
                    f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                    f" {type(callback_steps)}."
                )
    
            # 检查提示和提示嵌入是否同时存在
            if prompt is not None and prompt_embeds is not None:
                # 如果两者都存在,抛出值错误
                raise ValueError(
                    f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                    " only forward one of the two."
                )
            # 检查提示和提示嵌入是否同时为None
            elif prompt is None and prompt_embeds is None:
                # 如果都是None,抛出值错误
                raise ValueError(
                    "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
                )
            # 检查提示的类型是否为字符串或列表
            elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                # 如果不是,抛出值错误
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
            # 检查负提示和负提示嵌入是否同时存在
            if negative_prompt is not None and negative_prompt_embeds is not None:
                # 如果同时存在,抛出值错误
                raise ValueError(
                    f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                    f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
                )
    
            # 检查提示嵌入和负提示嵌入的形状是否一致
            if prompt_embeds is not None and negative_prompt_embeds is not None:
                if prompt_embeds.shape != negative_prompt_embeds.shape:
                    # 如果形状不一致,抛出值错误
                    raise ValueError(
                        "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                        f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                        f" {negative_prompt_embeds.shape}."
                    )
    
        # 在不计算梯度的情况下进行操作,节省内存和计算资源
        @torch.no_grad()
    # 定义一个可调用的类方法,用于处理图像生成
        def __call__(
            self,
            # 用户输入的提示,可以是字符串或字符串列表
            prompt: Union[str, List[str]],
            # 输入的图像,类型为 PIL.Image.Image
            image: PIL.Image.Image,
            # 掩模图像,类型为 PIL.Image.Image
            mask_image: PIL.Image.Image,
            # 输出图像的高度,默认为 512
            height: Optional[int] = 512,
            # 输出图像的宽度,默认为 512
            width: Optional[int] = 512,
            # 推理步骤的数量,默认为 50
            num_inference_steps: int = 50,
            # 指导尺度,默认为 7.5
            guidance_scale: float = 7.5,
            # 负提示,可以是字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: Optional[int] = 1,
            # 噪声的 eta 值,默认为 0.0
            eta: float = 0.0,
            # 随机数生成器,默认为 None
            generator: Optional[np.random.RandomState] = None,
            # 潜在表示,默认为 None
            latents: Optional[np.ndarray] = None,
            # 提示嵌入,默认为 None
            prompt_embeds: Optional[np.ndarray] = None,
            # 负提示嵌入,默认为 None
            negative_prompt_embeds: Optional[np.ndarray] = None,
            # 输出类型,默认为 "pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典,默认为 True
            return_dict: bool = True,
            # 回调函数,默认为 None
            callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
            # 回调步骤的间隔,默认为 1
            callback_steps: int = 1,

.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion_upscale.py

# 版权声明,标明版权和许可信息
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache 许可证版本 2.0("许可证")授权; 
# 除非遵循该许可证,否则不可使用此文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非根据适用法律或书面协议另有约定, 
# 否则根据许可证分发的软件是基于“原样”提供的,
# 不附带任何明示或暗示的担保或条件。
# 有关许可证具体条款的信息,见许可证。
import inspect  # 导入inspect模块,用于获取对象的信息
from typing import Any, Callable, List, Optional, Union  # 导入类型注解

import numpy as np  # 导入numpy库,用于数值计算
import PIL.Image  # 导入PIL库,用于图像处理
import torch  # 导入PyTorch库,用于深度学习
from transformers import CLIPImageProcessor, CLIPTokenizer  # 从transformers库导入图像处理和标记器

from ...configuration_utils import FrozenDict  # 导入FrozenDict,用于配置管理
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers  # 导入调度器,用于模型训练
from ...utils import deprecate, logging  # 导入工具模块,用于日志记录和弃用警告
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel  # 导入ONNX相关工具和模型类
from ..pipeline_utils import DiffusionPipeline  # 导入扩散管道类
from . import StableDiffusionPipelineOutput  # 导入稳定扩散管道输出类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def preprocess(image):  # 定义预处理函数,接收图像作为输入
    if isinstance(image, torch.Tensor):  # 检查图像是否为PyTorch张量
        return image  # 如果是,则直接返回
    elif isinstance(image, PIL.Image.Image):  # 检查图像是否为PIL图像
        image = [image]  # 将其封装为列表

    if isinstance(image[0], PIL.Image.Image):  # 检查列表中的第一个元素是否为PIL图像
        w, h = image[0].size  # 获取图像的宽度和高度
        w, h = (x - x % 64 for x in (w, h))  # 调整宽高,使其为64的整数倍

        image = [np.array(i.resize((w, h)))[None, :] for i in image]  # 调整所有图像大小并转为数组
        image = np.concatenate(image, axis=0)  # 将所有图像数组沿第0轴合并
        image = np.array(image).astype(np.float32) / 255.0  # 转换为浮点数并归一化到[0, 1]
        image = image.transpose(0, 3, 1, 2)  # 变换数组维度为[batch, channels, height, width]
        image = 2.0 * image - 1.0  # 将值归一化到[-1, 1]
        image = torch.from_numpy(image)  # 转换为PyTorch张量
    elif isinstance(image[0], torch.Tensor):  # 如果列表中的第一个元素是PyTorch张量
        image = torch.cat(image, dim=0)  # 在第0维连接所有张量

    return image  # 返回处理后的图像


class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):  # 定义ONNX稳定扩散上采样管道类,继承自DiffusionPipeline
    vae: OnnxRuntimeModel  # 定义变分自编码器模型
    text_encoder: OnnxRuntimeModel  # 定义文本编码器模型
    tokenizer: CLIPTokenizer  # 定义CLIP标记器
    unet: OnnxRuntimeModel  # 定义U-Net模型
    low_res_scheduler: DDPMScheduler  # 定义低分辨率调度器
    scheduler: KarrasDiffusionSchedulers  # 定义Karras扩散调度器
    safety_checker: OnnxRuntimeModel  # 定义安全检查模型
    feature_extractor: CLIPImageProcessor  # 定义特征提取器

    _optional_components = ["safety_checker", "feature_extractor"]  # 可选组件列表
    _is_onnx = True  # 指示该类是否为ONNX格式

    def __init__(  # 定义构造函数
        self,
        vae: OnnxRuntimeModel,  # 变分自编码器模型
        text_encoder: OnnxRuntimeModel,  # 文本编码器模型
        tokenizer: Any,  # 任意类型的标记器
        unet: OnnxRuntimeModel,  # U-Net模型
        low_res_scheduler: DDPMScheduler,  # 低分辨率调度器
        scheduler: KarrasDiffusionSchedulers,  # Karras调度器
        safety_checker: Optional[OnnxRuntimeModel] = None,  # 可选的安全检查模型
        feature_extractor: Optional[CLIPImageProcessor] = None,  # 可选的特征提取器
        max_noise_level: int = 350,  # 最大噪声级别
        num_latent_channels=4,  # 潜在通道数量
        num_unet_input_channels=7,  # U-Net输入通道数量
        requires_safety_checker: bool = True,  # 是否需要安全检查器
    # 定义一个检查输入参数的函数,确保输入有效
        def check_inputs(
            self,  # 表示该方法属于某个类
            prompt: Union[str, List[str]],  # 输入的提示,支持字符串或字符串列表
            image,  # 输入的图像,类型不固定
            noise_level,  # 噪声级别,通常用于控制生成图像的噪声程度
            callback_steps,  # 回调步数,用于更新或监控生成过程
            negative_prompt=None,  # 可选的负面提示,控制生成内容的方向
            prompt_embeds=None,  # 可选的提示嵌入,直接传入嵌入向量
            negative_prompt_embeds=None,  # 可选的负面提示嵌入,直接传入嵌入向量
        # 定义一个准备潜在变量的函数,用于生成图像的潜在表示
        def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
            # 定义潜在变量的形状,根据批大小、通道数、高度和宽度
            shape = (batch_size, num_channels_latents, height, width)
            # 如果没有提供潜在变量,则生成新的随机潜在变量
            if latents is None:
                latents = generator.randn(*shape).astype(dtype)  # 从生成器中生成随机潜在变量并转换为指定数据类型
            # 如果提供的潜在变量形状不符合预期,则引发错误
            elif latents.shape != shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
    
            return latents  # 返回准备好的潜在变量
    
        # 定义一个解码潜在变量的函数,将潜在表示转换为图像
        def decode_latents(self, latents):
            # 调整潜在变量的尺度,以匹配解码器的输入要求
            latents = 1 / 0.08333 * latents
            # 使用变分自编码器(VAE)解码潜在变量,获取生成的图像
            image = self.vae(latent_sample=latents)[0]
            # 将图像值缩放到 [0, 1] 范围内,并进行裁剪
            image = np.clip(image / 2 + 0.5, 0, 1)
            # 调整图像的维度顺序,从 (N, C, H, W) 转换为 (N, H, W, C)
            image = image.transpose((0, 2, 3, 1))
            return image  # 返回解码后的图像
    
        # 定义一个编码提示的函数,将文本提示转换为嵌入向量
        def _encode_prompt(
            self,
            prompt: Union[str, List[str]],  # 输入的提示,支持字符串或字符串列表
            num_images_per_prompt: Optional[int],  # 每个提示生成的图像数量
            do_classifier_free_guidance: bool,  # 是否进行无分类器引导
            negative_prompt: Optional[str],  # 可选的负面提示
            prompt_embeds: Optional[np.ndarray] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[np.ndarray] = None,  # 可选的负面提示嵌入
        # 定义一个调用函数,用于生成图像
        def __call__(
            self,
            prompt: Union[str, List[str]],  # 输入的提示,支持字符串或字符串列表
            image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]],  # 输入的图像,可以是 ndarray 或 PIL 图像
            num_inference_steps: int = 75,  # 推理步骤的数量,默认设置为 75
            guidance_scale: float = 9.0,  # 引导的缩放因子,控制生成图像的质量
            noise_level: int = 20,  # 噪声级别,影响生成图像的随机性
            negative_prompt: Optional[Union[str, List[str]]] = None,  # 可选的负面提示
            num_images_per_prompt: Optional[int] = 1,  # 每个提示生成的图像数量,默认设置为 1
            eta: float = 0.0,  # 控制随机性和确定性的超参数
            generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None,  # 随机数生成器
            latents: Optional[np.ndarray] = None,  # 可选的潜在变量
            prompt_embeds: Optional[np.ndarray] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[np.ndarray] = None,  # 可选的负面提示嵌入
            output_type: Optional[str] = "pil",  # 输出类型,默认设置为 PIL 图像
            return_dict: bool = True,  # 是否以字典形式返回结果
            callback: Optional[Callable[[int, int, np.ndarray], None]] = None,  # 可选的回调函数,用于处理生成过程中的状态
            callback_steps: Optional[int] = 1,  # 回调的步数,控制回调的频率

.\diffusers\pipelines\stable_diffusion\pipeline_output.py

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入 List, Optional, Union 类型注解
from typing import List, Optional, Union

# 导入 numpy 库并重命名为 np
import numpy as np
# 导入 PIL.Image 模块
import PIL.Image

# 从上层模块导入 BaseOutput 和 is_flax_available 函数
from ...utils import BaseOutput, is_flax_available


# 定义一个数据类,作为 Stable Diffusion 管道的输出
@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
    """
    Stable Diffusion 管道的输出类。

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            包含去噪后的 PIL 图像列表,长度为 `batch_size`,或形状为 `(batch_size, height, width,
            num_channels)` 的 NumPy 数组。
        nsfw_content_detected (`List[bool]`)
            指示对应生成图像是否包含“不可安全观看” (nsfw) 内容的列表,若无法进行安全检查则为 `None`。
    """

    # 存储图像,类型为 PIL 图像列表或 NumPy 数组
    images: Union[List[PIL.Image.Image], np.ndarray]
    # 存储 nsfw 内容检测结果的可选列表
    nsfw_content_detected: Optional[List[bool]]


# 检查是否可用 Flax 库
if is_flax_available():
    # 导入 flax 库
    import flax

    # 定义一个数据类,作为 Flax 基于 Stable Diffusion 管道的输出
    @flax.struct.dataclass
    class FlaxStableDiffusionPipelineOutput(BaseOutput):
        """
        Flax 基于 Stable Diffusion 管道的输出类。

        Args:
            images (`np.ndarray`):
                形状为 `(batch_size, height, width, num_channels)` 的去噪图像数组。
            nsfw_content_detected (`List[bool]`):
                指示对应生成图像是否包含“不可安全观看” (nsfw) 内容的列表,
                或 `None` 如果无法进行安全检查。
        """

        # 存储图像,类型为 NumPy 数组
        images: np.ndarray
        # 存储 nsfw 内容检测结果的列表
        nsfw_content_detected: List[bool]

.\diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion.py

# 版权声明,表明此文件的版权归 HuggingFace 团队所有
# 
# 根据 Apache License 2.0 许可协议进行授权;
# 除非遵循此许可协议,否则不得使用此文件。
# 可以通过以下网址获取许可的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,根据该许可协议分发的软件在“按原样”基础上分发,
# 不提供任何明示或暗示的担保或条件。
# 详细信息请参见许可协议中关于权限和限制的具体条款。
import inspect  # 导入用于检查对象的模块
from typing import Any, Callable, Dict, List, Optional, Union  # 导入类型注解工具

import torch  # 导入 PyTorch 库
from packaging import version  # 导入版本管理工具
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection  # 导入特定的转换器模型

from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 导入多管道回调相关类
from ...configuration_utils import FrozenDict  # 导入不可变字典工具
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 导入图像处理相关类
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  # 导入加载器混合类
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel  # 导入不同模型
from ...models.lora import adjust_lora_scale_text_encoder  # 导入用于调整文本编码器的 LoRA 函数
from ...schedulers import KarrasDiffusionSchedulers  # 导入 Karras 扩散调度器
from ...utils import (  # 导入多个实用工具
    USE_PEFT_BACKEND,  # 指定使用 PEFT 后端的标志
    deprecate,  # 导入弃用装饰器
    logging,  # 导入日志工具
    replace_example_docstring,  # 导入替换示例文档字符串的工具
    scale_lora_layers,  # 导入缩放 LoRA 层的工具
    unscale_lora_layers,  # 导入反缩放 LoRA 层的工具
)
from ...utils.torch_utils import randn_tensor  # 导入生成随机张量的工具
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入扩散管道相关类
from .pipeline_output import StableDiffusionPipelineOutput  # 导入稳定扩散管道输出类
from .safety_checker import StableDiffusionSafetyChecker  # 导入稳定扩散安全检查器类


logger = logging.get_logger(__name__)  # 创建当前模块的日志记录器,禁止 pylint 检查命名

EXAMPLE_DOC_STRING = """  # 示例文档字符串,展示使用示例
    Examples:
        ```py
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import StableDiffusionPipeline  # 从 diffusers 导入稳定扩散管道

        >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)  # 从预训练模型加载管道并设置数据类型
        >>> pipe = pipe.to("cuda")  # 将管道转移到 GPU

        >>> prompt = "a photo of an astronaut riding a horse on mars"  # 定义文本提示
        >>> image = pipe(prompt).images[0]  # 生成图像并提取第一张图像
        ```py
"""


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):  # 定义噪声配置重标定函数
    """
    根据 `guidance_rescale` 对 `noise_cfg` 进行重标定。基于论文[Common Diffusion Noise Schedules and
    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)的发现。参见第 3.4 节
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)  # 计算文本噪声的标准差,保持维度
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)  # 计算噪声配置的标准差,保持维度
    # 使用文本标准差调整噪声配置,以修正曝光过度的问题
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)  # 进行重标定
    # 按照指导比例将重标定的噪声与原始噪声混合,避免生成“平面”图像
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg  # 更新噪声配置
    # 返回噪声配置对象
        return noise_cfg
# 定义一个函数用于检索时间步,接受多个参数
def retrieve_timesteps(
    # 调度器实例,用于获取时间步
    scheduler,
    # 可选的推理步骤数量,默认为 None
    num_inference_steps: Optional[int] = None,
    # 可选的设备参数,可以是字符串或 torch.device,默认为 None
    device: Optional[Union[str, torch.device]] = None,
    # 可选的自定义时间步,默认为 None
    timesteps: Optional[List[int]] = None,
    # 可选的自定义 sigma 值,默认为 None
    sigmas: Optional[List[float]] = None,
    # 其他可选参数,将传递给调度器的 set_timesteps 方法
    **kwargs,
):
    """
    调用调度器的 `set_timesteps` 方法,并在调用后从调度器获取时间步。处理自定义时间步。
    所有 kwargs 将传递给 `scheduler.set_timesteps`。

    参数:
        scheduler (`SchedulerMixin`):
            用于获取时间步的调度器。
        num_inference_steps (`int`):
            生成样本时使用的扩散步骤数量。如果使用此参数,则 `timesteps` 必须为 `None`。
        device (`str` 或 `torch.device`, *可选*):
            时间步应移动到的设备。如果为 `None`,则不移动时间步。
        timesteps (`List[int]`, *可选*):
            用于覆盖调度器时间步间距策略的自定义时间步。如果传入 `timesteps`,
            `num_inference_steps` 和 `sigmas` 必须为 `None`。
        sigmas (`List[float]`, *可选*):
            用于覆盖调度器时间步间距策略的自定义 sigma 值。如果传入 `sigmas`,
            `num_inference_steps` 和 `timesteps` 必须为 `None`。

    返回:
        `Tuple[torch.Tensor, int]`: 一个元组,第一个元素是调度器的时间步调度,
        第二个元素是推理步骤的数量。
    """
    # 检查是否同时传入了自定义时间步和 sigma 值,如果是则抛出异常
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    # 如果传入了自定义时间步
    if timesteps is not None:
        # 检查调度器的 set_timesteps 方法是否接受时间步参数
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,抛出异常
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的 set_timesteps 方法设置自定义时间步
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        # 获取调度器中的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤数量
        num_inference_steps = len(timesteps)
    # 如果传入了自定义 sigma 值
    elif sigmas is not None:
        # 检查调度器的 set_timesteps 方法是否接受 sigma 参数
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,抛出异常
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的 set_timesteps 方法设置自定义 sigma 值
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        # 获取调度器中的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤数量
        num_inference_steps = len(timesteps)
    # 如果没有传入自定义时间步或 sigma
    else:
        # 调用调度器的 set_timesteps 方法,使用推理步骤数量
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        # 获取调度器中的时间步
        timesteps = scheduler.timesteps
    # 返回时间步长和推理步骤的数量
        return timesteps, num_inference_steps
# 定义一个名为 StableDiffusionPipeline 的类,继承多个混入类
class StableDiffusionPipeline(
    # 继承 DiffusionPipeline 基础功能
    DiffusionPipeline,
    # 继承稳定扩散特有的功能
    StableDiffusionMixin,
    # 继承文本反演加载功能
    TextualInversionLoaderMixin,
    # 继承 LoRA 加载功能
    StableDiffusionLoraLoaderMixin,
    # 继承 IP 适配器功能
    IPAdapterMixin,
    # 继承从单一文件加载功能
    FromSingleFileMixin,
):
    # 文档字符串,描述该类用于文本到图像生成
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    # 说明此模型继承自 DiffusionPipeline,并指出可以查看超类文档获取通用方法
    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    # 说明该管道也继承了多种加载方法
    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    # 参数说明,定义构造函数需要的各类参数及其类型
    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
    """

    # 定义一个字符串,表示模型在 CPU 上的卸载顺序
    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    # 定义可选组件的列表
    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
    # 定义不包含在 CPU 卸载中的组件
    _exclude_from_cpu_offload = ["safety_checker"]
    # 定义回调张量输入的列表
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
    # 初始化方法,构造类的实例并接受多个参数
        def __init__(
            self,
            vae: AutoencoderKL,  # 变分自编码器,用于图像生成
            text_encoder: CLIPTextModel,  # 文本编码器,用于将文本转换为嵌入
            tokenizer: CLIPTokenizer,  # 分词器,用于处理文本数据
            unet: UNet2DConditionModel,  # UNet模型,用于条件生成
            scheduler: KarrasDiffusionSchedulers,  # 调度器,控制生成过程中的步伐
            safety_checker: StableDiffusionSafetyChecker,  # 安全检查器,确保生成内容符合安全标准
            feature_extractor: CLIPImageProcessor,  # 特征提取器,用于处理图像
            image_encoder: CLIPVisionModelWithProjection = None,  # 可选图像编码器,用于图像嵌入
            requires_safety_checker: bool = True,  # 是否需要安全检查器,默认为True
        def _encode_prompt(
            self,
            prompt,  # 输入的提示文本
            device,  # 设备信息,指定运行的硬件
            num_images_per_prompt,  # 每个提示生成的图像数量
            do_classifier_free_guidance,  # 是否使用无分类器的引导
            negative_prompt=None,  # 可选的负面提示文本
            prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入
            lora_scale: Optional[float] = None,  # 可选的Lora缩放因子
            **kwargs,  # 其他可选参数
        ):
            # 生成弃用警告信息,提示用户该方法将被移除
            deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
            # 发出弃用警告,通知版本号和信息
            deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
    
            # 调用新的编码方法,获取提示嵌入的元组
            prompt_embeds_tuple = self.encode_prompt(
                prompt=prompt,  # 输入提示
                device=device,  # 设备信息
                num_images_per_prompt=num_images_per_prompt,  # 图像数量
                do_classifier_free_guidance=do_classifier_free_guidance,  # 引导选项
                negative_prompt=negative_prompt,  # 负面提示
                prompt_embeds=prompt_embeds,  # 提示嵌入
                negative_prompt_embeds=negative_prompt_embeds,  # 负面提示嵌入
                lora_scale=lora_scale,  # Lora缩放因子
                **kwargs,  # 其他参数
            )
    
            # 将元组中的提示嵌入连接为一个张量,便于后续处理
            prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
    
            # 返回连接后的提示嵌入
            return prompt_embeds
    
        # 新的编码提示方法,接受多个参数以生成提示嵌入
        def encode_prompt(
            self,
            prompt,  # 输入的提示文本
            device,  # 设备信息,指定运行的硬件
            num_images_per_prompt,  # 每个提示生成的图像数量
            do_classifier_free_guidance,  # 是否使用无分类器的引导
            negative_prompt=None,  # 可选的负面提示文本
            prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入
            lora_scale: Optional[float] = None,  # 可选的Lora缩放因子
            clip_skip: Optional[int] = None,  # 可选的跳过参数,用于调节处理流程
    # 定义一个方法用于编码图像,接受图像、设备、每个提示的图像数量及可选的隐藏状态输出
        def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
            # 获取图像编码器参数的数据类型
            dtype = next(self.image_encoder.parameters()).dtype
    
            # 如果输入的图像不是张量,则通过特征提取器转换为张量
            if not isinstance(image, torch.Tensor):
                image = self.feature_extractor(image, return_tensors="pt").pixel_values
    
            # 将图像移动到指定设备,并转换为正确的数据类型
            image = image.to(device=device, dtype=dtype)
            # 如果需要输出隐藏状态
            if output_hidden_states:
                # 通过图像编码器处理图像,获取倒数第二层的隐藏状态
                image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
                # 将隐藏状态重复,以适应每个提示的图像数量
                image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
                # 创建一个与输入图像大小相同的零张量,通过图像编码器获取未条件化的隐藏状态
                uncond_image_enc_hidden_states = self.image_encoder(
                    torch.zeros_like(image), output_hidden_states=True
                ).hidden_states[-2]
                # 将未条件化的隐藏状态重复,以适应每个提示的图像数量
                uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                    num_images_per_prompt, dim=0
                )
                # 返回编码后的图像隐藏状态和未条件化图像的隐藏状态
                return image_enc_hidden_states, uncond_image_enc_hidden_states
            else:
                # 通过图像编码器处理图像,获取图像嵌入
                image_embeds = self.image_encoder(image).image_embeds
                # 将图像嵌入重复,以适应每个提示的图像数量
                image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
                # 创建一个与图像嵌入相同形状的零张量作为未条件化的图像嵌入
                uncond_image_embeds = torch.zeros_like(image_embeds)
    
                # 返回图像嵌入和未条件化图像嵌入
                return image_embeds, uncond_image_embeds
    
        # 定义一个方法用于准备 IP 适配器的图像嵌入,接受 IP 适配器图像、图像嵌入、设备、每个提示的图像数量及是否进行分类自由引导
        def prepare_ip_adapter_image_embeds(
            self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
    ):
        # 初始化图像嵌入列表
        image_embeds = []
        # 如果启用无分类器自由引导,则初始化负图像嵌入列表
        if do_classifier_free_guidance:
            negative_image_embeds = []
        # 如果输入适配器图像嵌入为 None
        if ip_adapter_image_embeds is None:
            # 如果输入适配器图像不是列表,则将其转换为列表
            if not isinstance(ip_adapter_image, list):
                ip_adapter_image = [ip_adapter_image]

            # 检查输入适配器图像的长度是否与 IP 适配器数量相同
            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                raise ValueError(
                    # 抛出错误,说明输入适配器图像长度不匹配
                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                )

            # 遍历输入适配器图像和对应的图像投影层
            for single_ip_adapter_image, image_proj_layer in zip(
                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
            ):
                # 判断是否输出隐藏状态
                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                # 编码单个图像,返回图像嵌入和负图像嵌入
                single_image_embeds, single_negative_image_embeds = self.encode_image(
                    single_ip_adapter_image, device, 1, output_hidden_state
                )

                # 将图像嵌入添加到列表中
                image_embeds.append(single_image_embeds[None, :])
                # 如果启用无分类器自由引导,则将负图像嵌入添加到列表中
                if do_classifier_free_guidance:
                    negative_image_embeds.append(single_negative_image_embeds[None, :])
        else:
            # 遍历已存在的输入适配器图像嵌入
            for single_image_embeds in ip_adapter_image_embeds:
                # 如果启用无分类器自由引导,分离负图像嵌入和图像嵌入
                if do_classifier_free_guidance:
                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                    negative_image_embeds.append(single_negative_image_embeds)
                # 将图像嵌入添加到列表中
                image_embeds.append(single_image_embeds)

        # 初始化适配器图像嵌入列表
        ip_adapter_image_embeds = []
        # 遍历图像嵌入列表
        for i, single_image_embeds in enumerate(image_embeds):
            # 将单个图像嵌入复制指定次数以适应每个提示的图像数量
            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
            # 如果启用无分类器自由引导,复制负图像嵌入
            if do_classifier_free_guidance:
                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                # 将负图像嵌入和图像嵌入拼接
                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

            # 将图像嵌入移动到指定设备
            single_image_embeds = single_image_embeds.to(device=device)
            # 将处理后的图像嵌入添加到列表中
            ip_adapter_image_embeds.append(single_image_embeds)

        # 返回适配器图像嵌入列表
        return ip_adapter_image_embeds

    # 安全检查器的执行方法
    def run_safety_checker(self, image, device, dtype):
        # 如果没有安全检查器,初始化不适合的概念为 None
        if self.safety_checker is None:
            has_nsfw_concept = None
        else:
            # 如果输入图像是张量,则进行后处理
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                # 如果输入图像是 numpy 数组,则转换为 PIL 格式
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            # 将处理后的输入图像提取特征并移动到设备
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            # 运行安全检查器,获取处理后的图像和不适合的概念标志
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        # 返回处理后的图像和不适合的概念标志
        return image, has_nsfw_concept
    # 解码潜在表示
    def decode_latents(self, latents):
        # 定义弃用提示信息
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        # 调用弃用函数,提示用户
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
    
        # 根据缩放因子调整潜在表示
        latents = 1 / self.vae.config.scaling_factor * latents
        # 解码潜在表示,返回图像
        image = self.vae.decode(latents, return_dict=False)[0]
        # 将图像值缩放到[0, 1]范围内
        image = (image / 2 + 0.5).clamp(0, 1)
        # 转换图像为 float32 类型以确保兼容性
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        # 返回最终图像
        return image
    
    # 准备额外的步骤参数
    def prepare_extra_step_kwargs(self, generator, eta):
        # 为调度器步骤准备额外的关键字参数,因不同调度器的参数签名不同
        # eta (η) 仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
        # eta 对应 DDIM 论文中的 η,应在 [0, 1] 之间
    
        # 检查调度器是否接受 eta 参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            # 如果接受,则将 eta 添加到额外参数中
            extra_step_kwargs["eta"] = eta
    
        # 检查调度器是否接受 generator 参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            # 如果接受,则将 generator 添加到额外参数中
            extra_step_kwargs["generator"] = generator
        # 返回准备好的额外参数
        return extra_step_kwargs
    
    # 检查输入参数的有效性
    def check_inputs(
        self,
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        ip_adapter_image=None,
        ip_adapter_image_embeds=None,
        callback_on_step_end_tensor_inputs=None,
    ):
        # 方法体未提供,此处无进一步操作
    
    # 准备潜在表示
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 根据批大小和图像尺寸定义潜在表示的形状
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 如果生成器列表长度与批大小不匹配,抛出错误
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )
    
        # 如果未提供潜在表示,则随机生成
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 将已提供的潜在表示转移到指定设备
            latents = latents.to(device)
    
        # 按调度器要求的标准差缩放初始噪声
        latents = latents * self.scheduler.init_noise_sigma
        # 返回准备好的潜在表示
        return latents
    
    # 从 latent_consistency_models 获取指导尺度嵌入的方法复制
    # 定义生成指导缩放嵌入的函数,接受张量 w 和其他参数
    def get_guidance_scale_embedding(
        # 输入参数 w,为一维的张量
        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
    ) -> torch.Tensor:
        """
        参考链接,提供生成嵌入向量的信息
        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

        Args:
            w (`torch.Tensor`):
                用指定的指导缩放生成嵌入向量,以此丰富时间步嵌入。
            embedding_dim (`int`, *optional*, defaults to 512):
                要生成的嵌入的维度。
            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
                生成的嵌入的数据类型。

        Returns:
            `torch.Tensor`: 嵌入向量,形状为 `(len(w), embedding_dim)`。
        """
        # 确保输入张量 w 是一维的
        assert len(w.shape) == 1
        # 将 w 的值乘以 1000.0
        w = w * 1000.0

        # 计算嵌入的半维度
        half_dim = embedding_dim // 2
        # 计算常量 emb,用于后续的指数计算
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        # 生成半维度的指数衰减嵌入
        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
        # 将 w 转换为指定数据类型,并与 emb 进行广播相乘
        emb = w.to(dtype)[:, None] * emb[None, :]
        # 将正弦和余弦嵌入拼接在一起
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        # 如果嵌入维度是奇数,则在最后添加零填充
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1))
        # 确保最终嵌入的形状符合预期
        assert emb.shape == (w.shape[0], embedding_dim)
        # 返回生成的嵌入
        return emb

    # 定义属性,返回指导缩放值
    @property
    def guidance_scale(self):
        return self._guidance_scale

    # 定义属性,返回指导重缩放值
    @property
    def guidance_rescale(self):
        return self._guidance_rescale

    # 定义属性,返回跨注意力的值
    @property
    def clip_skip(self):
        return self._clip_skip

    # 定义属性,判断是否进行无分类器引导
    # 这里 `guidance_scale` 是类似于 Imagen 论文中方程 (2) 的指导权重 `w`
    # 当 `guidance_scale = 1` 时,相当于不进行分类器无引导
    @property
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

    # 定义属性,返回跨注意力的关键字参数
    @property
    def cross_attention_kwargs(self):
        return self._cross_attention_kwargs

    # 定义属性,返回时间步数
    @property
    def num_timesteps(self):
        return self._num_timesteps

    # 定义属性,返回中断标志
    @property
    def interrupt(self):
        return self._interrupt

    # 指定在此上下文中不计算梯度
    @torch.no_grad()
    # 替换示例文档字符串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义一个可调用的类方法
        def __call__(
            # 输入提示,可以是字符串或字符串列表,默认为 None
            self,
            prompt: Union[str, List[str]] = None,
            # 输出图像的高度,默认为 None
            height: Optional[int] = None,
            # 输出图像的宽度,默认为 None
            width: Optional[int] = None,
            # 进行推理的步骤数,默认为 50
            num_inference_steps: int = 50,
            # 指定的时间步列表,默认为 None
            timesteps: List[int] = None,
            # 指定的 sigma 值列表,默认为 None
            sigmas: List[float] = None,
            # 指导尺度,默认为 7.5
            guidance_scale: float = 7.5,
            # 负面提示,可以是字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: Optional[int] = 1,
            # 控制生成过程中的随机性,默认为 0.0
            eta: float = 0.0,
            # 随机数生成器,可以是单个或多个 PyTorch 生成器,默认为 None
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量张量,默认为 None
            latents: Optional[torch.Tensor] = None,
            # 提示嵌入的张量,默认为 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 负面提示嵌入的张量,默认为 None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 输入适配器图像,默认为 None
            ip_adapter_image: Optional[PipelineImageInput] = None,
            # 输入适配器图像的嵌入列表,默认为 None
            ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
            # 输出类型,默认为 "pil"(Python Imaging Library)
            output_type: Optional[str] = "pil",
            # 是否返回字典,默认为 True
            return_dict: bool = True,
            # 交叉注意力的额外参数,默认为 None
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 指导重缩放因子,默认为 0.0
            guidance_rescale: float = 0.0,
            # 跳过的剪辑层,默认为 None
            clip_skip: Optional[int] = None,
            # 步骤结束时的回调函数,可以是多种类型,默认为 None
            callback_on_step_end: Optional[
                Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
            ] = None,
            # 步骤结束时的张量输入列表,默认为 ["latents"]
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            # 额外的关键字参数
            **kwargs,
posted @ 2024-10-22 12:34  绝不原创的飞龙  阅读(124)  评论(0)    收藏  举报