diffusers-源码解析-二十六-

diffusers 源码解析(二十六)

.\diffusers\pipelines\deepfloyd_if\pipeline_if_inpainting_superresolution.py

# 导入 html 模块,用于处理 HTML 文本
import html
# 导入 inspect 模块,用于获取对象的信息
import inspect
# 导入 re 模块,用于正则表达式匹配
import re
# 导入 urllib.parse 模块并重命名为 ul,用于处理 URL 编码
import urllib.parse as ul
# 从 typing 模块导入类型提示相关的类
from typing import Any, Callable, Dict, List, Optional, Union

# 导入 numpy 库并重命名为 np,用于数组和数学计算
import numpy as np
# 导入 PIL.Image 模块,用于图像处理
import PIL.Image
# 导入 torch 库,用于深度学习操作
import torch
# 从 torch.nn.functional 导入 F,用于神经网络的功能操作
import torch.nn.functional as F
# 从 transformers 库导入 CLIPImageProcessor, T5EncoderModel 和 T5Tokenizer,用于处理图像和文本
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer

# 从本地模块导入 StableDiffusionLoraLoaderMixin,用于加载稳定扩散模型
from ...loaders import StableDiffusionLoraLoaderMixin
# 从本地模块导入 UNet2DConditionModel,用于2D条件模型
from ...models import UNet2DConditionModel
# 从本地模块导入 DDPMScheduler,用于扩散调度
from ...schedulers import DDPMScheduler
# 从本地模块导入多个实用工具函数
from ...utils import (
    BACKENDS_MAPPING,        # 后端映射
    PIL_INTERPOLATION,      # PIL 插值方式
    is_bs4_available,       # 检查 BeautifulSoup 是否可用
    is_ftfy_available,      # 检查 ftfy 是否可用
    logging,                # 日志记录模块
    replace_example_docstring,  # 替换示例文档字符串的函数
)
# 从本地工具模块导入 randn_tensor 函数,用于生成随机张量
from ...utils.torch_utils import randn_tensor
# 从本地模块导入 DiffusionPipeline,用于处理扩散管道
from ..pipeline_utils import DiffusionPipeline
# 从本地模块导入 IFPipelineOutput,用于扩散管道的输出
from .pipeline_output import IFPipelineOutput
# 从本地模块导入 IFSafetyChecker,用于安全检查
from .safety_checker import IFSafetyChecker
# 从本地模块导入 IFWatermarker,用于添加水印
from .watermark import IFWatermarker

# 如果 BeautifulSoup 可用,则导入 BeautifulSoup 类
if is_bs4_available():
    from bs4 import BeautifulSoup

# 如果 ftfy 可用,则导入 ftfy 模块
if is_ftfy_available():
    import ftfy

# 创建一个 logger 实例用于记录日志,命名为当前模块名
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 从 diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize 复制的函数
def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
    # 获取输入图像的宽度和高度
    w, h = images.size

    # 计算宽高比
    coef = w / h

    # 初始化宽高为目标尺寸
    w, h = img_size, img_size

    # 如果宽高比大于等于 1,则按比例调整宽度
    if coef >= 1:
        w = int(round(img_size / 8 * coef) * 8)  # 调整宽度为 8 的倍数
    else:
        # 如果宽高比小于 1,则按比例调整高度
        h = int(round(img_size / 8 / coef) * 8)  # 调整高度为 8 的倍数

    # 调整图像大小,使用双三次插值法
    images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)

    # 返回调整后的图像
    return images

# 示例文档字符串
EXAMPLE_DOC_STRING = """
    # 示例用法
    Examples:
        ```py
        # 导入必要的库和模块
        >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
        >>> from diffusers.utils import pt_to_pil
        >>> import torch
        >>> from PIL import Image
        >>> import requests
        >>> from io import BytesIO

        # 定义原始图像的 URL
        >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
        # 发送 GET 请求以获取图像
        >>> response = requests.get(url)
        # 将响应内容转换为 RGB 格式的图像
        >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
        # 将原始图像赋值给原始图像变量
        >>> original_image = original_image

        # 定义掩膜图像的 URL
        >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
        # 发送 GET 请求以获取掩膜图像
        >>> response = requests.get(url)
        # 将响应内容转换为图像
        >>> mask_image = Image.open(BytesIO(response.content))
        # 将掩膜图像赋值给掩膜图像变量
        >>> mask_image = mask_image

        # 从预训练模型加载图像修复管道
        >>> pipe = IFInpaintingPipeline.from_pretrained(
        ...     "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
        ... )
        # 启用模型 CPU 卸载功能以节省内存
        >>> pipe.enable_model_cpu_offload()

        # 定义提示文本
        >>> prompt = "blue sunglasses"

        # 编码提示文本为嵌入向量,包括正面和负面提示
        >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
        # 使用原始图像、掩膜图像和提示嵌入生成图像
        >>> image = pipe(
        ...     image=original_image,
        ...     mask_image=mask_image,
        ...     prompt_embeds=prompt_embeds,
        ...     negative_prompt_embeds=negative_embeds,
        ...     output_type="pt",
        ... ).images

        # 保存中间生成的图像
        >>> pil_image = pt_to_pil(image)
        # 将中间图像保存到文件
        >>> pil_image[0].save("./if_stage_I.png")

        # 从预训练模型加载超分辨率管道
        >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
        ...     "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
        ... )
        # 启用模型 CPU 卸载功能以节省内存
        >>> super_res_1_pipe.enable_model_cpu_offload()

        # 使用超分辨率管道生成最终图像
        >>> image = super_res_1_pipe(
        ...     image=image,
        ...     mask_image=mask_image,
        ...     original_image=original_image,
        ...     prompt_embeds=prompt_embeds,
        ...     negative_prompt_embeds=negative_embeds,
        ... ).images
        # 将最终图像保存到文件
        >>> image[0].save("./if_stage_II.png")
        ```py
    """
# 定义一个名为 IFInpaintingSuperResolutionPipeline 的类,继承自 DiffusionPipeline 和 StableDiffusionLoraLoaderMixin
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
    # 定义一个 tokenizer 属性,类型为 T5Tokenizer
    tokenizer: T5Tokenizer
    # 定义一个 text_encoder 属性,类型为 T5EncoderModel
    text_encoder: T5EncoderModel

    # 定义一个 unet 属性,类型为 UNet2DConditionModel
    unet: UNet2DConditionModel
    # 定义一个调度器 scheduler,类型为 DDPMScheduler
    scheduler: DDPMScheduler
    # 定义一个图像噪声调度器 image_noising_scheduler,类型为 DDPMScheduler
    image_noising_scheduler: DDPMScheduler

    # 可选的特征提取器,类型为 CLIPImageProcessor
    feature_extractor: Optional[CLIPImageProcessor]
    # 可选的安全检查器,类型为 IFSafetyChecker
    safety_checker: Optional[IFSafetyChecker]

    # 可选的水印处理器,类型为 IFWatermarker
    watermarker: Optional[IFWatermarker]

    # 定义一个正则表达式,用于匹配不良标点
    bad_punct_regex = re.compile(
        r"["
        + "#®•©™&@·º½¾¿¡§~"  # 包含特定特殊字符
        + r"\)"  # 匹配右括号
        + r"\("  # 匹配左括号
        + r"\]"  # 匹配右中括号
        + r"\["  # 匹配左中括号
        + r"\}"  # 匹配右花括号
        + r"\{"  # 匹配左花括号
        + r"\|"  # 匹配竖线
        + "\\"
        + r"\/"  # 匹配斜杠
        + r"\*"  # 匹配星号
        + r"]{1,}"  # 至少匹配一个以上的字符
    )  # noqa

    # 定义一个字符串,用于表示 CPU 卸载顺序
    model_cpu_offload_seq = "text_encoder->unet"
    # 定义一个可选组件列表
    _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
    # 定义一个不参与 CPU 卸载的组件列表
    _exclude_from_cpu_offload = ["watermarker"]

    # 初始化方法,接收多个参数以构建类的实例
    def __init__(
        self,
        # tokenizer 参数,类型为 T5Tokenizer
        tokenizer: T5Tokenizer,
        # text_encoder 参数,类型为 T5EncoderModel
        text_encoder: T5EncoderModel,
        # unet 参数,类型为 UNet2DConditionModel
        unet: UNet2DConditionModel,
        # scheduler 参数,类型为 DDPMScheduler
        scheduler: DDPMScheduler,
        # image_noising_scheduler 参数,类型为 DDPMScheduler
        image_noising_scheduler: DDPMScheduler,
        # 可选的安全检查器参数,类型为 IFSafetyChecker
        safety_checker: Optional[IFSafetyChecker],
        # 可选的特征提取器参数,类型为 CLIPImageProcessor
        feature_extractor: Optional[CLIPImageProcessor],
        # 可选的水印处理器参数,类型为 IFWatermarker
        watermarker: Optional[IFWatermarker],
        # 指示是否需要安全检查器的布尔值,默认为 True
        requires_safety_checker: bool = True,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 检查安全检查器是否为 None 且要求使用安全检查器
        if safety_checker is None and requires_safety_checker:
            # 记录警告,提示用户已禁用安全检查器
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the IF license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 检查安全检查器是否不为 None 且特征提取器为 None
        if safety_checker is not None and feature_extractor is None:
            # 抛出值错误,提示必须定义特征提取器
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        # 检查 UNet 配置的输入通道数是否不等于 6
        if unet.config.in_channels != 6:
            # 记录警告,提示用户加载了不适用于超分辨率的检查点
            logger.warning(
                "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
            )

        # 注册各个模块以便后续使用
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            image_noising_scheduler=image_noising_scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            watermarker=watermarker,
        )
        # 将配置注册到实例中,记录是否需要安全检查器
        self.register_to_config(requires_safety_checker=requires_safety_checker)

    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制
    # 定义一个文本预处理的私有方法,接收文本和是否清理标题的标志
    def _text_preprocessing(self, text, clean_caption=False):
        # 如果设置了清理标题但 bs4 库不可用,发出警告并将 clean_caption 设置为 False
        if clean_caption and not is_bs4_available():
            logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
            logger.warning("Setting `clean_caption` to False...")
            clean_caption = False

        # 如果设置了清理标题但 ftfy 库不可用,发出警告并将 clean_caption 设置为 False
        if clean_caption and not is_ftfy_available():
            logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
            logger.warning("Setting `clean_caption` to False...")
            clean_caption = False

        # 如果文本不是元组或列表,则将其包装成列表
        if not isinstance(text, (tuple, list)):
            text = [text]

        # 定义一个处理单个文本的内部函数
        def process(text: str):
            # 如果需要清理标题,则调用清理标题的方法
            if clean_caption:
                text = self._clean_caption(text)
                text = self._clean_caption(text)
            else:
                # 否则将文本转换为小写并去除首尾空白
                text = text.lower().strip()
            # 返回处理后的文本
            return text

        # 对文本列表中的每个元素应用处理函数,并返回处理后的结果列表
        return [process(t) for t in text]

    # 禁用梯度计算,以节省内存和提高性能
    @torch.no_grad()
    # 定义编码提示的方法,接收多个参数
    def encode_prompt(
        self,
        prompt: Union[str, List[str]],
        do_classifier_free_guidance: bool = True,
        num_images_per_prompt: int = 1,
        device: Optional[torch.device] = None,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        clean_caption: bool = False,
    # 定义运行安全检查的方法,接收图像、设备和数据类型
    def run_safety_checker(self, image, device, dtype):
        # 如果存在安全检查器
        if self.safety_checker is not None:
            # 使用特征提取器处理图像,并将结果转换为指定设备的张量
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            # 运行安全检查器,检测图像中的不当内容和水印
            image, nsfw_detected, watermark_detected = self.safety_checker(
                images=image,
                clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
            )
        else:
            # 如果没有安全检查器,则将检测结果设置为 None
            nsfw_detected = None
            watermark_detected = None

        # 返回处理后的图像和检测结果
        return image, nsfw_detected, watermark_detected

    # 这里是从其他模块复制的准备额外步骤参数的方法
    # 准备调度器步骤的额外参数,因为并非所有调度器都有相同的签名
    def prepare_extra_step_kwargs(self, generator, eta):
        # eta (η) 仅在 DDIMScheduler 中使用,其他调度器将忽略它
        # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
        # 应在 [0, 1] 范围内

        # 检查调度器的 step 方法参数中是否接受 eta
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化额外步骤参数字典
        extra_step_kwargs = {}
        # 如果接受 eta,则将 eta 添加到额外参数字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 检查调度器的 step 方法参数中是否接受 generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受 generator,则将 generator 添加到额外参数字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回准备好的额外参数字典
        return extra_step_kwargs

    # 检查输入参数的有效性
    def check_inputs(
        self,
        prompt,
        image,
        original_image,
        mask_image,
        batch_size,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image 复制而来,将 preprocess_image 修改为 preprocess_original_image
    def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor:
        # 如果输入的 image 不是列表,则将其转换为列表
        if not isinstance(image, list):
            image = [image]

        # 定义将 NumPy 数组转换为 PyTorch 张量的函数
        def numpy_to_pt(images):
            # 如果输入图像是 3D 数组,则在最后添加一个维度
            if images.ndim == 3:
                images = images[..., None]

            # 将 NumPy 数组转换为 PyTorch 张量并调整维度顺序
            images = torch.from_numpy(images.transpose(0, 3, 1, 2))
            return images

        # 如果第一个图像是 PIL 图像类型
        if isinstance(image[0], PIL.Image.Image):
            new_image = []

            # 遍历每个图像,进行转换和处理
            for image_ in image:
                # 将图像转换为 RGB 格式
                image_ = image_.convert("RGB")
                # 调整图像大小
                image_ = resize(image_, self.unet.config.sample_size)
                # 将图像转换为 NumPy 数组
                image_ = np.array(image_)
                # 转换数据类型为 float32
                image_ = image_.astype(np.float32)
                # 归一化图像数据到 [-1, 1] 范围
                image_ = image_ / 127.5 - 1
                # 将处理后的图像添加到新图像列表中
                new_image.append(image_)

            # 将新图像列表转换为 NumPy 数组
            image = new_image

            # 将图像堆叠成一个 NumPy 数组
            image = np.stack(image, axis=0)  # 转换为 NumPy 数组
            # 将 NumPy 数组转换为 PyTorch 张量
            image = numpy_to_pt(image)  # 转换为 PyTorch 张量

        # 如果第一个图像是 NumPy 数组
        elif isinstance(image[0], np.ndarray):
            # 如果数组是 4 维,则进行拼接,否则堆叠成一个数组
            image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
            # 将 NumPy 数组转换为 PyTorch 张量
            image = numpy_to_pt(image)

        # 如果第一个图像是 PyTorch 张量
        elif isinstance(image[0], torch.Tensor):
            # 如果张量是 4 维,则进行拼接,否则堆叠成一个张量
            image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)

        # 返回处理后的图像张量
        return image

    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image 复制而来
    # 定义图像预处理函数,接受图像、每个提示的图像数量和设备作为参数,返回处理后的张量
    def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor:
        # 检查输入的图像是否为张量或列表,如果不是则将其转换为列表
        if not isinstance(image, torch.Tensor) and not isinstance(image, list):
            image = [image]

        # 如果列表中的第一个元素是 PIL 图像
        if isinstance(image[0], PIL.Image.Image):
            # 将 PIL 图像转换为 NumPy 数组,并归一化到 [-1, 1] 范围
            image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image]

            # 将列表中的图像堆叠为 NumPy 数组,增加一个维度
            image = np.stack(image, axis=0)  # to np
            # 将 NumPy 数组转换为 PyTorch 张量,并调整维度顺序
            image = torch.from_numpy(image.transpose(0, 3, 1, 2))
        # 如果列表中的第一个元素是 NumPy 数组
        elif isinstance(image[0], np.ndarray):
            # 将列表中的图像堆叠为 NumPy 数组,增加一个维度
            image = np.stack(image, axis=0)  # to np
            # 如果图像是 5 维,取第一个元素
            if image.ndim == 5:
                image = image[0]

            # 将 NumPy 数组转换为 PyTorch 张量,并调整维度顺序
            image = torch.from_numpy(image.transpose(0, 3, 1, 2))
        # 如果输入是列表且第一个元素是张量
        elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
            # 获取第一个张量的维度
            dims = image[0].ndim

            # 如果是 3 维,沿第 0 维堆叠张量
            if dims == 3:
                image = torch.stack(image, dim=0)
            # 如果是 4 维,沿第 0 维连接张量
            elif dims == 4:
                image = torch.concat(image, dim=0)
            # 如果维度不是 3 或 4,抛出错误
            else:
                raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")

        # 将图像张量移动到指定设备,并设置数据类型
        image = image.to(device=device, dtype=self.unet.dtype)

        # 重复图像以匹配每个提示的图像数量
        image = image.repeat_interleave(num_images_per_prompt, dim=0)

        # 返回处理后的图像张量
        return image

    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline 复制的预处理掩码图像的代码
    # 预处理掩码图像,返回处理后的 PyTorch 张量
    def preprocess_mask_image(self, mask_image) -> torch.Tensor:
        # 检查掩码图像是否为列表,如果不是,则将其包装为列表
        if not isinstance(mask_image, list):
            mask_image = [mask_image]
    
        # 如果掩码图像的第一个元素是 PyTorch 张量
        if isinstance(mask_image[0], torch.Tensor):
            # 根据第一个张量的维度,选择合并(cat)或堆叠(stack)操作
            mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0)
    
            # 如果处理后的张量是二维
            if mask_image.ndim == 2:
                # 对于单个掩码,增加批次维度和通道维度
                mask_image = mask_image.unsqueeze(0).unsqueeze(0)
            # 如果处理后的张量是三维,且第一维大小为1
            elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
                # 单个掩码,认为第0维是批次大小为1
                mask_image = mask_image.unsqueeze(0)
            # 如果处理后的张量是三维,且第一维大小不为1
            elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
                # 一批掩码,认为第0维是批次维度
                mask_image = mask_image.unsqueeze(1)
    
            # 将掩码图像中小于0.5的值设为0
            mask_image[mask_image < 0.5] = 0
            # 将掩码图像中大于等于0.5的值设为1
            mask_image[mask_image >= 0.5] = 1
    
        # 如果掩码图像的第一个元素是 PIL 图像
        elif isinstance(mask_image[0], PIL.Image.Image):
            new_mask_image = []  # 创建一个新的列表以存储处理后的掩码图像
    
            # 遍历每个掩码图像
            for mask_image_ in mask_image:
                # 将掩码图像转换为灰度模式
                mask_image_ = mask_image_.convert("L")
                # 调整掩码图像的大小
                mask_image_ = resize(mask_image_, self.unet.config.sample_size)
                # 将图像转换为 NumPy 数组
                mask_image_ = np.array(mask_image_)
                # 增加批次和通道维度
                mask_image_ = mask_image_[None, None, :]
                # 将处理后的掩码图像添加到新列表中
                new_mask_image.append(mask_image_)
    
            # 将新处理的掩码图像列表合并为一个张量
            mask_image = new_mask_image
            # 沿第0维连接所有掩码图像
            mask_image = np.concatenate(mask_image, axis=0)
            # 将值转换为浮点数并归一化到[0, 1]
            mask_image = mask_image.astype(np.float32) / 255.0
            # 将掩码图像中小于0.5的值设为0
            mask_image[mask_image < 0.5] = 0
            # 将掩码图像中大于等于0.5的值设为1
            mask_image[mask_image >= 0.5] = 1
            # 将 NumPy 数组转换为 PyTorch 张量
            mask_image = torch.from_numpy(mask_image)
    
        # 如果掩码图像的第一个元素是 NumPy 数组
        elif isinstance(mask_image[0], np.ndarray):
            # 沿第0维连接每个掩码数组,并增加批次和通道维度
            mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
    
            # 将掩码图像中小于0.5的值设为0
            mask_image[mask_image < 0.5] = 0
            # 将掩码图像中大于等于0.5的值设为1
            mask_image[mask_image >= 0.5] = 1
            # 将 NumPy 数组转换为 PyTorch 张量
            mask_image = torch.from_numpy(mask_image)
    
        # 返回处理后的掩码图像
        return mask_image
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps 复制的代码
        def get_timesteps(self, num_inference_steps, strength):
            # 根据初始时间步计算原始时间步
            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
            # 计算开始时间步,确保不小于0
            t_start = max(num_inference_steps - init_timestep, 0)
            # 从调度器中获取时间步
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果调度器有设置开始索引的方法,则调用它
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回时间步和剩余的推理步骤数量
            return timesteps, num_inference_steps - t_start
    
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images 复制的代码
    # 准备中间图像,为图像处理生成噪声并应用遮罩
    def prepare_intermediate_images(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None
        ):
            # 获取输入图像的批量大小、通道数、高度和宽度
            image_batch_size, channels, height, width = image.shape
    
            # 更新批量大小为每个提示生成的图像数量
            batch_size = batch_size * num_images_per_prompt
    
            # 定义新的形状用于生成噪声
            shape = (batch_size, channels, height, width)
    
            # 检查生成器列表的长度是否与请求的批量大小匹配
            if isinstance(generator, list) and len(generator) != batch_size:
                # 抛出错误以确保生成器数量与批量大小一致
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            # 生成与输入形状匹配的随机噪声
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    
            # 将输入图像按提示的数量重复,以匹配新批量大小
            image = image.repeat_interleave(num_images_per_prompt, dim=0)
            # 向图像添加噪声
            noised_image = self.scheduler.add_noise(image, noise, timestep)
    
            # 应用遮罩以混合原始图像和带噪声图像
            image = (1 - mask_image) * image + mask_image * noised_image
    
            # 返回处理后的图像
            return image
    
        # 禁用梯度计算以节省内存
        @torch.no_grad()
        # 替换示例文档字符串
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        def __call__(
            self,
            # 接受图像,支持多种输入类型
            image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
            # 可选参数:原始图像,支持多种输入类型
            original_image: Union[
                PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
            ] = None,
            # 可选参数:遮罩图像,支持多种输入类型
            mask_image: Union[
                PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
            ] = None,
            # 设置强度参数,默认值为0.8
            strength: float = 0.8,
            # 可选参数:提示信息,支持字符串或字符串列表
            prompt: Union[str, List[str]] = None,
            # 设置推理步骤的数量,默认值为100
            num_inference_steps: int = 100,
            # 可选参数:时间步列表
            timesteps: List[int] = None,
            # 设置指导尺度,默认值为4.0
            guidance_scale: float = 4.0,
            # 可选参数:负面提示信息
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 可选参数:每个提示生成的图像数量,默认值为1
            num_images_per_prompt: Optional[int] = 1,
            # 设置η值,默认值为0.0
            eta: float = 0.0,
            # 可选参数:随机数生成器
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 可选参数:提示嵌入
            prompt_embeds: Optional[torch.Tensor] = None,
            # 可选参数:负面提示嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 可选参数:输出类型,默认值为"pil"
            output_type: Optional[str] = "pil",
            # 可选参数:是否返回字典,默认值为True
            return_dict: bool = True,
            # 可选参数:回调函数
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 设置回调步骤的数量,默认值为1
            callback_steps: int = 1,
            # 可选参数:交叉注意力关键字参数
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 设置噪声水平,默认值为0
            noise_level: int = 0,
            # 可选参数:清理标题的标志,默认值为True
            clean_caption: bool = True,

.\diffusers\pipelines\deepfloyd_if\pipeline_if_superresolution.py

# 导入用于处理 HTML 的模块
import html
# 导入用于获取对象信息的模块
import inspect
# 导入用于正则表达式处理的模块
import re
# 导入 URL 解析模块并命名为 ul
import urllib.parse as ul
# 从 typing 模块导入类型提示相关的类
from typing import Any, Callable, Dict, List, Optional, Union

# 导入 NumPy 库并命名为 np
import numpy as np
# 导入图像处理库 PIL
import PIL.Image
# 导入 PyTorch 库
import torch
# 导入 PyTorch 的功能性模块
import torch.nn.functional as F
# 从 transformers 库导入图像处理器和 T5 模型及其标记器
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer

# 从相对路径导入所需的混合类和模型
from ...loaders import StableDiffusionLoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
# 从 utils 模块导入多个工具函数和常量
from ...utils import (
    BACKENDS_MAPPING,
    is_bs4_available,
    is_ftfy_available,
    logging,
    replace_example_docstring,
)
# 从 torch_utils 模块导入生成随机张量的函数
from ...utils.torch_utils import randn_tensor
# 从 pipeline_utils 导入扩散管道类
from ..pipeline_utils import DiffusionPipeline
# 从 pipeline_output 导入输出类
from .pipeline_output import IFPipelineOutput
# 从安全检查器模块导入安全检查器类
from .safety_checker import IFSafetyChecker
# 从水印模块导入水印处理类
from .watermark import IFWatermarker

# 如果 bs4 可用,则导入 BeautifulSoup 类
if is_bs4_available():
    from bs4 import BeautifulSoup

# 如果 ftfy 可用,则导入该模块
if is_ftfy_available():
    import ftfy

# 创建一个日志记录器以记录模块内的日志信息
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 示例文档字符串,提供了使用该类的示例
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
        >>> from diffusers.utils import pt_to_pil
        >>> import torch

        >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
        >>> pipe.enable_model_cpu_offload()

        >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
        >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)

        >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images

        >>> # save intermediate image
        >>> pil_image = pt_to_pil(image)
        >>> pil_image[0].save("./if_stage_I.png")

        >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
        ...     "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
        ... )
        >>> super_res_1_pipe.enable_model_cpu_offload()

        >>> image = super_res_1_pipe(
        ...     image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds
        ... ).images
        >>> image[0].save("./if_stage_II.png")
        ```py
"""

# 定义 IFSuperResolutionPipeline 类,继承自 DiffusionPipeline 和 StableDiffusionLoraLoaderMixin
class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
    # 定义标记器属性,类型为 T5Tokenizer
    tokenizer: T5Tokenizer
    # 定义文本编码器属性,类型为 T5EncoderModel
    text_encoder: T5EncoderModel

    # 定义 UNet 模型属性,类型为 UNet2DConditionModel
    unet: UNet2DConditionModel
    # 定义调度器属性,类型为 DDPMScheduler
    scheduler: DDPMScheduler
    # 定义图像噪声调度器属性,类型为 DDPMScheduler
    image_noising_scheduler: DDPMScheduler

    # 定义可选的特征提取器属性,类型为 CLIPImageProcessor
    feature_extractor: Optional[CLIPImageProcessor]
    # 定义可选的安全检查器属性,类型为 IFSafetyChecker
    safety_checker: Optional[IFSafetyChecker]

    # 定义可选的水印处理器属性,类型为 IFWatermarker
    watermarker: Optional[IFWatermarker]

    # 定义用于匹配不良标点符号的正则表达式
    bad_punct_regex = re.compile(
        r"["
        + "#®•©™&@·º½¾¿¡§~"
        + r"\)"
        + r"\("
        + r"\]"
        + r"\["
        + r"\}"
        + r"\{"
        + r"\|"
        + "\\"
        + r"\/"
        + r"\*"
        + r"]{1,}"
    )  # noqa
    # 定义可选组件的列表
    _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
    # 定义模型在 CPU 上的卸载顺序
    model_cpu_offload_seq = "text_encoder->unet"
    # 定义不参与 CPU 卸载的组件
    _exclude_from_cpu_offload = ["watermarker"]

    # 初始化方法,定义模型的各个组件
    def __init__(
        self,
        tokenizer: T5Tokenizer,  # 用于文本分词的模型
        text_encoder: T5EncoderModel,  # 文本编码器模型
        unet: UNet2DConditionModel,  # UNet 模型,用于图像生成
        scheduler: DDPMScheduler,  # 调度器,用于控制生成过程
        image_noising_scheduler: DDPMScheduler,  # 图像噪声调度器
        safety_checker: Optional[IFSafetyChecker],  # 安全检查器(可选)
        feature_extractor: Optional[CLIPImageProcessor],  # 特征提取器(可选)
        watermarker: Optional[IFWatermarker],  # 水印器(可选)
        requires_safety_checker: bool = True,  # 是否需要安全检查器的标志
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 检查安全检查器是否为 None,且要求使用安全检查器
        if safety_checker is None and requires_safety_checker:
            # 记录警告信息,提醒用户安全检查器未启用
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the IF license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 检查安全检查器存在且特征提取器为 None
        if safety_checker is not None and feature_extractor is None:
            # 抛出值错误,提醒用户需要定义特征提取器
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        # 检查 UNet 的输入通道数是否为 6
        if unet.config.in_channels != 6:
            # 记录警告,提醒用户所加载的检查点不适合超分辨率
            logger.warning(
                "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)."
            )

        # 注册各个组件到模型
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            image_noising_scheduler=image_noising_scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            watermarker=watermarker,
        )
        # 将是否需要安全检查器的信息注册到配置中
        self.register_to_config(requires_safety_checker=requires_safety_checker)

    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing 复制的内容
    # 定义文本预处理函数,接收文本和可选参数 clean_caption
    def _text_preprocessing(self, text, clean_caption=False):
        # 如果 clean_caption 为真且 bs4 库不可用,则记录警告信息
        if clean_caption and not is_bs4_available():
            logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
            # 记录将 clean_caption 设置为 False 的警告信息
            logger.warning("Setting `clean_caption` to False...")
            # 将 clean_caption 设置为 False
            clean_caption = False

        # 如果 clean_caption 为真且 ftfy 库不可用,则记录警告信息
        if clean_caption and not is_ftfy_available():
            logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
            # 记录将 clean_caption 设置为 False 的警告信息
            logger.warning("Setting `clean_caption` to False...")
            # 将 clean_caption 设置为 False
            clean_caption = False

        # 如果文本不是元组或列表,则将其转换为列表
        if not isinstance(text, (tuple, list)):
            text = [text]

        # 定义处理单个文本的内部函数
        def process(text: str):
            # 如果 clean_caption 为真,执行清理操作
            if clean_caption:
                text = self._clean_caption(text)
                # 再次清理文本
                text = self._clean_caption(text)
            else:
                # 将文本转为小写并去除首尾空白
                text = text.lower().strip()
            # 返回处理后的文本
            return text

        # 对文本列表中的每个文本应用处理函数并返回结果列表
        return [process(t) for t in text]

    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption 复制而来
    @torch.no_grad()
    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt 复制而来
    def encode_prompt(
        # 定义提示的类型为字符串或字符串列表
        prompt: Union[str, List[str]],
        # 是否使用无分类器自由引导,默认为 True
        do_classifier_free_guidance: bool = True,
        # 每个提示生成的图像数量,默认为 1
        num_images_per_prompt: int = 1,
        # 可选参数,指定设备
        device: Optional[torch.device] = None,
        # 可选参数,负面提示,可以是字符串或字符串列表
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 可选参数,提示嵌入
        prompt_embeds: Optional[torch.Tensor] = None,
        # 可选参数,负面提示嵌入
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 是否清理提示的选项,默认为 False
        clean_caption: bool = False,
    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker 复制而来
    def run_safety_checker(self, image, device, dtype):
        # 如果存在安全检查器
        if self.safety_checker is not None:
            # 使用特征提取器处理图像并转换为张量
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            # 执行安全检查,返回处理后的图像和检测结果
            image, nsfw_detected, watermark_detected = self.safety_checker(
                images=image,
                clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
            )
        else:
            # 如果没有安全检查器,设置检测结果为 None
            nsfw_detected = None
            watermark_detected = None

        # 返回处理后的图像及其检测结果
        return image, nsfw_detected, watermark_detected

    # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs 复制而来
    # 准备额外的参数用于调度器步骤,因为并非所有调度器具有相同的参数签名
        def prepare_extra_step_kwargs(self, generator, eta):
            # eta (η) 仅在 DDIMScheduler 中使用,其他调度器将忽略
            # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
            # 应该在 [0, 1] 范围内
    
            # 检查调度器的步骤是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 创建额外参数字典
            extra_step_kwargs = {}
            # 如果接受 eta,添加到额外参数中
            if accepts_eta:
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器的步骤是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 如果接受 generator,添加到额外参数中
            if accepts_generator:
                extra_step_kwargs["generator"] = generator
            # 返回包含额外参数的字典
            return extra_step_kwargs
    
        # 检查输入参数的有效性
        def check_inputs(
            self,
            prompt,
            image,
            batch_size,
            noise_level,
            callback_steps,
            negative_prompt=None,
            prompt_embeds=None,
            negative_prompt_embeds=None,
        # 从 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images 复制
        def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
            # 定义中间图像的形状
            shape = (batch_size, num_channels, height, width)
            # 检查生成器列表的长度是否与请求的批大小匹配
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"您传入了长度为 {len(generator)} 的生成器列表,但请求的有效批大小为 {batch_size}。"
                    f" 请确保批大小与生成器的长度匹配。"
                )
    
            # 生成随机噪声张量作为初始中间图像
            intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    
            # 按调度器所需的标准差缩放初始噪声
            intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
            # 返回生成的中间图像
            return intermediate_images
    # 预处理图像,使其适合于后续模型处理
    def preprocess_image(self, image, num_images_per_prompt, device):
        # 检查输入是否为张量或列表,若不是,则将其转换为列表
        if not isinstance(image, torch.Tensor) and not isinstance(image, list):
            image = [image]
    
        # 如果输入图像是 PIL 图像,则转换为 NumPy 数组并归一化
        if isinstance(image[0], PIL.Image.Image):
            image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image]
    
            # 将图像列表堆叠为 NumPy 数组
            image = np.stack(image, axis=0)  # to np
            # 转换 NumPy 数组为 PyTorch 张量,并调整维度顺序
            image = torch.from_numpy(image.transpose(0, 3, 1, 2))
        # 如果输入图像是 NumPy 数组,则直接堆叠并处理
        elif isinstance(image[0], np.ndarray):
            image = np.stack(image, axis=0)  # to np
            # 如果数组有五个维度,则只取第一个
            if image.ndim == 5:
                image = image[0]
    
            # 转换 NumPy 数组为 PyTorch 张量,并调整维度顺序
            image = torch.from_numpy(image.transpose(0, 3, 1, 2))
        # 如果输入是张量列表,则根据维度进行堆叠或连接
        elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
            dims = image[0].ndim
    
            # 三维张量则堆叠,四维张量则连接
            if dims == 3:
                image = torch.stack(image, dim=0)
            elif dims == 4:
                image = torch.concat(image, dim=0)
            # 若维度不符合要求,抛出错误
            else:
                raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
    
        # 将图像移动到指定设备,并设置数据类型
        image = image.to(device=device, dtype=self.unet.dtype)
    
        # 根据每个提示生成的图像数量重复图像
        image = image.repeat_interleave(num_images_per_prompt, dim=0)
    
        # 返回处理后的图像
        return image
    
    # 装饰器,用于禁用梯度计算,节省内存和计算
    @torch.no_grad()
    # 替换示例文档字符串的装饰器
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义可调用的类方法
    def __call__(
        # 提示内容,可以是字符串或字符串列表
        prompt: Union[str, List[str]] = None,
        # 输出图像的高度
        height: int = None,
        # 输出图像的宽度
        width: int = None,
        # 输入的图像,可以是多种类型
        image: Union[PIL.Image.Image, np.ndarray, torch.Tensor] = None,
        # 推理步骤的数量
        num_inference_steps: int = 50,
        # 时间步列表
        timesteps: List[int] = None,
        # 指导缩放因子
        guidance_scale: float = 4.0,
        # 负提示,可以是字符串或字符串列表
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 每个提示生成的图像数量
        num_images_per_prompt: Optional[int] = 1,
        # 噪声级别
        eta: float = 0.0,
        # 随机数生成器,可以是单个或列表
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 提示嵌入张量
        prompt_embeds: Optional[torch.Tensor] = None,
        # 负提示嵌入张量
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 输出类型,默认为 PIL
        output_type: Optional[str] = "pil",
        # 是否返回字典格式
        return_dict: bool = True,
        # 可选的回调函数
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
        # 回调步骤的数量
        callback_steps: int = 1,
        # 交叉注意力的参数
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 噪声级别
        noise_level: int = 250,
        # 是否清理标题
        clean_caption: bool = True,

.\diffusers\pipelines\deepfloyd_if\pipeline_output.py

# 从 dataclasses 模块导入 dataclass 装饰器,用于简化类的定义
from dataclasses import dataclass
# 从 typing 模块导入类型注解,用于类型提示
from typing import List, Optional, Union

# 导入 numpy 库,通常用于数组操作和数值计算
import numpy as np
# 导入 PIL.Image,用于处理图像
import PIL.Image

# 从上级模块导入 BaseOutput 基类,用于输出类的继承
from ...utils import BaseOutput


# 定义 IFPipelineOutput 类,继承自 BaseOutput
@dataclass
class IFPipelineOutput(BaseOutput):
    """
    Args:
    Output class for Stable Diffusion pipelines.
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
            num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
        nsfw_detected (`List[bool]`)
            List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content or a watermark. `None` if safety checking could not be performed.
        watermark_detected (`List[bool]`)
            List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
            checking could not be performed.
    """

    # 定义 images 属性,可以是 PIL 图像列表或 numpy 数组
    images: Union[List[PIL.Image.Image], np.ndarray]
    # 定义 nsfw_detected 属性,可选的布尔列表,用于标记是否检测到不安全内容
    nsfw_detected: Optional[List[bool]]
    # 定义 watermark_detected 属性,可选的布尔列表,用于标记是否检测到水印
    watermark_detected: Optional[List[bool]]

.\diffusers\pipelines\deepfloyd_if\safety_checker.py

# 导入必要的库和模块
import numpy as np  # 导入 NumPy 库,用于数值计算
import torch  # 导入 PyTorch 库,用于深度学习
import torch.nn as nn  # 导入 PyTorch 的神经网络模块
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel  # 导入 CLIP 配置和模型

from ...utils import logging  # 从父级模块导入日志工具


# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)


# 定义一个安全检查器类,继承自预训练模型
class IFSafetyChecker(PreTrainedModel):
    # 指定配置类为 CLIPConfig
    config_class = CLIPConfig

    # 指定不进行拆分的模块
    _no_split_modules = ["CLIPEncoderLayer"]

    # 初始化方法,接受一个配置对象
    def __init__(self, config: CLIPConfig):
        # 调用父类的初始化方法
        super().__init__(config)

        # 创建视觉模型,使用配置中的视觉部分
        self.vision_model = CLIPVisionModelWithProjection(config.vision_config)

        # 定义线性层用于 NSFW 检测
        self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
        # 定义线性层用于水印检测
        self.w_head = nn.Linear(config.vision_config.projection_dim, 1)

    # 无梯度计算的前向传播方法
    @torch.no_grad()
    def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
        # 获取图像的嵌入向量
        image_embeds = self.vision_model(clip_input)[0]

        # 检测 NSFW 内容
        nsfw_detected = self.p_head(image_embeds)
        # 将输出展平为一维
        nsfw_detected = nsfw_detected.flatten()
        # 根据阈值判断是否检测到 NSFW 内容
        nsfw_detected = nsfw_detected > p_threshold
        # 转换为列表格式
        nsfw_detected = nsfw_detected.tolist()

        # 如果检测到 NSFW 内容,记录警告日志
        if any(nsfw_detected):
            logger.warning(
                "Potential NSFW content was detected in one or more images. A black image will be returned instead."
                " Try again with a different prompt and/or seed."
            )

        # 遍历每个图像,处理检测到 NSFW 内容的图像
        for idx, nsfw_detected_ in enumerate(nsfw_detected):
            if nsfw_detected_:
                # 将检测到的 NSFW 图像替换为全黑图像
                images[idx] = np.zeros(images[idx].shape)

        # 检测水印内容
        watermark_detected = self.w_head(image_embeds)
        # 将输出展平为一维
        watermark_detected = watermark_detected.flatten()
        # 根据阈值判断是否检测到水印内容
        watermark_detected = watermark_detected > w_threshold
        # 转换为列表格式
        watermark_detected = watermark_detected.tolist()

        # 如果检测到水印内容,记录警告日志
        if any(watermark_detected):
            logger.warning(
                "Potential watermarked content was detected in one or more images. A black image will be returned instead."
                " Try again with a different prompt and/or seed."
            )

        # 遍历每个图像,处理检测到水印的图像
        for idx, watermark_detected_ in enumerate(watermark_detected):
            if watermark_detected_:
                # 将检测到水印的图像替换为全黑图像
                images[idx] = np.zeros(images[idx].shape)

        # 返回处理后的图像和检测结果
        return images, nsfw_detected, watermark_detected

.\diffusers\pipelines\deepfloyd_if\timesteps.py

# 定义一个包含 27 个时间步长的列表,表示快速模式下的时间点
fast27_timesteps = [
    # 具体的时间步长值
    999,
    800,
    799,
    600,
    599,
    500,
    400,
    399,
    377,
    355,
    333,
    311,
    288,
    266,
    244,
    222,
    200,
    199,
    177,
    155,
    133,
    111,
    88,
    66,
    44,
    22,
    0,
]

# 定义一个包含 27 个时间步长的列表,表示智能模式下的时间点
smart27_timesteps = [
    # 具体的时间步长值
    999,
    976,
    952,
    928,
    905,
    882,
    858,
    857,
    810,
    762,
    715,
    714,
    572,
    429,
    428,
    286,
    285,
    238,
    190,
    143,
    142,
    118,
    95,
    71,
    47,
    24,
    0,
]

# 定义一个包含 100 个时间步长的列表,表示智能模式下的时间点
smart50_timesteps = [
    # 具体的时间步长值
    999,
    988,
    977,
    966,
    955,
    944,
    933,
    922,
    911,
    900,
    899,
    879,
    859,
    840,
    820,
    800,
    799,
    766,
    733,
    700,
    699,
    650,
    600,
    599,
    500,
    499,
    400,
    399,
    350,
    300,
    299,
    266,
    233,
    200,
    199,
    179,
    159,
    140,
    120,
    100,
    99,
    88,
    77,
    66,
    55,
    44,
    33,
    22,
    11,
    0,
]

# 定义一个包含 185 个时间步长的列表,表示智能模式下的时间点
smart100_timesteps = [
    # 具体的时间步长值
    999,
    995,
    992,
    989,
    985,
    981,
    978,
    975,
    971,
    967,
    964,
    961,
    957,
    956,
    951,
    947,
    942,
    937,
    933,
    928,
    923,
    919,
    914,
    913,
    908,
    903,
    897,
    892,
    887,
    881,
    876,
    871,
    870,
    864,
    858,
    852,
    846,
    840,
    834,
    828,
    827,
    820,
    813,
    806,
    799,
    792,
    785,
    784,
    777,
    770,
    763,
    756,
    749,
    742,
    741,
    733,
    724,
    716,
    707,
    699,
    698,
    688,
    677,
    666,
    656,
    655,
    645,
    634,
    623,
    613,
    612,
    598,
    584,
    570,
    569,
    555,
    541,
    527,
    526,
    505,
    484,
    483,
    462,
    440,
    439,
    396,
    395,
    352,
    351,
    308,
    307,
    264,
    263,
    220,
    219,
    176,
    132,
    88,
    44,
    0,
]

# 定义一个包含 185 个时间步长的列表,表示智能模式下的时间点
smart185_timesteps = [
    # 具体的时间步长值
    999,
    997,
    995,
    992,
    990,
    988,
    986,
    984,
    981,
    979,
    977,
    975,
    972,
    970,
    968,
    966,
    964,
    961,
    959,
    957,
    956,
    954,
    951,
    949,
    946,
    944,
    941,
    939,
    936,
    934,
    931,
    929,
    926,
    924,
    921,
    919,
    916,
    914,
    913,
    910,
    907,
    905,
    902,
    899,
    896,
    893,
    891,
    888,
    885,
    882,
    879,
    877,
    874,
    871,
    870,
    867,
    864,
    861,
    858,
    855,
    852,
    849,
    846,
    843,
    840,
    837,
    834,
    831,
    828,
    827,
    824,
    821,
    817,
    814,
    811,
    808,
    804,
    801,
    798,
    795,
    791,
    788,
    785,
    784,
    780,
    777,
    774,
    770,
    766,
    763,
    760,
    756,
    752,
    749,
    746,
    742,
    741,
    737,
    733,
    730,
    726,
    722,
    718,
    714,
    710,
    707,
    703,
    699,
    698,
    694,
    690,
    685,
    681,
    677,
    673,
    669,
    664,
    660,
    # 添加数字 656 到列表
    656,
    # 添加数字 655 到列表
    655,
    # 添加数字 650 到列表
    650,
    # 添加数字 646 到列表
    646,
    # 添加数字 641 到列表
    641,
    # 添加数字 636 到列表
    636,
    # 添加数字 632 到列表
    632,
    # 添加数字 627 到列表
    627,
    # 添加数字 622 到列表
    622,
    # 添加数字 618 到列表
    618,
    # 添加数字 613 到列表
    613,
    # 添加数字 612 到列表
    612,
    # 添加数字 607 到列表
    607,
    # 添加数字 602 到列表
    602,
    # 添加数字 596 到列表
    596,
    # 添加数字 591 到列表
    591,
    # 添加数字 586 到列表
    586,
    # 添加数字 580 到列表
    580,
    # 添加数字 575 到列表
    575,
    # 添加数字 570 到列表
    570,
    # 添加数字 569 到列表
    569,
    # 添加数字 563 到列表
    563,
    # 添加数字 557 到列表
    557,
    # 添加数字 551 到列表
    551,
    # 添加数字 545 到列表
    545,
    # 添加数字 539 到列表
    539,
    # 添加数字 533 到列表
    533,
    # 添加数字 527 到列表
    527,
    # 添加数字 526 到列表
    526,
    # 添加数字 519 到列表
    519,
    # 添加数字 512 到列表
    512,
    # 添加数字 505 到列表
    505,
    # 添加数字 498 到列表
    498,
    # 添加数字 491 到列表
    491,
    # 添加数字 484 到列表
    484,
    # 添加数字 483 到列表
    483,
    # 添加数字 474 到列表
    474,
    # 添加数字 466 到列表
    466,
    # 添加数字 457 到列表
    457,
    # 添加数字 449 到列表
    449,
    # 添加数字 440 到列表
    440,
    # 添加数字 439 到列表
    439,
    # 添加数字 428 到列表
    428,
    # 添加数字 418 到列表
    418,
    # 添加数字 407 到列表
    407,
    # 添加数字 396 到列表
    396,
    # 添加数字 395 到列表
    395,
    # 添加数字 381 到列表
    381,
    # 添加数字 366 到列表
    366,
    # 添加数字 352 到列表
    352,
    # 添加数字 351 到列表
    351,
    # 添加数字 330 到列表
    330,
    # 添加数字 308 到列表
    308,
    # 添加数字 307 到列表
    307,
    # 添加数字 286 到列表
    286,
    # 添加数字 264 到列表
    264,
    # 添加数字 263 到列表
    263,
    # 添加数字 242 到列表
    242,
    # 添加数字 220 到列表
    220,
    # 添加数字 219 到列表
    219,
    # 添加数字 176 到列表
    176,
    # 添加数字 175 到列表
    175,
    # 添加数字 132 到列表
    132,
    # 添加数字 131 到列表
    131,
    # 添加数字 88 到列表
    88,
    # 添加数字 44 到列表
    44,
    # 添加数字 0 到列表
    0,
# 定义一个空列表,后续将填充时间步长数据
super27_timesteps = [
    # 定义时间步长,值递减
    999,
    991,
    982,
    974,
    966,
    958,
    950,
    941,
    933,
    925,
    916,
    908,
    900,
    899,
    874,
    850,
    825,
    800,
    799,
    700,
    600,
    500,
    400,
    300,
    200,
    100,
    0,
]

# 定义另一个时间步长列表,值同样递减
super40_timesteps = [
    # 各时间步长的值
    999,
    992,
    985,
    978,
    971,
    964,
    957,
    949,
    942,
    935,
    928,
    921,
    914,
    907,
    900,
    899,
    879,
    859,
    840,
    820,
    800,
    799,
    766,
    733,
    700,
    699,
    650,
    600,
    599,
    500,
    499,
    400,
    399,
    300,
    299,
    200,
    199,
    100,
    99,
    0,
]

# 定义第三个时间步长列表,值继续递减
super100_timesteps = [
    # 包含一系列时间步长的值
    999,
    996,
    992,
    989,
    985,
    982,
    979,
    975,
    972,
    968,
    965,
    961,
    958,
    955,
    951,
    948,
    944,
    941,
    938,
    934,
    931,
    927,
    924,
    920,
    917,
    914,
    910,
    907,
    903,
    900,
    899,
    891,
    884,
    876,
    869,
    861,
    853,
    846,
    838,
    830,
    823,
    815,
    808,
    800,
    799,
    788,
    777,
    766,
    755,
    744,
    733,
    722,
    711,
    700,
    699,
    688,
    677,
    666,
    655,
    644,
    633,
    622,
    611,
    600,
    599,
    585,
    571,
    557,
    542,
    528,
    514,
    500,
    499,
    485,
    471,
    457,
    442,
    428,
    414,
    400,
    399,
    379,
    359,
    340,
    320,
    300,
    299,
    279,
    259,
    240,
    220,
    200,
    199,
    166,
    133,
    100,
    99,
    66,
    33,
    0,
]

.\diffusers\pipelines\deepfloyd_if\watermark.py

# 从 typing 模块导入 List 类型,用于类型注释
from typing import List

# 导入 PIL.Image 库以处理图像
import PIL.Image
# 导入 torch 库用于张量操作
import torch
# 从 PIL 导入 Image 类以创建和处理图像
from PIL import Image

# 从配置工具模块导入 ConfigMixin 类
from ...configuration_utils import ConfigMixin
# 从模型工具模块导入 ModelMixin 类
from ...models.modeling_utils import ModelMixin
# 从工具模块导入 PIL_INTERPOLATION 以获取插值方法
from ...utils import PIL_INTERPOLATION


# 定义 IFWatermarker 类,继承自 ModelMixin 和 ConfigMixin
class IFWatermarker(ModelMixin, ConfigMixin):
    # 初始化方法
    def __init__(self):
        # 调用父类初始化方法
        super().__init__()

        # 注册一个形状为 (62, 62, 4) 的零张量作为水印图像
        self.register_buffer("watermark_image", torch.zeros((62, 62, 4)))
        # 初始化水印图像的 PIL 表示为 None
        self.watermark_image_as_pil = None

    # 定义应用水印的方法,接受图像列表和可选的样本大小
    def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None):
        # 从 GitHub 复制的代码

        # 获取第一张图像的高度
        h = images[0].height
        # 获取第一张图像的宽度
        w = images[0].width

        # 如果未指定样本大小,则使用图像高度
        sample_size = sample_size or h

        # 计算宽高比系数
        coef = min(h / sample_size, w / sample_size)
        # 根据系数计算图像的新高度和宽度
        img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)

        # 定义 S1 和 S2,用于计算 K
        S1, S2 = 1024**2, img_w * img_h
        # 计算 K 值
        K = (S2 / S1) ** 0.5
        # 计算水印大小及其在图像中的位置
        wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)

        # 如果水印图像尚未创建
        if self.watermark_image_as_pil is None:
            # 将水印张量转换为 uint8 类型并转移到 CPU,转换为 NumPy 数组
            watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy()
            # 将 NumPy 数组转换为 RGBA 模式的 PIL 图像
            watermark_image = Image.fromarray(watermark_image, mode="RGBA")
            # 将 PIL 图像保存到实例变量
            self.watermark_image_as_pil = watermark_image

        # 调整水印图像大小
        wm_img = self.watermark_image_as_pil.resize(
            (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None
        )

        # 遍历输入图像列表
        for pil_img in images:
            # 将水印图像粘贴到每张图像上,使用水印的 alpha 通道作为掩码
            pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])

        # 返回添加水印后的图像列表
        return images
posted @ 2024-10-22 12:35  绝不原创的飞龙  阅读(43)  评论(0)    收藏  举报