diffusers 源码解析(二十五)
.\diffusers\pipelines\deepfloyd_if\pipeline_if_img2img.py
# 导入处理 HTML 的库
import html
# 导入用于获取对象的源代码的库
import inspect
# 导入正则表达式库
import re
# 导入用于解析 URL 的库,并简化其引用
import urllib.parse as ul
# 导入类型提示所需的类型
from typing import Any, Callable, Dict, List, Optional, Union
# 导入 NumPy 库,常用于科学计算
import numpy as np
# 导入 PIL 库,用于处理图像
import PIL.Image
# 导入 PyTorch 库,深度学习框架
import torch
# 从 Transformers 库中导入图像处理器和模型
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
# 从 loaders 模块导入 StableDiffusionLoraLoaderMixin 类
from ...loaders import StableDiffusionLoraLoaderMixin
# 从 models 模块导入 UNet2DConditionModel 类
from ...models import UNet2DConditionModel
# 从 schedulers 模块导入 DDPMScheduler 类
from ...schedulers import DDPMScheduler
# 从 utils 模块导入多个工具函数和常量
from ...utils import (
    BACKENDS_MAPPING,  # 后端映射
    PIL_INTERPOLATION,  # PIL 图像插值方法
    is_bs4_available,  # 检查 BeautifulSoup 库是否可用
    is_ftfy_available,  # 检查 ftfy 库是否可用
    logging,  # 日志记录工具
    replace_example_docstring,  # 替换示例文档字符串的工具
)
# 从 torch_utils 模块导入 randn_tensor 函数
from ...utils.torch_utils import randn_tensor
# 从 pipeline_utils 模块导入 DiffusionPipeline 类
from ..pipeline_utils import DiffusionPipeline
# 从 pipeline_output 模块导入 IFPipelineOutput 类
from .pipeline_output import IFPipelineOutput
# 从 safety_checker 模块导入 IFSafetyChecker 类
from .safety_checker import IFSafetyChecker
# 从 watermark 模块导入 IFWatermarker 类
from .watermark import IFWatermarker
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
# 如果 BeautifulSoup 库可用,则导入其模块
if is_bs4_available():
    from bs4 import BeautifulSoup
# 如果 ftfy 库可用,则导入其模块
if is_ftfy_available():
    import ftfy
# 定义调整图像大小的函数,接受图像和目标大小作为参数
def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
    # 获取图像的宽度和高度
    w, h = images.size
    # 计算宽高比
    coef = w / h
    # 将宽度和高度初始化为目标大小
    w, h = img_size, img_size
    # 根据宽高比调整宽度或高度,使其为8的倍数
    if coef >= 1:
        w = int(round(img_size / 8 * coef) * 8)
    else:
        h = int(round(img_size / 8 / coef) * 8)
    # 使用指定的插值方法调整图像大小
    images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
    # 返回调整大小后的图像
    return images
# 示例文档字符串,用于说明功能
EXAMPLE_DOC_STRING = """
    # 示例代码
    Examples:
        ```py
        # 导入需要的库和模块
        >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
        >>> from diffusers.utils import pt_to_pil
        >>> import torch
        >>> from PIL import Image
        >>> import requests
        >>> from io import BytesIO
        # 定义图像的URL
        >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
        # 发送请求获取图像数据
        >>> response = requests.get(url)
        # 打开图像数据并转换为RGB格式
        >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
        # 调整图像大小为768x512
        >>> original_image = original_image.resize((768, 512))
        # 从预训练模型加载图像到图像管道
        >>> pipe = IFImg2ImgPipeline.from_pretrained(
        ...     "DeepFloyd/IF-I-XL-v1.0",
        ...     variant="fp16",
        ...     torch_dtype=torch.float16,
        ... )
        # 启用模型CPU卸载以节省内存
        >>> pipe.enable_model_cpu_offload()
        # 定义生成图像的提示语
        >>> prompt = "A fantasy landscape in style minecraft"
        # 编码提示语以获取正负嵌入
        >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
        # 生成图像
        >>> image = pipe(
        ...     image=original_image,
        ...     prompt_embeds=prompt_embeds,
        ...     negative_prompt_embeds=negative_embeds,
        ...     output_type="pt",
        ... ).images
        # 保存中间生成的图像
        >>> pil_image = pt_to_pil(image)
        >>> pil_image[0].save("./if_stage_I.png")
        # 从预训练模型加载超分辨率管道
        >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
        ...     "DeepFloyd/IF-II-L-v1.0",
        ...     text_encoder=None,
        ...     variant="fp16",
        ...     torch_dtype=torch.float16,
        ... )
        # 启用模型CPU卸载以节省内存
        >>> super_res_1_pipe.enable_model_cpu_offload()
        # 进行超分辨率处理
        >>> image = super_res_1_pipe(
        ...     image=image,
        ...     original_image=original_image,
        ...     prompt_embeds=prompt_embeds,
        ...     negative_prompt_embeds=negative_embeds,
        ... ).images
        # 保存最终生成的超分辨率图像
        >>> image[0].save("./if_stage_II.png")
        ```py
# 定义一个图像到图像的扩散管道类,继承自DiffusionPipeline和StableDiffusionLoraLoaderMixin
class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
    # 定义用于文本处理的分词器
    tokenizer: T5Tokenizer
    # 定义用于编码文本的模型
    text_encoder: T5EncoderModel
    # 定义条件生成的UNet模型
    unet: UNet2DConditionModel
    # 定义扩散调度器
    scheduler: DDPMScheduler
    # 可选的特征提取器
    feature_extractor: Optional[CLIPImageProcessor]
    # 可选的安全检查器
    safety_checker: Optional[IFSafetyChecker]
    # 可选的水印器
    watermarker: Optional[IFWatermarker]
    # 定义不良标点的正则表达式
    bad_punct_regex = re.compile(
        r"["
        + "#®•©™&@·º½¾¿¡§~"
        + r"\)"
        + r"\("
        + r"\]"
        + r"\["
        + r"\}"
        + r"\{"
        + r"\|"
        + "\\"
        + r"\/"
        + r"\*"
        + r"]{1,}"
    )  # noqa
    # 定义可选组件列表
    _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
    # 定义模型CPU卸载的顺序
    model_cpu_offload_seq = "text_encoder->unet"
    # 定义不参与CPU卸载的组件
    _exclude_from_cpu_offload = ["watermarker"]
    # 初始化方法,接收多个参数以配置模型
    def __init__(
        self,
        tokenizer: T5Tokenizer,
        text_encoder: T5EncoderModel,
        unet: UNet2DConditionModel,
        scheduler: DDPMScheduler,
        safety_checker: Optional[IFSafetyChecker],
        feature_extractor: Optional[CLIPImageProcessor],
        watermarker: Optional[IFWatermarker],
        requires_safety_checker: bool = True,
    ):
        # 调用父类构造函数
        super().__init__()
        # 检查是否禁用安全检查器并发出警告
        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the IF license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )
        # 检查是否缺少特征提取器并引发错误
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )
        # 注册模块,包括分词器、文本编码器、UNet、调度器等
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            watermarker=watermarker,
        )
        # 将需要的安全检查器配置注册到配置中
        self.register_to_config(requires_safety_checker=requires_safety_checker)
    # 设定不需要梯度计算的上下文
    @torch.no_grad()
    # 定义一个用于编码提示的函数,接收多种输入参数
        def encode_prompt(
            self,
            # 提示内容,可以是字符串或字符串列表
            prompt: Union[str, List[str]],
            # 是否使用无分类器自由引导的标志,默认为 True
            do_classifier_free_guidance: bool = True,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: int = 1,
            # 计算设备,默认为 None
            device: Optional[torch.device] = None,
            # 负面提示内容,可以是字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 提示的张量表示,默认为 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 负面提示的张量表示,默认为 None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 是否清理标题,默认为 False
            clean_caption: bool = False,
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker 复制而来
        def run_safety_checker(self, image, device, dtype):
            # 如果存在安全检查器,则执行安全检查
            if self.safety_checker is not None:
                # 将图像转换为 PIL 格式,并提取特征
                safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
                # 使用安全检查器检测图像中的不安全内容和水印
                image, nsfw_detected, watermark_detected = self.safety_checker(
                    images=image,
                    clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
                )
            else:
                # 如果没有安全检查器,则不安全检测返回 None
                nsfw_detected = None
                watermark_detected = None
    
            # 返回经过检查的图像及检测结果
            return image, nsfw_detected, watermark_detected
    
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs 复制而来
        def prepare_extra_step_kwargs(self, generator, eta):
            # 准备调度器步骤的额外参数,因为不是所有调度器的签名相同
            # eta (η) 仅用于 DDIMScheduler,其他调度器将被忽略
            # eta 对应于 DDIM 论文中的 η,范围在 [0, 1] 之间
    
            # 检查调度器步骤是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 初始化额外参数字典
            extra_step_kwargs = {}
            if accepts_eta:
                # 如果接受 eta,则将其添加到额外参数中
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器步骤是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            if accepts_generator:
                # 如果接受 generator,则将其添加到额外参数中
                extra_step_kwargs["generator"] = generator
            # 返回准备好的额外参数
            return extra_step_kwargs
    
        # 定义一个检查输入参数的函数
        def check_inputs(
            self,
            # 提示内容
            prompt,
            # 输入图像
            image,
            # 批处理大小
            batch_size,
            # 回调步骤
            callback_steps,
            # 负面提示,默认为 None
            negative_prompt=None,
            # 提示的张量表示,默认为 None
            prompt_embeds=None,
            # 负面提示的张量表示,默认为 None
            negative_prompt_embeds=None,
    # 该部分代码用于验证输入参数的有效性
        ):
            # 检查回调步数是否有效
            if (callback_steps is None) or (
                callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
            ):
                # 抛出异常,提示回调步数必须为正整数
                raise ValueError(
                    f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                    f" {type(callback_steps)}."
                )
    
            # 检查是否同时提供了 prompt 和 prompt_embeds
            if prompt is not None and prompt_embeds is not None:
                # 抛出异常,提示不能同时提供这两个参数
                raise ValueError(
                    f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                    " only forward one of the two."
                )
            # 检查是否两个参数都未提供
            elif prompt is None and prompt_embeds is None:
                # 抛出异常,提示至少需要提供一个参数
                raise ValueError(
                    "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
                )
            # 检查 prompt 的类型
            elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                # 抛出异常,提示 prompt 必须是字符串或列表类型
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
            # 检查负面提示和其嵌入是否同时提供
            if negative_prompt is not None and negative_prompt_embeds is not None:
                # 抛出异常,提示不能同时提供这两个参数
                raise ValueError(
                    f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                    f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
                )
    
            # 检查嵌入的形状是否匹配
            if prompt_embeds is not None and negative_prompt_embeds is not None:
                if prompt_embeds.shape != negative_prompt_embeds.shape:
                    # 抛出异常,提示两个嵌入的形状必须一致
                    raise ValueError(
                        "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                        f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                        f" {negative_prompt_embeds.shape}."
                    )
    
            # 检查图像类型
            if isinstance(image, list):
                check_image_type = image[0]
            else:
                check_image_type = image
    
            # 验证图像类型是否有效
            if (
                not isinstance(check_image_type, torch.Tensor)
                and not isinstance(check_image_type, PIL.Image.Image)
                and not isinstance(check_image_type, np.ndarray)
            ):
                # 抛出异常,提示图像类型无效
                raise ValueError(
                    "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
                    f" {type(check_image_type)}"
                )
    
            # 根据图像类型确定批量大小
            if isinstance(image, list):
                image_batch_size = len(image)
            elif isinstance(image, torch.Tensor):
                image_batch_size = image.shape[0]
            elif isinstance(image, PIL.Image.Image):
                image_batch_size = 1
            elif isinstance(image, np.ndarray):
                image_batch_size = image.shape[0]
            else:
                # 断言无效的图像类型
                assert False
    
            # 检查批量大小是否一致
            if batch_size != image_batch_size:
                # 抛出异常,提示图像批量大小与提示批量大小不一致
                raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制的代码
    def _text_preprocessing(self, text, clean_caption=False):
        # 检查是否启用清理字幕,且 bs4 模块不可用
        if clean_caption and not is_bs4_available():
            # 记录警告,提示用户缺少 bs4 模块
            logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
            # 记录警告,自动将 clean_caption 设置为 False
            logger.warning("Setting `clean_caption` to False...")
            clean_caption = False  # 更新 clean_caption 状态
        # 检查是否启用清理字幕,且 ftfy 模块不可用
        if clean_caption and not is_ftfy_available():
            # 记录警告,提示用户缺少 ftfy 模块
            logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
            # 记录警告,自动将 clean_caption 设置为 False
            logger.warning("Setting `clean_caption` to False...")
            clean_caption = False  # 更新 clean_caption 状态
        # 如果 text 不是元组或列表,转换为列表
        if not isinstance(text, (tuple, list)):
            text = [text]  # 将单一文本包裹成列表
        # 定义处理文本的内部函数
        def process(text: str):
            # 如果启用清理字幕,进行两次清理操作
            if clean_caption:
                text = self._clean_caption(text)
                text = self._clean_caption(text)
            else:
                # 否则,将文本转换为小写并去除空白
                text = text.lower().strip()
            return text  # 返回处理后的文本
        # 对列表中的每个文本进行处理,并返回结果
        return [process(t) for t in text]  # 返回处理后的文本列表
    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption 复制的代码
    def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
        # 如果输入不是列表,转换为列表
        if not isinstance(image, list):
            image = [image]  # 将单一图像包裹成列表
        # 定义将 NumPy 数组转换为 PyTorch 张量的内部函数
        def numpy_to_pt(images):
            # 如果图像维度为 3,增加最后一个维度
            if images.ndim == 3:
                images = images[..., None]
            # 转换为 PyTorch 张量并调整维度顺序
            images = torch.from_numpy(images.transpose(0, 3, 1, 2))
            return images  # 返回转换后的张量
        # 如果图像是 PIL 图像实例
        if isinstance(image[0], PIL.Image.Image):
            new_image = []  # 创建新的图像列表
            # 遍历每个图像进行处理
            for image_ in image:
                image_ = image_.convert("RGB")  # 转换为 RGB 格式
                image_ = resize(image_, self.unet.config.sample_size)  # 调整图像大小
                image_ = np.array(image_)  # 转换为 NumPy 数组
                image_ = image_.astype(np.float32)  # 转换数据类型为 float32
                image_ = image_ / 127.5 - 1  # 归一化到 [-1, 1] 范围
                new_image.append(image_)  # 添加处理后的图像到列表
            image = new_image  # 更新为处理后的图像列表
            # 将图像列表堆叠为 NumPy 数组
            image = np.stack(image, axis=0)  # to np
            # 转换为 PyTorch 张量
            image = numpy_to_pt(image)  # to pt
        # 如果输入图像是 NumPy 数组
        elif isinstance(image[0], np.ndarray):
            # 根据维度将图像合并或堆叠
            image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
            image = numpy_to_pt(image)  # 转换为 PyTorch 张量
        # 如果输入图像是 PyTorch 张量
        elif isinstance(image[0], torch.Tensor):
            # 根据维度将图像合并或堆叠
            image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
        return image  # 返回处理后的图像
    # 获取时间步长的函数,基于推理步骤和强度参数
        def get_timesteps(self, num_inference_steps, strength):
            # 根据初始时间步和强度计算最小的时间步
            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
            # 计算开始的时间步,确保不小于0
            t_start = max(num_inference_steps - init_timestep, 0)
            # 从调度器中获取时间步,从t_start开始,按照调度器的顺序
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果调度器有设置开始索引的方法,则调用它
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回时间步和有效推理步骤的数量
            return timesteps, num_inference_steps - t_start
    
        # 准备中间图像的函数
        def prepare_intermediate_images(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
        ):
            # 获取输入图像的维度信息
            _, channels, height, width = image.shape
    
            # 计算有效的批量大小
            batch_size = batch_size * num_images_per_prompt
    
            # 设置图像的目标形状
            shape = (batch_size, channels, height, width)
    
            # 检查生成器的长度是否与批量大小匹配
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            # 生成随机噪声张量
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    
            # 按照每个提示重复输入图像
            image = image.repeat_interleave(num_images_per_prompt, dim=0)
            # 向图像中添加噪声
            image = self.scheduler.add_noise(image, noise, timestep)
    
            # 返回处理后的图像
            return image
    
        # 调用函数,进行图像生成
        @torch.no_grad()
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        def __call__(
            self,
            prompt: Union[str, List[str]] = None,
            # 输入图像可以是多种格式
            image: Union[
                PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
            ] = None,
            # 设置强度参数
            strength: float = 0.7,
            # 设置推理步骤数量
            num_inference_steps: int = 80,
            # 时间步列表,默认值为None
            timesteps: List[int] = None,
            # 设置引导比例
            guidance_scale: float = 10.0,
            # 可选的负面提示
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量
            num_images_per_prompt: Optional[int] = 1,
            # 设置eta参数
            eta: float = 0.0,
            # 生成器参数,可选
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 提示的嵌入表示,可选
            prompt_embeds: Optional[torch.Tensor] = None,
            # 负面提示的嵌入表示,可选
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 输出类型,默认为'pil'
            output_type: Optional[str] = "pil",
            # 是否返回字典格式的结果
            return_dict: bool = True,
            # 可选的回调函数
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 回调的步长设置
            callback_steps: int = 1,
            # 是否清理提示
            clean_caption: bool = True,
            # 交叉注意力的额外参数,可选
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
.\diffusers\pipelines\deepfloyd_if\pipeline_if_img2img_superresolution.py
# 导入标准库中的 html 模块,用于处理 HTML 字符串
import html
# 导入 inspect 模块,用于获取对象的信息
import inspect
# 导入正则表达式模块,用于字符串模式匹配
import re
# 导入 urllib.parse 模块并命名为 ul,用于处理 URL
import urllib.parse as ul
# 从 typing 模块导入各种类型提示
from typing import Any, Callable, Dict, List, Optional, Union
# 导入 numpy 库,常用于数值计算
import numpy as np
# 导入 PIL.Image 模块,用于图像处理
import PIL.Image
# 导入 torch 库,深度学习框架
import torch
# 从 torch.nn.functional 导入 F,提供常用的神经网络功能
import torch.nn.functional as F
# 从 transformers 导入 CLIPImageProcessor, T5EncoderModel, T5Tokenizer,用于自然语言处理和图像处理
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
# 从 loaders 模块导入 StableDiffusionLoraLoaderMixin 类
from ...loaders import StableDiffusionLoraLoaderMixin
# 从 models 模块导入 UNet2DConditionModel 类
from ...models import UNet2DConditionModel
# 从 schedulers 模块导入 DDPMScheduler 类
from ...schedulers import DDPMScheduler
# 从 utils 模块导入多个工具函数和常量
from ...utils import (
    BACKENDS_MAPPING,  # 后端映射
    PIL_INTERPOLATION,  # PIL 插值方式
    is_bs4_available,  # 检查 BeautifulSoup 是否可用
    is_ftfy_available,  # 检查 ftfy 是否可用
    logging,  # 日志记录工具
    replace_example_docstring,  # 替换示例文档字符串的工具
)
# 从 torch_utils 模块导入 randn_tensor 函数
from ...utils.torch_utils import randn_tensor
# 从 pipeline_utils 模块导入 DiffusionPipeline 类
from ..pipeline_utils import DiffusionPipeline
# 从 pipeline_output 模块导入 IFPipelineOutput 类
from .pipeline_output import IFPipelineOutput
# 从 safety_checker 模块导入 IFSafetyChecker 类
from .safety_checker import IFSafetyChecker
# 从 watermark 模块导入 IFWatermarker 类
from .watermark import IFWatermarker
# 如果 bs4 可用,导入 BeautifulSoup 类用于解析 HTML 文档
if is_bs4_available():
    from bs4 import BeautifulSoup
# 如果 ftfy 可用,导入 ftfy 模块用于处理文本
if is_ftfy_available():
    import ftfy
# 创建一个 logger 实例,用于日志记录,禁用无效名称警告
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
# 定义一个函数,进行图像的大小调整
def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
    # 获取图像的宽和高
    w, h = images.size
    # 计算图像的宽高比
    coef = w / h
    # 将宽和高都设置为目标尺寸
    w, h = img_size, img_size
    # 根据宽高比调整宽或高,使其为 8 的倍数
    if coef >= 1:
        w = int(round(img_size / 8 * coef) * 8)
    else:
        h = int(round(img_size / 8 / coef) * 8)
    # 调整图像大小,使用双立方插值法
    images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
    # 返回调整后的图像
    return images
# 示例文档字符串,可能用于函数或类的文档说明
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        # 导入所需的库和模块
        >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
        >>> from diffusers.utils import pt_to_pil
        >>> import torch
        >>> from PIL import Image
        >>> import requests
        >>> from io import BytesIO
        # 定义图片的 URL
        >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
        # 发送 GET 请求以获取图像
        >>> response = requests.get(url)
        # 将响应内容转换为 PIL 图像并转换为 RGB 模式
        >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
        # 调整图像大小为 768x512 像素
        >>> original_image = original_image.resize((768, 512))
        # 从预训练模型加载图像到图像转换管道
        >>> pipe = IFImg2ImgPipeline.from_pretrained(
        ...     "DeepFloyd/IF-I-XL-v1.0",
        ...     variant="fp16",  # 使用半精度浮点格式
        ...     torch_dtype=torch.float16,  # 设置 PyTorch 的数据类型
        ... )
        # 启用模型的 CPU 离线处理以节省内存
        >>> pipe.enable_model_cpu_offload()
        # 定义生成图像的提示
        >>> prompt = "A fantasy landscape in style minecraft"
        # 对提示进行编码,获取正向和负向的嵌入向量
        >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
        # 使用管道生成图像
        >>> image = pipe(
        ...     image=original_image,  # 输入的原始图像
        ...     prompt_embeds=prompt_embeds,  # 正向提示嵌入
        ...     negative_prompt_embeds=negative_embeds,  # 负向提示嵌入
        ...     output_type="pt",  # 输出类型设置为 PyTorch 张量
        ... ).images  # 获取生成的图像列表
        # 将生成的中间图像保存为文件
        >>> pil_image = pt_to_pil(image)  # 将 PyTorch 张量转换为 PIL 图像
        >>> pil_image[0].save("./if_stage_I.png")  # 保存第一张图像
        # 从预训练模型加载超分辨率管道
        >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
        ...     "DeepFloyd/IF-II-L-v1.0",
        ...     text_encoder=None,  # 不使用文本编码器
        ...     variant="fp16",  # 使用半精度浮点格式
        ...     torch_dtype=torch.float16,  # 设置 PyTorch 的数据类型
        ... )
        # 启用模型的 CPU 离线处理以节省内存
        >>> super_res_1_pipe.enable_model_cpu_offload()
        # 使用超分辨率管道对生成的图像进行处理
        >>> image = super_res_1_pipe(
        ...     image=image,  # 输入的图像
        ...     original_image=original_image,  # 原始图像
        ...     prompt_embeds=prompt_embeds,  # 正向提示嵌入
        ...     negative_prompt_embeds=negative_embeds,  # 负向提示嵌入
        ... ).images  # 获取处理后的图像列表
        # 保存处理后的第一张图像
        >>> image[0].save("./if_stage_II.png")  # 保存为文件
        ```py
"""
# 文档字符串,通常用于描述类的功能和用途
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
    # 声明类的属性,包括 tokenizer 和 text_encoder 的类型
    tokenizer: T5Tokenizer
    text_encoder: T5EncoderModel
    # 声明 UNet 和调度器的类型
    unet: UNet2DConditionModel
    scheduler: DDPMScheduler
    image_noising_scheduler: DDPMScheduler
    # 可选的特征提取器和安全检查器
    feature_extractor: Optional[CLIPImageProcessor]
    safety_checker: Optional[IFSafetyChecker]
    # 可选的水印处理器
    watermarker: Optional[IFWatermarker]
    # 定义一个正则表达式,用于匹配不良标点符号
    bad_punct_regex = re.compile(
        r"["
        + "#®•©™&@·º½¾¿¡§~"
        + r"\)"
        + r"\("
        + r"\]"
        + r"\["
        + r"\}"
        + r"\{"
        + r"\|"
        + "\\"
        + r"\/"
        + r"\*"
        + r"]{1,}"
    )  # noqa
    # 可选组件列表,包含 tokenizer、text_encoder、safety_checker 和 feature_extractor
    _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"]
    # 定义模型的 CPU 卸载顺序
    model_cpu_offload_seq = "text_encoder->unet"
    # 排除水印器不参与 CPU 卸载
    _exclude_from_cpu_offload = ["watermarker"]
    # 构造函数,初始化类的各个属性
    def __init__(
        self,
        tokenizer: T5Tokenizer,  # tokenizer 用于文本处理
        text_encoder: T5EncoderModel,  # text_encoder 负责编码文本
        unet: UNet2DConditionModel,  # unet 处理图像生成
        scheduler: DDPMScheduler,  # scheduler 控制生成过程的时间步
        image_noising_scheduler: DDPMScheduler,  # image_noising_scheduler 处理图像噪声
        safety_checker: Optional[IFSafetyChecker],  # 可选的安全检查器
        feature_extractor: Optional[CLIPImageProcessor],  # 可选的特征提取器
        watermarker: Optional[IFWatermarker],  # 可选的水印处理器
        requires_safety_checker: bool = True,  # 是否需要安全检查器的布尔标志
    ):
        # 调用父类的构造函数进行初始化
        super().__init__()
        # 检查安全检查器是否为 None 且要求使用安全检查器
        if safety_checker is None and requires_safety_checker:
            # 记录警告信息,提示用户禁用了安全检查器
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the IF license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )
        # 检查安全检查器是否不为 None 且特征提取器为 None
        if safety_checker is not None and feature_extractor is None:
            # 抛出错误,提示用户需要定义特征提取器
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )
        # 检查 UNet 配置的输入通道数是否不等于 6
        if unet.config.in_channels != 6:
            # 记录警告信息,提示用户加载的检查点不适用于超分辨率
            logger.warning(
                "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
            )
        # 注册多个模块,包括 tokenizer、text_encoder 等
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            image_noising_scheduler=image_noising_scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            watermarker=watermarker,
        )
        # 将需要的配置注册到对象中
        self.register_to_config(requires_safety_checker=requires_safety_checker)
    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制的代码
    # 文本预处理方法,接受文本和清理标题的标志
    def _text_preprocessing(self, text, clean_caption=False):
        # 如果需要清理标题且未安装 BeautifulSoup4,则发出警告
        if clean_caption and not is_bs4_available():
            logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
            # 发出警告,提示将清理标题设置为 False
            logger.warning("Setting `clean_caption` to False...")
            # 设置清理标题为 False
            clean_caption = False
        # 如果需要清理标题且未安装 ftfy,则发出警告
        if clean_caption and not is_ftfy_available():
            logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
            # 发出警告,提示将清理标题设置为 False
            logger.warning("Setting `clean_caption` to False...")
            # 设置清理标题为 False
            clean_caption = False
        # 如果输入文本不是元组或列表,则将其转换为列表
        if not isinstance(text, (tuple, list)):
            text = [text]
        # 定义内部处理函数,接受一个字符串文本
        def process(text: str):
            # 如果需要清理标题,调用清理标题的方法
            if clean_caption:
                text = self._clean_caption(text)
                text = self._clean_caption(text)  # 再次调用以确保清理完全
            else:
                # 否则,将文本转换为小写并去除前后空格
                text = text.lower().strip()
            # 返回处理后的文本
            return text
        # 对输入文本列表中的每个文本应用处理函数,并返回结果列表
        return [process(t) for t in text]
    # 禁用梯度计算,提升性能
    @torch.no_grad()
    # 定义编码提示的方法,接受多个参数
    def encode_prompt(
        self,
        prompt: Union[str, List[str]],  # 输入的提示,可以是字符串或字符串列表
        do_classifier_free_guidance: bool = True,  # 是否进行无分类器引导
        num_images_per_prompt: int = 1,  # 每个提示生成的图像数量
        device: Optional[torch.device] = None,  # 指定设备
        negative_prompt: Optional[Union[str, List[str]]] = None,  # 可选的负提示
        prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入
        negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负提示嵌入
        clean_caption: bool = False,  # 是否清理标题
    # 定义运行安全检查的方法,接受图像、设备和数据类型参数
    def run_safety_checker(self, image, device, dtype):
        # 如果存在安全检查器,则进行安全检查
        if self.safety_checker is not None:
            # 使用特征提取器处理图像,转换为适合模型输入的格式
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            # 进行安全检查,返回检查后的图像及检测结果
            image, nsfw_detected, watermark_detected = self.safety_checker(
                images=image,
                clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
            )
        else:
            # 如果没有安全检查器,则设置检测结果为 None
            nsfw_detected = None
            watermark_detected = None
        # 返回检查后的图像及检测结果
        return image, nsfw_detected, watermark_detected
    # 准备额外步骤参数的方法,供其他方法使用
    # 准备额外的参数用于调度器步骤,不同调度器可能有不同的参数签名
        def prepare_extra_step_kwargs(self, generator, eta):
            # eta (η) 仅用于 DDIMScheduler,其他调度器将忽略该参数
            # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
            # 应在 [0, 1] 之间
    
            # 检查调度器是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 创建存放额外参数的字典
            extra_step_kwargs = {}
            # 如果调度器接受 eta,添加 eta 到额外参数字典
            if accepts_eta:
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 如果调度器接受 generator,添加 generator 到额外参数字典
            if accepts_generator:
                extra_step_kwargs["generator"] = generator
            # 返回额外参数字典
            return extra_step_kwargs
    
        # 检查输入参数的有效性
        def check_inputs(
            self,
            prompt,
            image,
            original_image,
            batch_size,
            callback_steps,
            negative_prompt=None,
            prompt_embeds=None,
            negative_prompt_embeds=None,
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image 复制而来,将 preprocess_image 替换为 preprocess_original_image
        def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor:
            # 如果输入不是列表,则将其转换为列表
            if not isinstance(image, list):
                image = [image]
    
            # 定义将 numpy 数组转换为 PyTorch 张量的函数
            def numpy_to_pt(images):
                # 如果图像是 3 维,则添加一个维度
                if images.ndim == 3:
                    images = images[..., None]
                # 转换图像格式并创建 PyTorch 张量
                images = torch.from_numpy(images.transpose(0, 3, 1, 2))
                return images
    
            # 如果输入图像是 PIL 图像类型
            if isinstance(image[0], PIL.Image.Image):
                new_image = []
                # 遍历每个图像进行处理
                for image_ in image:
                    # 转换图像为 RGB 格式
                    image_ = image_.convert("RGB")
                    # 调整图像大小
                    image_ = resize(image_, self.unet.config.sample_size)
                    # 转换为 numpy 数组
                    image_ = np.array(image_)
                    # 转换数据类型为 float32
                    image_ = image_.astype(np.float32)
                    # 标准化图像数据
                    image_ = image_ / 127.5 - 1
                    # 将处理后的图像添加到新列表中
                    new_image.append(image_)
    
                # 将新图像列表堆叠为 numpy 数组
                image = np.stack(image, axis=0)  # 转换为 numpy 数组
                # 转换为 PyTorch 张量
                image = numpy_to_pt(image)  # 转换为张量
    
            # 如果输入是 numpy 数组类型
            elif isinstance(image[0], np.ndarray):
                # 根据维度合并多个 numpy 数组
                image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
                # 转换为 PyTorch 张量
                image = numpy_to_pt(image)
    
            # 如果输入是 PyTorch 张量类型
            elif isinstance(image[0], torch.Tensor):
                # 根据维度合并多个张量
                image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
    
            # 返回处理后的图像
            return image
    
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image 复制而来
    # 处理输入图像,将其预处理为适合模型的格式
        def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor:
            # 检查输入是否为张量或列表,如果不是,则将其转换为列表
            if not isinstance(image, torch.Tensor) and not isinstance(image, list):
                image = [image]
    
            # 如果列表中的第一个元素是 PIL 图像,转换为 NumPy 数组并归一化
            if isinstance(image[0], PIL.Image.Image):
                image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image]
    
                # 将列表中的数组堆叠成一个 NumPy 数组
                image = np.stack(image, axis=0)  # to np
                # 将 NumPy 数组转换为 PyTorch 张量,并调整维度顺序
                image = torch.from_numpy(image.transpose(0, 3, 1, 2))
            # 如果列表中的第一个元素是 NumPy 数组,直接堆叠
            elif isinstance(image[0], np.ndarray):
                image = np.stack(image, axis=0)  # to np
                # 如果数组是五维,取第一维
                if image.ndim == 5:
                    image = image[0]
    
                # 将 NumPy 数组转换为 PyTorch 张量,并调整维度顺序
                image = torch.from_numpy(image.transpose(0, 3, 1, 2))
            # 如果列表中的第一个元素是 PyTorch 张量,检查维度
            elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
                dims = image[0].ndim
    
                # 三维张量堆叠
                if dims == 3:
                    image = torch.stack(image, dim=0)
                # 四维张量连接
                elif dims == 4:
                    image = torch.concat(image, dim=0)
                # 维度不匹配时引发错误
                else:
                    raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
    
            # 将图像移动到指定设备并设置数据类型
            image = image.to(device=device, dtype=self.unet.dtype)
    
            # 重复图像以匹配每个提示的图像数量
            image = image.repeat_interleave(num_images_per_prompt, dim=0)
    
            # 返回处理后的图像张量
            return image
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps 复制
        def get_timesteps(self, num_inference_steps, strength):
            # 计算初始时间步长,确保不超过总步长
            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
            # 计算开始时间步
            t_start = max(num_inference_steps - init_timestep, 0)
            # 获取调度器的时间步列表
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果调度器具有设置开始索引的方法,则调用该方法
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回时间步和有效的推理步长
            return timesteps, num_inference_steps - t_start
    
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images 复制
        def prepare_intermediate_images(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
        ):
            # 解构图像的维度信息
            _, channels, height, width = image.shape
    
            # 计算有效的批量大小
            batch_size = batch_size * num_images_per_prompt
    
            # 创建新的形状元组
            shape = (batch_size, channels, height, width)
    
            # 检查生成器列表的长度是否与批量大小匹配
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            # 生成随机噪声张量
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    
            # 重复图像以匹配每个提示的图像数量
            image = image.repeat_interleave(num_images_per_prompt, dim=0)
            # 将噪声添加到图像中
            image = self.scheduler.add_noise(image, noise, timestep)
    
            # 返回添加噪声后的图像
            return image
    
        # 禁用梯度计算,以减少内存使用
        @torch.no_grad()
        # 替换示例文档字符串
        @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义可调用类的 __call__ 方法,允许该类实例像函数一样被调用
        def __call__(
            self,
            # 输入的图像,可以是 PIL 图像、NumPy 数组或 PyTorch 张量
            image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
            # 原始图像,可以是多种格式,默认为 None
            original_image: Union[
                PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
            ] = None,
            # 图像处理的强度,默认为 0.8
            strength: float = 0.8,
            # 提示词,可以是单个字符串或字符串列表,默认为 None
            prompt: Union[str, List[str]] = None,
            # 推理步骤的数量,默认为 50
            num_inference_steps: int = 50,
            # 时间步的列表,默认为 None
            timesteps: List[int] = None,
            # 指导比例,默认为 4.0
            guidance_scale: float = 4.0,
            # 负提示词,可以是单个字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示词生成的图像数量,默认为 1
            num_images_per_prompt: Optional[int] = 1,
            # 控制随机性的参数,默认为 0.0
            eta: float = 0.0,
            # 生成器,用于控制随机性,可以是单个或多个 PyTorch 生成器,默认为 None
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 提示词的嵌入表示,默认为 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 负提示词的嵌入表示,默认为 None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 输出类型,默认为 "pil"(PIL 图像)
            output_type: Optional[str] = "pil",
            # 是否返回字典格式的输出,默认为 True
            return_dict: bool = True,
            # 回调函数,在处理过程中可以调用,默认为 None
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 回调函数调用的步数,默认为 1
            callback_steps: int = 1,
            # 跨注意力的关键字参数,默认为 None
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 噪声水平,默认为 250
            noise_level: int = 250,
            # 是否清理提示词,默认为 True
            clean_caption: bool = True,
.\diffusers\pipelines\deepfloyd_if\pipeline_if_inpainting.py
# 导入 HTML 处理库
import html
# 导入检查对象信息的库
import inspect
# 导入正则表达式库
import re
# 导入 urllib 的解析模块并重命名为 ul
import urllib.parse as ul
# 从 typing 导入多种类型提示工具
from typing import Any, Callable, Dict, List, Optional, Union
# 导入 numpy 库并重命名为 np
import numpy as np
# 导入图像处理库 PIL
import PIL.Image
# 导入 PyTorch 库
import torch
# 从 transformers 导入 CLIP 图像处理器、T5 编码模型和 T5 分词器
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
# 从 loaders 导入 StableDiffusionLoraLoaderMixin 类
from ...loaders import StableDiffusionLoraLoaderMixin
# 从 models 导入 UNet2DConditionModel 类
from ...models import UNet2DConditionModel
# 从 schedulers 导入 DDPMScheduler 类
from ...schedulers import DDPMScheduler
# 从 utils 导入多个实用工具
from ...utils import (
    BACKENDS_MAPPING,          # 后端映射
    PIL_INTERPOLATION,        # PIL 插值方法
    is_bs4_available,         # 检查 bs4 可用性
    is_ftfy_available,        # 检查 ftfy 可用性
    logging,                  # 日志记录工具
    replace_example_docstring, # 替换示例文档字符串的工具
)
# 从 utils.torch_utils 导入 randn_tensor 函数
from ...utils.torch_utils import randn_tensor
# 从 pipeline_utils 导入 DiffusionPipeline 类
from ..pipeline_utils import DiffusionPipeline
# 从 pipeline_output 导入 IFPipelineOutput 类
from .pipeline_output import IFPipelineOutput
# 从 safety_checker 导入 IFSafetyChecker 类
from .safety_checker import IFSafetyChecker
# 从 watermark 导入 IFWatermarker 类
from .watermark import IFWatermarker
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
# 如果 bs4 可用,则导入 BeautifulSoup 类
if is_bs4_available():
    from bs4 import BeautifulSoup
# 如果 ftfy 可用,则导入 ftfy 库
if is_ftfy_available():
    import ftfy
# 从 diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize 复制的 resize 函数
def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
    # 获取输入图像的宽度和高度
    w, h = images.size
    # 计算宽高比
    coef = w / h
    # 将宽度和高度都设置为目标大小
    w, h = img_size, img_size
    # 根据宽高比调整宽度或高度
    if coef >= 1:
        w = int(round(img_size / 8 * coef) * 8)  # 调整宽度为最接近的8的倍数
    else:
        h = int(round(img_size / 8 / coef) * 8)  # 调整高度为最接近的8的倍数
    # 调整图像大小,使用双三次插值法
    images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
    # 返回调整后的图像
    return images
# 示例文档字符串
EXAMPLE_DOC_STRING = """
    # 示例代码说明如何使用图像修复和超分辨率模型
        Examples:
            ```py
            # 导入所需的库和模块
            >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
            >>> from diffusers.utils import pt_to_pil
            >>> import torch
            >>> from PIL import Image
            >>> import requests
            >>> from io import BytesIO
    
            # 定义图像的URL
            >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
            # 发送GET请求获取图像
            >>> response = requests.get(url)
            # 打开图像并转换为RGB格式
            >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
            # 将原始图像赋值给变量
            >>> original_image = original_image
    
            # 定义掩膜图像的URL
            >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
            # 发送GET请求获取掩膜图像
            >>> response = requests.get(url)
            # 打开掩膜图像
            >>> mask_image = Image.open(BytesIO(response.content))
            # 将掩膜图像赋值给变量
            >>> mask_image = mask_image
    
            # 从预训练模型加载图像修复管道
            >>> pipe = IFInpaintingPipeline.from_pretrained(
            ...     "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
            ... )
            # 启用模型的CPU卸载功能以节省内存
            >>> pipe.enable_model_cpu_offload()
    
            # 定义图像修复的提示
            >>> prompt = "blue sunglasses"
            # 对提示进行编码,生成提示嵌入和负面嵌入
            >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
    
            # 使用管道进行图像修复
            >>> image = pipe(
            ...     image=original_image,
            ...     mask_image=mask_image,
            ...     prompt_embeds=prompt_embeds,
            ...     negative_prompt_embeds=negative_embeds,
            ...     output_type="pt",
            ... ).images
    
            # 保存中间图像为文件
            >>> # save intermediate image
            >>> pil_image = pt_to_pil(image)
            >>> pil_image[0].save("./if_stage_I.png")
    
            # 从预训练模型加载超分辨率管道
            >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
            ...     "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
            ... )
            # 启用模型的CPU卸载功能
            >>> super_res_1_pipe.enable_model_cpu_offload()
    
            # 使用超分辨率管道处理图像
            >>> image = super_res_1_pipe(
            ...     image=image,
            ...     mask_image=mask_image,
            ...     original_image=original_image,
            ...     prompt_embeds=prompt_embeds,
            ...     negative_prompt_embeds=negative_embeds,
            ... ).images
            # 保存最终图像为文件
            >>> image[0].save("./if_stage_II.png")
            ```py
"""
# 文档字符串,通常用于描述类的用途和功能
class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
    # 定义一个类,继承自 DiffusionPipeline 和 StableDiffusionLoraLoaderMixin
    tokenizer: T5Tokenizer
    # 声明一个属性 tokenizer,类型为 T5Tokenizer
    text_encoder: T5EncoderModel
    # 声明一个属性 text_encoder,类型为 T5EncoderModel
    unet: UNet2DConditionModel
    # 声明一个属性 unet,类型为 UNet2DConditionModel
    scheduler: DDPMScheduler
    # 声明一个属性 scheduler,类型为 DDPMScheduler
    feature_extractor: Optional[CLIPImageProcessor]
    # 声明一个可选属性 feature_extractor,类型为 CLIPImageProcessor
    safety_checker: Optional[IFSafetyChecker]
    # 声明一个可选属性 safety_checker,类型为 IFSafetyChecker
    watermarker: Optional[IFWatermarker]
    # 声明一个可选属性 watermarker,类型为 IFWatermarker
    # 使用正则表达式编译不良标点符号的模式
    bad_punct_regex = re.compile(
        r"["
        + "#®•©™&@·º½¾¿¡§~"
        + r"\)"
        + r"\("
        + r"\]"
        + r"\["
        + r"\}"
        + r"\{"
        + r"\|"
        + "\\"
        + r"\/"
        + r"\*"
        + r"]{1,}"
    )  # noqa
    # 定义一个可选组件列表,包含不同的组件名称
    _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
    # 定义模型的 CPU 卸载顺序
    model_cpu_offload_seq = "text_encoder->unet"
    # 定义需要从 CPU 卸载中排除的组件列表
    _exclude_from_cpu_offload = ["watermarker"]
    # 初始化方法,设置类的属性
    def __init__(
        self,
        tokenizer: T5Tokenizer,
        text_encoder: T5EncoderModel,
        unet: UNet2DConditionModel,
        scheduler: DDPMScheduler,
        safety_checker: Optional[IFSafetyChecker],
        feature_extractor: Optional[CLIPImageProcessor],
        watermarker: Optional[IFWatermarker],
        requires_safety_checker: bool = True,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果安全检查器为 None 且需要安全检查器,发出警告
        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the IF license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )
        # 如果安全检查器不为 None 且特征提取器为 None,抛出错误
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )
        # 注册模块,包括所有必要的组件
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            watermarker=watermarker,
        )
        # 注册配置,指定是否需要安全检查器
        self.register_to_config(requires_safety_checker=requires_safety_checker)
    @torch.no_grad()
    # 禁用梯度计算,通常用于推理阶段以节省内存
    # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
    # 定义一个方法用于编码输入提示
        def encode_prompt(
            self,
            # 输入的提示,可以是字符串或字符串列表
            prompt: Union[str, List[str]],
            # 是否进行无分类器自由引导,默认为 True
            do_classifier_free_guidance: bool = True,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: int = 1,
            # 指定设备,可选,默认为 None
            device: Optional[torch.device] = None,
            # 可选的负面提示,可以是字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 可选的提示嵌入,默认为 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 可选的负面提示嵌入,默认为 None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 是否清理提示,默认为 False
            clean_caption: bool = False,
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker 复制
        def run_safety_checker(self, image, device, dtype):
            # 如果存在安全检查器
            if self.safety_checker is not None:
                # 将图像转换为 PIL 格式并提取特征,返回张量,移动到指定设备
                safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
                # 使用安全检查器检测图像,返回处理后的图像及检测结果
                image, nsfw_detected, watermark_detected = self.safety_checker(
                    images=image,
                    # 获取安全检查器输入的像素值并转换数据类型
                    clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
                )
            else:
                # 如果没有安全检查器,设置检测结果为 None
                nsfw_detected = None
                watermark_detected = None
    
            # 返回处理后的图像及检测结果
            return image, nsfw_detected, watermark_detected
    
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs 复制
        def prepare_extra_step_kwargs(self, generator, eta):
            # 准备调度器步骤的额外参数,因为并非所有调度器的签名相同
            # eta (η) 仅在 DDIMScheduler 中使用,其他调度器将忽略
            # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
            # 值应在 [0, 1] 之间
    
            # 检查调度器是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 初始化额外参数字典
            extra_step_kwargs = {}
            # 如果接受 eta,将其添加到额外参数字典
            if accepts_eta:
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 如果接受 generator,将其添加到额外参数字典
            if accepts_generator:
                extra_step_kwargs["generator"] = generator
            # 返回额外参数字典
            return extra_step_kwargs
    
        # 定义一个方法用于检查输入参数
        def check_inputs(
            self,
            # 输入的提示
            prompt,
            # 输入的图像
            image,
            # 输入的掩码图像
            mask_image,
            # 批处理大小
            batch_size,
            # 回调步骤
            callback_steps,
            # 可选的负面提示
            negative_prompt=None,
            # 可选的提示嵌入
            prompt_embeds=None,
            # 可选的负面提示嵌入
            negative_prompt_embeds=None,
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制
    # 定义文本预处理函数,接受文本和清理标题标志
        def _text_preprocessing(self, text, clean_caption=False):
            # 如果设置清理标题且未安装 bs4,则记录警告并将标志设置为 False
            if clean_caption and not is_bs4_available():
                logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
                logger.warning("Setting `clean_caption` to False...")
                clean_caption = False
    
            # 如果设置清理标题且未安装 ftfy,则记录警告并将标志设置为 False
            if clean_caption and not is_ftfy_available():
                logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
                logger.warning("Setting `clean_caption` to False...")
                clean_caption = False
    
            # 如果输入文本不是元组或列表,则将其转换为列表
            if not isinstance(text, (tuple, list)):
                text = [text]
    
            # 定义内部处理函数,清理或标准化文本
            def process(text: str):
                if clean_caption:
                    text = self._clean_caption(text)  # 清理标题
                    text = self._clean_caption(text)  # 再次清理标题
                else:
                    text = text.lower().strip()  # 转小写并去除首尾空格
                return text
    
            # 返回处理后的文本列表
            return [process(t) for t in text]
    
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption 复制
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image 复制
        def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
            # 如果输入不是列表,则将其转换为列表
            if not isinstance(image, list):
                image = [image]
    
            # 定义将 numpy 数组转换为 PyTorch 张量的内部函数
            def numpy_to_pt(images):
                if images.ndim == 3:
                    images = images[..., None]  # 如果是 3 维,增加一个维度
    
                images = torch.from_numpy(images.transpose(0, 3, 1, 2))  # 转换并调整维度
                return images
    
            # 如果第一个元素是 PIL 图像
            if isinstance(image[0], PIL.Image.Image):
                new_image = []  # 初始化新图像列表
    
                # 遍历图像列表进行处理
                for image_ in image:
                    image_ = image_.convert("RGB")  # 转换为 RGB 格式
                    image_ = resize(image_, self.unet.config.sample_size)  # 调整大小
                    image_ = np.array(image_)  # 转换为 numpy 数组
                    image_ = image_.astype(np.float32)  # 转换为浮点型
                    image_ = image_ / 127.5 - 1  # 归一化到 [-1, 1]
                    new_image.append(image_)  # 添加到新图像列表
    
                image = new_image  # 更新图像为新列表
    
                image = np.stack(image, axis=0)  # 将列表转换为 numpy 数组
                image = numpy_to_pt(image)  # 转换为 PyTorch 张量
    
            # 如果第一个元素是 numpy 数组
            elif isinstance(image[0], np.ndarray):
                # 根据维度进行拼接或堆叠
                image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
                image = numpy_to_pt(image)  # 转换为 PyTorch 张量
    
            # 如果第一个元素是 PyTorch 张量
            elif isinstance(image[0], torch.Tensor):
                # 根据维度进行拼接或堆叠
                image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
    
            # 返回处理后的图像
            return image
    # 处理掩码图像,返回一个张量
        def preprocess_mask_image(self, mask_image) -> torch.Tensor:
            # 检查输入是否为列表,如果不是则转换为单元素列表
            if not isinstance(mask_image, list):
                mask_image = [mask_image]
    
            # 检查第一个元素是否为张量
            if isinstance(mask_image[0], torch.Tensor):
                # 如果是四维张量则在第0轴上拼接,否则堆叠
                mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0)
    
                # 如果是二维张量,添加批次和通道维度
                if mask_image.ndim == 2:
                    mask_image = mask_image.unsqueeze(0).unsqueeze(0)
                # 如果是三维张量且批次大小为1,添加批次维度
                elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
                    mask_image = mask_image.unsqueeze(0)
                # 如果是三维张量且批次大小不为1,添加通道维度
                elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
                    mask_image = mask_image.unsqueeze(1)
    
                # 将小于0.5的值设为0
                mask_image[mask_image < 0.5] = 0
                # 将大于等于0.5的值设为1
                mask_image[mask_image >= 0.5] = 1
    
            # 检查第一个元素是否为PIL图像
            elif isinstance(mask_image[0], PIL.Image.Image):
                new_mask_image = []
    
                # 遍历每个掩码图像进行处理
                for mask_image_ in mask_image:
                    # 将图像转换为灰度模式
                    mask_image_ = mask_image_.convert("L")
                    # 调整图像大小
                    mask_image_ = resize(mask_image_, self.unet.config.sample_size)
                    # 转换为numpy数组
                    mask_image_ = np.array(mask_image_)
                    # 添加批次和通道维度
                    mask_image_ = mask_image_[None, None, :]
                    new_mask_image.append(mask_image_)
    
                # 将所有处理后的掩码合并
                mask_image = new_mask_image
    
                # 在第0轴上拼接所有掩码图像
                mask_image = np.concatenate(mask_image, axis=0)
                # 将像素值缩放到[0, 1]
                mask_image = mask_image.astype(np.float32) / 255.0
                # 将小于0.5的值设为0
                mask_image[mask_image < 0.5] = 0
                # 将大于等于0.5的值设为1
                mask_image[mask_image >= 0.5] = 1
                # 转换为PyTorch张量
                mask_image = torch.from_numpy(mask_image)
    
            # 检查第一个元素是否为numpy数组
            elif isinstance(mask_image[0], np.ndarray):
                # 在第0轴上拼接所有掩码图像
                mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
    
                # 将小于0.5的值设为0
                mask_image[mask_image < 0.5] = 0
                # 将大于等于0.5的值设为1
                mask_image[mask_image >= 0.5] = 1
                # 转换为PyTorch张量
                mask_image = torch.from_numpy(mask_image)
    
            # 返回处理后的掩码图像
            return mask_image
    
        # 从diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps复制的函数
        def get_timesteps(self, num_inference_steps, strength):
            # 根据初始化时间步长计算原始时间步长
            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
            # 计算时间步长的起始索引
            t_start = max(num_inference_steps - init_timestep, 0)
            # 获取调度器中相应的时间步长
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果调度器具有设置起始索引的方法,则调用
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回时间步长和剩余的推理步骤
            return timesteps, num_inference_steps - t_start
    
        # 准备中间图像的函数
        def prepare_intermediate_images(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None
    ):
        # 获取输入图像的批大小、通道数、高度和宽度
        image_batch_size, channels, height, width = image.shape
        # 根据每个提示所需生成的图像数量调整批大小
        batch_size = batch_size * num_images_per_prompt
        # 定义图像的形状,包括调整后的批大小和通道、高度、宽度
        shape = (batch_size, channels, height, width)
        # 检查生成器是否为列表且其长度与批大小不匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 如果生成器的长度不等于批大小,抛出值错误
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )
        # 生成随机噪声张量,形状与输入图像匹配
        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        # 重复输入图像,生成所需的图像数量
        image = image.repeat_interleave(num_images_per_prompt, dim=0)
        # 将噪声添加到图像上,得到带噪声的图像
        noised_image = self.scheduler.add_noise(image, noise, timestep)
        # 根据掩膜图像合成原始图像和带噪声的图像
        image = (1 - mask_image) * image + mask_image * noised_image
        # 返回处理后的图像
        return image
    # 禁用梯度计算以节省内存和计算资源
    @torch.no_grad()
    # 替换示例文档字符串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        # 提示字符串或字符串列表,默认值为 None
        prompt: Union[str, List[str]] = None,
        # 输入图像,可以是多种格式,包括 PIL 图像、张量、数组等,默认值为 None
        image: Union[
            PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
        ] = None,
        # 掩膜图像,可以是多种格式,默认值为 None
        mask_image: Union[
            PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
        ] = None,
        # 强度参数,影响图像生成的程度,默认值为 1.0
        strength: float = 1.0,
        # 推理步骤数量,控制生成过程的迭代次数,默认值为 50
        num_inference_steps: int = 50,
        # 时间步列表,控制噪声添加的时刻,默认值为 None
        timesteps: List[int] = None,
        # 引导比例,影响生成图像与提示的相关性,默认值为 7.0
        guidance_scale: float = 7.0,
        # 负提示,指定不希望出现的提示,可以是字符串或字符串列表,默认值为 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 每个提示生成的图像数量,默认值为 1
        num_images_per_prompt: Optional[int] = 1,
        # η 参数,控制随机性,默认值为 0.0
        eta: float = 0.0,
        # 生成器,可以是单个生成器或生成器列表,默认值为 None
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 提示嵌入,预先计算的提示表示,默认值为 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 负提示嵌入,预先计算的负提示表示,默认值为 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 输出类型,指定返回的图像格式,默认值为 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典形式的结果,默认值为 True
        return_dict: bool = True,
        # 回调函数,接受当前步骤和生成图像的函数,默认值为 None
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
        # 回调步骤,控制回调函数调用的频率,默认值为 1
        callback_steps: int = 1,
        # 是否清理提示,默认值为 True
        clean_caption: bool = True,
        # 交叉注意力的额外参数,默认为 None
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,