diffusers-源码解析-六十六-

diffusers 源码解析(六十六)

.\diffusers\utils\logging.py

# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,标明版权所有者和年份
# Copyright 2024 Optuna, Hugging Face
#
# 根据 Apache License 2.0 版本许可本文件的使用
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件在未遵守许可证的情况下不可使用
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证的副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 在适用的情况下,许可证下的软件以“原样”提供,不提供任何明示或暗示的担保
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取特定的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
"""记录工具函数的模块。"""

# 导入 logging 模块以实现日志记录功能
import logging
# 导入 os 模块以进行操作系统交互
import os
# 导入 sys 模块以访问系统特定参数和功能
import sys
# 导入 threading 模块以实现线程支持
import threading
# 从 logging 模块导入不同的日志级别常量
from logging import (
    CRITICAL,  # NOQA
    DEBUG,  # NOQA
    ERROR,  # NOQA
    FATAL,  # NOQA
    INFO,  # NOQA
    NOTSET,  # NOQA
    WARN,  # NOQA
    WARNING,  # NOQA
)
# 导入 Dict 和 Optional 类型以进行类型注解
from typing import Dict, Optional

# 从 tqdm 库导入自动选择的进度条支持
from tqdm import auto as tqdm_lib

# 创建一个线程锁以确保线程安全
_lock = threading.Lock()
# 定义一个默认的日志处理程序,初始为 None
_default_handler: Optional[logging.Handler] = None

# 定义日志级别的字典,映射字符串到 logging 的级别
log_levels = {
    "debug": logging.DEBUG,
    "info": logging.INFO,
    "warning": logging.WARNING,
    "error": logging.ERROR,
    "critical": logging.CRITICAL,
}

# 设置默认日志级别为 WARNING
_default_log_level = logging.WARNING

# 标志表示进度条是否处于活动状态
_tqdm_active = True

# 定义获取默认日志级别的函数
def _get_default_logging_level() -> int:
    """
    如果环境变量 DIFFUSERS_VERBOSITY 设置为有效选项,则返回该值作为新的默认级别。
    如果没有设置,则返回 `_default_log_level`。
    """
    # 获取环境变量 DIFFUSERS_VERBOSITY 的值
    env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
    # 如果环境变量存在
    if env_level_str:
        # 检查环境变量值是否在日志级别字典中
        if env_level_str in log_levels:
            return log_levels[env_level_str]
        else:
            # 如果值无效,记录警告信息
            logging.getLogger().warning(
                f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
                f"has to be one of: { ', '.join(log_levels.keys()) }"
            )
    # 返回默认日志级别
    return _default_log_level

# 定义获取库名称的函数
def _get_library_name() -> str:
    # 返回模块名称的第一个部分作为库名称
    return __name__.split(".")[0]

# 定义获取库根日志记录器的函数
def _get_library_root_logger() -> logging.Logger:
    # 返回库名称对应的日志记录器
    return logging.getLogger(_get_library_name())

# 定义配置库根日志记录器的函数
def _configure_library_root_logger() -> None:
    global _default_handler

    # 使用线程锁来确保线程安全
    with _lock:
        # 如果默认处理程序已存在,返回
        if _default_handler:
            # 该库已配置根日志记录器
            return
        # 创建一个流处理程序,输出到标准错误
        _default_handler = logging.StreamHandler()  # Set sys.stderr as stream.

        # 检查 sys.stderr 是否存在
        if sys.stderr:  # only if sys.stderr exists, e.g. when not using pythonw in windows
            # 设置 flush 方法为 sys.stderr 的 flush 方法
            _default_handler.flush = sys.stderr.flush

        # 应用默认配置到库根日志记录器
        library_root_logger = _get_library_root_logger()
        # 添加默认处理程序到库根日志记录器
        library_root_logger.addHandler(_default_handler)
        # 设置库根日志记录器的日志级别
        library_root_logger.setLevel(_get_default_logging_level())
        # 禁用日志记录器的传播
        library_root_logger.propagate = False

# 定义重置库根日志记录器的函数
def _reset_library_root_logger() -> None:
    global _default_handler
    # 使用锁确保线程安全,防止竞争条件
        with _lock:
            # 如果没有默认处理器,则直接返回
            if not _default_handler:
                return
    
            # 获取库的根日志记录器
            library_root_logger = _get_library_root_logger()
            # 从根日志记录器中移除默认处理器
            library_root_logger.removeHandler(_default_handler)
            # 将根日志记录器的日志级别设置为 NOTSET,表示接受所有级别的日志
            library_root_logger.setLevel(logging.NOTSET)
            # 将默认处理器设置为 None,表示不再使用默认处理器
            _default_handler = None
# 获取日志级别字典
def get_log_levels_dict() -> Dict[str, int]:
    # 返回全局日志级别字典
    return log_levels


# 获取指定名称的日志记录器
def get_logger(name: Optional[str] = None) -> logging.Logger:
    """
    返回具有指定名称的日志记录器。

    该函数不应直接访问,除非您正在编写自定义的 diffusers 模块。
    """

    # 如果未提供名称,则获取库名称
    if name is None:
        name = _get_library_name()

    # 配置库根日志记录器
    _configure_library_root_logger()
    # 返回指定名称的日志记录器
    return logging.getLogger(name)


# 获取当前日志级别
def get_verbosity() -> int:
    """
    返回 🤗 Diffusers 根日志记录器的当前级别作为 `int`。

    返回:
        `int`:
            日志级别整数,可以是以下之一:

            - `50`: `diffusers.logging.CRITICAL` 或 `diffusers.logging.FATAL`
            - `40`: `diffusers.logging.ERROR`
            - `30`: `diffusers.logging.WARNING` 或 `diffusers.logging.WARN`
            - `20`: `diffusers.logging.INFO`
            - `10`: `diffusers.logging.DEBUG`
    """

    # 配置库根日志记录器
    _configure_library_root_logger()
    # 返回根日志记录器的有效级别
    return _get_library_root_logger().getEffectiveLevel()


# 设置日志级别
def set_verbosity(verbosity: int) -> None:
    """
    设置 🤗 Diffusers 根日志记录器的详细程度。

    参数:
        verbosity (`int`):
            日志级别,可以是以下之一:

            - `diffusers.logging.CRITICAL` 或 `diffusers.logging.FATAL`
            - `diffusers.logging.ERROR`
            - `diffusers.logging.WARNING` 或 `diffusers.logging.WARN`
            - `diffusers.logging.INFO`
            - `diffusers.logging.DEBUG`
    """

    # 配置库根日志记录器
    _configure_library_root_logger()
    # 设置根日志记录器的级别
    _get_library_root_logger().setLevel(verbosity)


# 设置日志级别为 INFO
def set_verbosity_info() -> None:
    """将详细程度设置为 `INFO` 级别。"""
    # 调用设置详细程度的函数
    return set_verbosity(INFO)


# 设置日志级别为 WARNING
def set_verbosity_warning() -> None:
    """将详细程度设置为 `WARNING` 级别。"""
    # 调用设置详细程度的函数
    return set_verbosity(WARNING)


# 设置日志级别为 DEBUG
def set_verbosity_debug() -> None:
    """将详细程度设置为 `DEBUG` 级别。"""
    # 调用设置详细程度的函数
    return set_verbosity(DEBUG)


# 设置日志级别为 ERROR
def set_verbosity_error() -> None:
    """将详细程度设置为 `ERROR` 级别。"""
    # 调用设置详细程度的函数
    return set_verbosity(ERROR)


# 禁用默认处理程序
def disable_default_handler() -> None:
    """禁用 🤗 Diffusers 根日志记录器的默认处理程序。"""

    # 配置库根日志记录器
    _configure_library_root_logger()

    # 确保默认处理程序存在
    assert _default_handler is not None
    # 从根日志记录器中移除默认处理程序
    _get_library_root_logger().removeHandler(_default_handler)


# 启用默认处理程序
def enable_default_handler() -> None:
    """启用 🤗 Diffusers 根日志记录器的默认处理程序。"""

    # 配置库根日志记录器
    _configure_library_root_logger()

    # 确保默认处理程序存在
    assert _default_handler is not None
    # 将默认处理程序添加到根日志记录器
    _get_library_root_logger().addHandler(_default_handler)


# 添加处理程序到日志记录器
def add_handler(handler: logging.Handler) -> None:
    """将处理程序添加到 HuggingFace Diffusers 根日志记录器。"""

    # 配置库根日志记录器
    _configure_library_root_logger()

    # 确保处理程序存在
    assert handler is not None
    # 将处理程序添加到根日志记录器
    _get_library_root_logger().addHandler(handler)


# 从日志记录器移除处理程序
def remove_handler(handler: logging.Handler) -> None:
    """从 HuggingFace Diffusers 根日志记录器移除给定的处理程序。"""

    # 配置库根日志记录器
    _configure_library_root_logger()
    # 确保处理器不为空,并且在库根日志记录器的处理器列表中
    assert handler is not None and handler in _get_library_root_logger().handlers
    # 从库根日志记录器中移除指定的处理器
    _get_library_root_logger().removeHandler(handler)
# 定义一个函数,用于禁用库的日志输出传播
def disable_propagation() -> None:
    """
    Disable propagation of the library log outputs. Note that log propagation is disabled by default.
    """
    # 配置库的根日志记录器
    _configure_library_root_logger()
    # 设置根日志记录器的传播属性为 False,禁用日志传播
    _get_library_root_logger().propagate = False


# 定义一个函数,用于启用库的日志输出传播
def enable_propagation() -> None:
    """
    Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
    double logging if the root logger has been configured.
    """
    # 配置库的根日志记录器
    _configure_library_root_logger()
    # 设置根日志记录器的传播属性为 True,启用日志传播
    _get_library_root_logger().propagate = True


# 定义一个函数,用于启用明确的日志格式
def enable_explicit_format() -> None:
    """
    Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
    ```py
    [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
    ```
    All handlers currently bound to the root logger are affected by this method.
    """
    # 获取根日志记录器的所有处理器
    handlers = _get_library_root_logger().handlers

    # 遍历每个处理器,设置其格式化器
    for handler in handlers:
        # 创建一个新的格式化器
        formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
        # 将格式化器设置到处理器上
        handler.setFormatter(formatter)


# 定义一个函数,用于重置日志格式
def reset_format() -> None:
    """
    Resets the formatting for 🤗 Diffusers' loggers.

    All handlers currently bound to the root logger are affected by this method.
    """
    # 获取根日志记录器的所有处理器
    handlers = _get_library_root_logger().handlers

    # 遍历每个处理器,重置其格式化器
    for handler in handlers:
        # 将处理器的格式化器设置为 None,重置格式
        handler.setFormatter(None)


# 定义一个方法,用于发出警告信息
def warning_advice(self, *args, **kwargs) -> None:
    """
    This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
    warning will not be printed
    """
    # 检查环境变量是否设置为不发出建议警告
    no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
    # 如果设置了环境变量,则直接返回,不发出警告
    if no_advisory_warnings:
        return
    # 调用日志记录器的警告方法
    self.warning(*args, **kwargs)


# 将自定义的警告方法绑定到日志记录器
logging.Logger.warning_advice = warning_advice


# 定义一个空的 tqdm 类,用于替代真实的进度条
class EmptyTqdm:
    """Dummy tqdm which doesn't do anything."""

    # 初始化方法,接收可变参数
    def __init__(self, *args, **kwargs):  # pylint: disable=unused-argument
        # 如果有参数,保存第一个参数为迭代器
        self._iterator = args[0] if args else None

    # 定义迭代器方法,返回迭代器
    def __iter__(self):
        return iter(self._iterator)

    # 定义属性访问方法,返回一个空函数
    def __getattr__(self, _):
        """Return empty function."""

        # 返回一个空函数
        def empty_fn(*args, **kwargs):  # pylint: disable=unused-argument
            return

        return empty_fn

    # 定义上下文管理器的进入方法
    def __enter__(self):
        return self

    # 定义上下文管理器的退出方法
    def __exit__(self, type_, value, traceback):
        return


# 定义一个自定义的 tqdm 类
class _tqdm_cls:
    # 定义调用方法
    def __call__(self, *args, **kwargs):
        # 检查 tqdm 是否处于激活状态
        if _tqdm_active:
            # 返回激活状态下的 tqdm 实例
            return tqdm_lib.tqdm(*args, **kwargs)
        else:
            # 返回空的 tqdm 实例
            return EmptyTqdm(*args, **kwargs)

    # 定义设置锁的方法
    def set_lock(self, *args, **kwargs):
        # 将锁设置为 None
        self._lock = None
        # 如果 tqdm 处于激活状态,设置锁
        if _tqdm_active:
            return tqdm_lib.tqdm.set_lock(*args, **kwargs)

    # 定义获取锁的方法
    def get_lock(self):
        # 如果 tqdm 处于激活状态,获取锁
        if _tqdm_active:
            return tqdm_lib.tqdm.get_lock()


# 创建一个 _tqdm_cls 的实例
tqdm = _tqdm_cls()


# 定义一个函数,检查进度条是否启用
def is_progress_bar_enabled() -> bool:
    """Return a boolean indicating whether tqdm progress bars are enabled."""
    global _tqdm_active
    # 返回进度条激活状态的布尔值
    return bool(_tqdm_active)
# 定义一个启用进度条的函数,不返回任何值
def enable_progress_bar() -> None:
    # 函数文档字符串,说明该函数的作用是启用 tqdm 进度条
    """Enable tqdm progress bar."""
    # 声明全局变量 _tqdm_active
    global _tqdm_active
    # 将全局变量 _tqdm_active 设置为 True,表示进度条处于启用状态
    _tqdm_active = True


# 定义一个禁用进度条的函数,不返回任何值
def disable_progress_bar() -> None:
    # 函数文档字符串,说明该函数的作用是禁用 tqdm 进度条
    """Disable tqdm progress bar."""
    # 声明全局变量 _tqdm_active
    global _tqdm_active
    # 将全局变量 _tqdm_active 设置为 False,表示进度条处于禁用状态
    _tqdm_active = False

{{ card_data }}

{{ model_description }}

Intended uses & limitations

How to use

# TODO: add an example code snippet for running this diffusion pipeline

Limitations and bias

[TODO: provide examples of latent issues and potential remediations]

Training details

[TODO: describe the data used to train the model]

.\diffusers\utils\outputs.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,第 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有规定,软件
# 按“现状”分发,不附带任何明示或暗示的担保或条件。
# 请参阅许可证以了解管理权限和
# 限制的具体语言。
"""
通用工具函数
"""

# 从有序字典模块导入 OrderedDict
from collections import OrderedDict
# 从数据类模块导入字段和是否为数据类的检查
from dataclasses import fields, is_dataclass
# 导入 Any 和 Tuple 类型
from typing import Any, Tuple

# 导入 NumPy 库
import numpy as np

# 从本地导入工具模块,检查 PyTorch 是否可用及其版本
from .import_utils import is_torch_available, is_torch_version


def is_tensor(x) -> bool:
    """
    测试 `x` 是否为 `torch.Tensor` 或 `np.ndarray`。
    """
    # 如果 PyTorch 可用
    if is_torch_available():
        # 导入 PyTorch 库
        import torch

        # 检查 x 是否为 torch.Tensor 类型
        if isinstance(x, torch.Tensor):
            return True

    # 检查 x 是否为 np.ndarray 类型
    return isinstance(x, np.ndarray)


class BaseOutput(OrderedDict):
    """
    所有模型输出的基类,作为数据类。具有一个 `__getitem__` 方法,允许通过整数或切片(像元组)或字符串(像字典)进行索引,并会忽略 `None` 属性。
    否则像常规 Python 字典一样工作。

    <提示 警告={true}>
    
    不能直接解包 [`BaseOutput`]。请先使用 [`~utils.BaseOutput.to_tuple`] 方法将其转换为元组。
    
    </提示>
    """

    def __init_subclass__(cls) -> None:
        """将子类注册为 pytree 节点。

        这对于在使用 `torch.nn.parallel.DistributedDataParallel` 和
        `static_graph=True` 时同步梯度是必要的,尤其是对于输出 `ModelOutput` 子类的模块。
        """
        # 如果 PyTorch 可用
        if is_torch_available():
            # 导入 PyTorch 的 pytree 工具
            import torch.utils._pytree

            # 检查 PyTorch 版本是否小于 2.2
            if is_torch_version("<", "2.2"):
                # 注册 pytree 节点,使用字典扁平化和解扁平化
                torch.utils._pytree._register_pytree_node(
                    cls,
                    torch.utils._pytree._dict_flatten,
                    lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
                )
            else:
                # 注册 pytree 节点,使用字典扁平化和解扁平化
                torch.utils._pytree.register_pytree_node(
                    cls,
                    torch.utils._pytree._dict_flatten,
                    lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
                )
    # 定义数据类的后处理初始化方法
        def __post_init__(self) -> None:
            # 获取当前数据类的所有字段
            class_fields = fields(self)
    
            # 安全性和一致性检查
            if not len(class_fields):
                # 如果没有字段,抛出错误
                raise ValueError(f"{self.__class__.__name__} has no fields.")
    
            # 获取第一个字段的值
            first_field = getattr(self, class_fields[0].name)
            # 检查除了第一个字段外,其他字段是否均为 None
            other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
    
            # 如果其他字段均为 None 且第一个字段为字典,进行赋值
            if other_fields_are_none and isinstance(first_field, dict):
                for key, value in first_field.items():
                    # 将字典内容赋值到当前对象
                    self[key] = value
            else:
                # 遍历所有字段并赋值非 None 的字段
                for field in class_fields:
                    v = getattr(self, field.name)
                    if v is not None:
                        # 将非 None 的字段值赋值到当前对象
                        self[field.name] = v
    
        # 定义删除项的方法
        def __delitem__(self, *args, **kwargs):
            # 不允许删除项,抛出异常
            raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
    
        # 定义设置默认值的方法
        def setdefault(self, *args, **kwargs):
            # 不允许设置默认值,抛出异常
            raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
    
        # 定义弹出项的方法
        def pop(self, *args, **kwargs):
            # 不允许弹出项,抛出异常
            raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
    
        # 定义更新项的方法
        def update(self, *args, **kwargs):
            # 不允许更新项,抛出异常
            raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
    
        # 定义获取项的方法
        def __getitem__(self, k: Any) -> Any:
            # 如果键是字符串
            if isinstance(k, str):
                # 将当前项转换为字典并返回对应值
                inner_dict = dict(self.items())
                return inner_dict[k]
            else:
                # 如果键不是字符串,返回对应的元组值
                return self.to_tuple()[k]
    
        # 定义设置属性的方法
        def __setattr__(self, name: Any, value: Any) -> None:
            # 如果属性名在键中且值不为 None
            if name in self.keys() and value is not None:
                # 不调用 self.__setitem__ 以避免递归错误
                super().__setitem__(name, value)
            # 设置属性值
            super().__setattr__(name, value)
    
        # 定义设置项的方法
        def __setitem__(self, key, value):
            # 将键值对设置到当前对象中
            super().__setitem__(key, value)
            # 不调用 self.__setattr__ 以避免递归错误
            super().__setattr__(key, value)
    
        # 定义序列化的方法
        def __reduce__(self):
            # 如果当前对象不是数据类
            if not is_dataclass(self):
                # 调用父类的序列化方法
                return super().__reduce__()
            # 获取可调用对象和参数
            callable, _args, *remaining = super().__reduce__()
            # 生成字段的元组
            args = tuple(getattr(self, field.name) for field in fields(self))
            # 返回可调用对象、参数及其他信息
            return callable, args, *remaining
    
        # 定义转换为元组的方法
        def to_tuple(self) -> Tuple[Any, ...]:
            """
            将当前对象转换为一个包含所有非 `None` 属性/键的元组。
            """
            # 返回包含所有键的值的元组
            return tuple(self[k] for k in self.keys())

.\diffusers\utils\peft_utils.py

# 版权声明,声明本文件的版权归 HuggingFace 团队所有
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 在 Apache 许可证 2.0(“许可证”)下授权;
# 除非遵循该许可证,否则不得使用本文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是“按原样”提供的,
# 不提供任何形式的保证或条件,无论是明示还是暗示的。
# 有关许可证所管辖的权限和限制的具体信息,请参阅许可证。
"""
PEFT 工具:与 peft 库相关的工具
"""

# 导入 collections 模块以便于使用集合相关的功能
import collections
# 导入 importlib 模块以支持动态导入
import importlib
# 从 typing 导入 Optional 类型以用于类型注释
from typing import Optional

# 导入 version 模块以处理版本相关的功能
from packaging import version

# 从当前包导入 utils,检查 peft 和 torch 库是否可用
from .import_utils import is_peft_available, is_torch_available

# 如果 torch 库可用,则导入 torch 模块
if is_torch_available():
    import torch

# 定义函数以递归地替换模型中的 LoraLayer 实例
def recurse_remove_peft_layers(model):
    r"""
    递归替换模型中所有 `LoraLayer` 的实例为相应的新层。
    """
    # 从 peft.tuners.tuners_utils 导入 BaseTunerLayer 类
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 初始化一个标志,指示是否存在基础层模式
    has_base_layer_pattern = False
    # 遍历模型中的所有模块
    for module in model.modules():
        # 检查模块是否为 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 如果模块有名为 "base_layer" 的属性,则更新标志
            has_base_layer_pattern = hasattr(module, "base_layer")
            break

    # 如果存在基础层模式
    if has_base_layer_pattern:
        # 从 peft.utils 导入 _get_submodules 函数
        from peft.utils import _get_submodules

        # 获取所有不包含 "lora" 的模块名称列表
        key_list = [key for key, _ in model.named_modules() if "lora" not in key]
        # 遍历所有模块名称
        for key in key_list:
            try:
                # 获取当前模块的父级、目标和目标名称
                parent, target, target_name = _get_submodules(model, key)
            except AttributeError:
                # 如果发生属性错误,则继续下一个模块
                continue
            # 如果目标具有 "base_layer" 属性
            if hasattr(target, "base_layer"):
                # 用目标的基础层替换父模块中的目标
                setattr(parent, target_name, target.get_base_layer())
    else:
        # 处理与 PEFT <= 0.6.2 的向后兼容性
        # TODO: 一旦不再支持该 PEFT 版本,可以移除此代码
        from peft.tuners.lora import LoraLayer  # 导入 LoraLayer 模块

        # 遍历模型的所有子模块
        for name, module in model.named_children():
            # 如果子模块有子模块,则递归进入处理
            if len(list(module.children())) > 0:
                ## 复合模块,递归处理其内部的层
                recurse_remove_peft_layers(module)

            module_replaced = False  # 初始化标志,表示模块是否被替换

            # 检查当前模块是否为 LoraLayer 且为线性层
            if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
                # 创建新的线性层,保留输入和输出特征及偏置信息
                new_module = torch.nn.Linear(
                    module.in_features,  # 输入特征数量
                    module.out_features,  # 输出特征数量
                    bias=module.bias is not None,  # 是否使用偏置
                ).to(module.weight.device)  # 将新模块移动到当前模块权重的设备上
                new_module.weight = module.weight  # 复制权重
                if module.bias is not None:
                    new_module.bias = module.bias  # 复制偏置

                module_replaced = True  # 标记模块已被替换
            # 检查当前模块是否为 LoraLayer 且为卷积层
            elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
                # 创建新的卷积层,保留卷积参数
                new_module = torch.nn.Conv2d(
                    module.in_channels,  # 输入通道数
                    module.out_channels,  # 输出通道数
                    module.kernel_size,  # 卷积核大小
                    module.stride,  # 步幅
                    module.padding,  # 填充
                    module.dilation,  # 膨胀
                    module.groups,  # 分组卷积
                ).to(module.weight.device)  # 将新模块移动到当前模块权重的设备上

                new_module.weight = module.weight  # 复制权重
                if module.bias is not None:
                    new_module.bias = module.bias  # 复制偏置

                module_replaced = True  # 标记模块已被替换

            # 如果模块被替换,则更新模型
            if module_replaced:
                setattr(model, name, new_module)  # 更新模型中的模块
                del module  # 删除旧模块

                # 如果可用,则清空 CUDA 缓存
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()  # 释放 CUDA 内存
    return model  # 返回更新后的模型
# 定义一个函数来调整模型的 LoRA 层的权重
def scale_lora_layers(model, weight):
    """
    调整模型的 LoRA 层的权重。

    参数:
        model (`torch.nn.Module`):
            需要调整的模型。
        weight (`float`):
            要分配给 LoRA 层的权重。
    """
    # 从 peft.tuners.tuners_utils 导入 BaseTunerLayer 类
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 如果权重为 1.0,直接返回,不做任何调整
    if weight == 1.0:
        return

    # 遍历模型的所有模块
    for module in model.modules():
        # 检查当前模块是否是 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 调整当前 LoRA 层的权重
            module.scale_layer(weight)


# 定义一个函数来移除先前给定的 LoRA 层的权重
def unscale_lora_layers(model, weight: Optional[float] = None):
    """
    移除模型的 LoRA 层的权重。

    参数:
        model (`torch.nn.Module`):
            需要调整的模型。
        weight (`float`, *可选*):
            要分配给 LoRA 层的权重。如果未传入权重,将重新初始化 LoRA 层的权重。如果传入 0.0,将以正确的值重新初始化权重。
    """
    # 从 peft.tuners.tuners_utils 导入 BaseTunerLayer 类
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 如果权重为 1.0,直接返回,不做任何调整
    if weight == 1.0:
        return

    # 遍历模型的所有模块
    for module in model.modules():
        # 检查当前模块是否是 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 如果传入了权重且权重不为 0
            if weight is not None and weight != 0:
                # 移除当前 LoRA 层的权重
                module.unscale_layer(weight)
            # 如果传入的权重为 0
            elif weight is not None and weight == 0:
                # 遍历当前模块的所有活动适配器
                for adapter_name in module.active_adapters:
                    # 如果权重为 0,则将权重重置为原始值
                    module.set_scale(adapter_name, 1.0)


# 定义一个函数来获取 PEFT 的参数
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
    # 初始化排名模式和 alpha 模式的字典
    rank_pattern = {}
    alpha_pattern = {}
    # 获取 rank_dict 的第一个值作为 lora_alpha
    r = lora_alpha = list(rank_dict.values())[0]

    # 如果 rank_dict 中的值不全相同
    if len(set(rank_dict.values())) > 1:
        # 获取出现次数最多的 rank
        r = collections.Counter(rank_dict.values()).most_common()[0][0]

        # 对于排名与最常见排名不同的模块,将其添加到 rank_pattern 中
        rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
        # 提取模块名称,去掉 ".lora_B." 后的部分,并保存到 rank_pattern 中
        rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
    # 检查网络 alpha 字典是否不为 None 且包含元素
    if network_alpha_dict is not None and len(network_alpha_dict) > 0:
        # 检查网络 alpha 字典中是否有超过一个不同的值
        if len(set(network_alpha_dict.values())) > 1:
            # 获取出现次数最多的 alpha
            lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

            # 对于与出现次数最多的 alpha 不同的模块,将其添加到 `alpha_pattern`
            alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
            # 如果是 UNet 模型
            if is_unet:
                # 处理模块名称,去掉 ".lora_A." 和 ".alpha"
                alpha_pattern = {
                    ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
                    for k, v in alpha_pattern.items()
                }
            else:
                # 处理模块名称,去掉 ".down." 后的部分
                alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
        else:
            # 如果只有一个不同的 alpha,直接取出该值
            lora_alpha = set(network_alpha_dict.values()).pop()

    # 获取不包含 Diffusers 特定的层名称,去重
    target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
    # 检查 peft_state_dict 中是否有使用 "lora_magnitude_vector"
    use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)

    # 构建 lora 配置参数字典
    lora_config_kwargs = {
        "r": r,
        "lora_alpha": lora_alpha,
        "rank_pattern": rank_pattern,
        "alpha_pattern": alpha_pattern,
        "target_modules": target_modules,
        "use_dora": use_dora,
    }
    # 返回 lora 配置参数字典
    return lora_config_kwargs
# 获取模型中适配器的名称,根据 BaseTunerLayer 的数量返回
def get_adapter_name(model):
    # 从 PEFT 库中导入基础调优层
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 遍历模型的所有模块
    for module in model.modules():
        # 检查模块是否为 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 返回适配器的名称,格式为 "default_" 加上适配器数量
            return f"default_{len(module.r)}"
    # 如果没有找到适配器,返回默认名称 "default_0"
    return "default_0"


# 设置模型的适配层,可启用或禁用
def set_adapter_layers(model, enabled=True):
    # 从 PEFT 库中导入基础调优层
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 遍历模型的所有模块
    for module in model.modules():
        # 检查模块是否为 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 检查模块是否具备 enable_adapters 方法
            if hasattr(module, "enable_adapters"):
                # 调用 enable_adapters 方法启用或禁用适配器
                module.enable_adapters(enabled=enabled)
            else:
                # 通过禁用状态设置 disable_adapters 属性
                module.disable_adapters = not enabled


# 删除模型中的适配层
def delete_adapter_layers(model, adapter_name):
    # 从 PEFT 库中导入基础调优层
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 遍历模型的所有模块
    for module in model.modules():
        # 检查模块是否为 BaseTunerLayer 的实例
        if isinstance(module, BaseTunerLayer):
            # 检查模块是否具备 delete_adapter 方法
            if hasattr(module, "delete_adapter"):
                # 调用 delete_adapter 方法删除指定适配器
                module.delete_adapter(adapter_name)
            else:
                # 抛出错误,提示 PEFT 版本不兼容
                raise ValueError(
                    "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
                )

    # 检查模型是否已加载 PEFT 配置
    if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
        # 从配置中删除适配器
        model.peft_config.pop(adapter_name, None)
        # 如果所有适配器都已删除,删除配置并重置标志
        if len(model.peft_config) == 0:
            del model.peft_config
            model._hf_peft_config_loaded = None


# 设置适配器的权重并激活它们
def set_weights_and_activate_adapters(model, adapter_names, weights):
    # 从 PEFT 库中导入基础调优层
    from peft.tuners.tuners_utils import BaseTunerLayer

    # 定义获取模块权重的内部函数
    def get_module_weight(weight_for_adapter, module_name):
        # 如果权重不是字典,直接返回该权重
        if not isinstance(weight_for_adapter, dict):
            return weight_for_adapter

        # 遍历权重字典,查找对应模块的权重
        for layer_name, weight_ in weight_for_adapter.items():
            if layer_name in module_name:
                return weight_

        # 分割模块名称为部分
        parts = module_name.split(".")
        # 生成关键字以获取块权重
        key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
        block_weight = weight_for_adapter.get(key, 1.0)

        return block_weight

    # 遍历每个适配器,使其激活并设置对应的缩放权重
    for adapter_name, weight in zip(adapter_names, weights):
        for module_name, module in model.named_modules():
            # 检查模块是否为 BaseTunerLayer 的实例
            if isinstance(module, BaseTunerLayer):
                # 检查模块是否具备 set_adapter 方法,以兼容旧版本
                if hasattr(module, "set_adapter"):
                    # 设置适配器名称
                    module.set_adapter(adapter_name)
                else:
                    # 设置当前激活的适配器名称
                    module.active_adapter = adapter_name
                # 设置适配器的缩放权重
                module.set_scale(adapter_name, get_module_weight(weight, module_name))

    # 设置多个激活的适配器
    # 遍历模型中的所有模块
        for module in model.modules():
            # 检查当前模块是否为 BaseTunerLayer 的实例
            if isinstance(module, BaseTunerLayer):
                # 为了与以前的 PEFT 版本保持向后兼容
                if hasattr(module, "set_adapter"):
                    # 如果模块具有 set_adapter 方法,则调用该方法并传入适配器名称
                    module.set_adapter(adapter_names)
                else:
                    # 如果没有 set_adapter 方法,则直接设置 active_adapter 属性
                    module.active_adapter = adapter_names
# 检查 PEFT 的版本是否兼容
def check_peft_version(min_version: str) -> None:
    # 文档字符串,说明该函数的作用和参数
    r"""
    Checks if the version of PEFT is compatible.

    Args:
        version (`str`):
            The version of PEFT to check against.
    """
    # 检查 PEFT 是否可用,若不可用则抛出异常
    if not is_peft_available():
        raise ValueError("PEFT is not installed. Please install it with `pip install peft`")

    # 获取当前 PEFT 版本并与最小版本进行比较
    is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)

    # 若版本不兼容,则抛出异常并提示用户使用兼容版本
    if not is_peft_version_compatible:
        raise ValueError(
            f"The version of PEFT you are using is not compatible, please use a version that is greater"
            f" than {min_version}"
        )

.\diffusers\utils\pil_utils.py

# 从类型提示模块导入 List
from typing import List

# 导入 PIL.Image 模块用于图像处理
import PIL.Image
# 导入 PIL.ImageOps 模块提供图像操作功能
import PIL.ImageOps
# 从 packaging 模块导入 version 用于版本比较
from packaging import version
# 从 PIL 导入 Image 类用于图像处理
from PIL import Image


# 检查 PIL 的版本是否大于或等于 9.1.0
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    # 定义图像插值方式的字典,使用新版的 Resampling 属性
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    # 定义图像插值方式的字典,使用旧版的插值常量
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }


# 将 torch 图像转换为 PIL 图像的函数
def pt_to_pil(images):
    """
    Convert a torch image to a PIL image.
    """
    # 将图像标准化到 [0, 1] 范围内
    images = (images / 2 + 0.5).clamp(0, 1)
    # 将图像转移到 CPU,并调整维度为 (batch, height, width, channels)
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    # 调用 numpy_to_pil 函数将 numpy 图像转换为 PIL 图像
    images = numpy_to_pil(images)
    # 返回 PIL 图像
    return images


# 将 numpy 图像或图像批量转换为 PIL 图像的函数
def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    # 如果图像是三维的,增加一个维度以表示批量
    if images.ndim == 3:
        images = images[None, ...]
    # 将图像值转换为 0-255 的整数,并取整
    images = (images * 255).round().astype("uint8")
    # 检查是否为单通道图像(灰度图像)
    if images.shape[-1] == 1:
        # 对于灰度图像,使用模式 "L" 创建 PIL 图像
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        # 对于彩色图像,直接从数组创建 PIL 图像
        pil_images = [Image.fromarray(image) for image in images]

    # 返回 PIL 图像列表
    return pil_images


# 创建图像网格的函数,便于可视化
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image:
    """
    Prepares a single grid of images. Useful for visualization purposes.
    """
    # 断言图像数量与网格大小一致
    assert len(images) == rows * cols

    # 如果指定了调整大小,则对每个图像进行调整
    if resize is not None:
        images = [img.resize((resize, resize)) for img in images]

    # 获取每个图像的宽度和高度
    w, h = images[0].size
    # 创建一个新的 RGB 图像,用于放置网格
    grid = Image.new("RGB", size=(cols * w, rows * h))

    # 将每个图像放置到网格中
    for i, img in enumerate(images):
        # 计算每个图像的放置位置
        grid.paste(img, box=(i % cols * w, i // cols * h))
    # 返回生成的图像网格
    return grid

.\diffusers\utils\state_dict_utils.py

# 版权声明,指定版权信息及保留权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 许可声明,指定使用文件的条件和限制
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 许可地址,提供获取许可的链接
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 声明在适用情况下,软件按“原样”分发,且没有任何形式的保证或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可以了解特定权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
"""
状态字典工具:用于轻松转换状态字典的实用方法
"""

# 导入枚举模块
import enum

# 从日志模块导入获取记录器的函数
from .logging import get_logger

# 创建一个记录器实例,用于当前模块的日志记录
logger = get_logger(__name__)

# 定义状态字典类型的枚举类
class StateDictType(enum.Enum):
    """
    用于转换状态字典时使用的模式。
    """

    # 指定不同的状态字典类型
    DIFFUSERS_OLD = "diffusers_old"
    KOHYA_SS = "kohya_ss"
    PEFT = "peft"
    DIFFUSERS = "diffusers"

# 定义 Unet 到 Diffusers 的映射,因为它使用不同的输出键与文本编码器
# 例如:to_q_lora -> q_proj / to_q
UNET_TO_DIFFUSERS = {
    # 映射 Unet 的上输出到 Diffusers 的对应输出
    ".to_out_lora.up": ".to_out.0.lora_B",
    ".to_out_lora.down": ".to_out.0.lora_A",
    ".to_q_lora.down": ".to_q.lora_A",
    ".to_q_lora.up": ".to_q.lora_B",
    ".to_k_lora.down": ".to_k.lora_A",
    ".to_k_lora.up": ".to_k.lora_B",
    ".to_v_lora.down": ".to_v.lora_A",
    ".to_v_lora.up": ".to_v.lora_B",
    ".lora.up": ".lora_B",
    ".lora.down": ".lora_A",
    ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
}

# 定义 Diffusers 到 PEFT 的映射
DIFFUSERS_TO_PEFT = {
    # 映射 Diffusers 的层到 PEFT 的对应层
    ".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
    ".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
    ".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
    ".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
    ".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
    ".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
    ".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
    ".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
    ".lora_linear_layer.up": ".lora_B",
    ".lora_linear_layer.down": ".lora_A",
    "text_projection.lora.down.weight": "text_projection.lora_A.weight",
    "text_projection.lora.up.weight": "text_projection.lora_B.weight",
}

# 定义 Diffusers_old 到 PEFT 的映射
DIFFUSERS_OLD_TO_PEFT = {
    # 映射旧版本 Diffusers 的层到 PEFT 的对应层
    ".to_q_lora.up": ".q_proj.lora_B",
    ".to_q_lora.down": ".q_proj.lora_A",
    ".to_k_lora.up": ".k_proj.lora_B",
    ".to_k_lora.down": ".k_proj.lora_A",
    ".to_v_lora.up": ".v_proj.lora_B",
    ".to_v_lora.down": ".v_proj.lora_A",
    ".to_out_lora.up": ".out_proj.lora_B",
    ".to_out_lora.down": ".out_proj.lora_A",
    ".lora_linear_layer.up": ".lora_B",
    ".lora_linear_layer.down": ".lora_A",
}

# 定义 PEFT 到 Diffusers 的映射
PEFT_TO_DIFFUSERS = {
    # 映射 PEFT 的层到 Diffusers 的对应层
    ".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
    ".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
    ".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
    ".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
    # 将 lora_B 关联到 lora_linear_layer 的上层
        ".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
        # 将 lora_A 关联到 lora_linear_layer 的下层
        ".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
        # 将 out_proj 的 lora_B 关联到 lora_linear_layer 的上层
        ".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
        # 将 out_proj 的 lora_A 关联到 lora_linear_layer 的下层
        ".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
        # 将 to_k 的 lora_A 关联到 lora 的下层
        "to_k.lora_A": "to_k.lora.down",
        # 将 to_k 的 lora_B 关联到 lora 的上层
        "to_k.lora_B": "to_k.lora.up",
        # 将 to_q 的 lora_A 关联到 lora 的下层
        "to_q.lora_A": "to_q.lora.down",
        # 将 to_q 的 lora_B 关联到 lora 的上层
        "to_q.lora_B": "to_q.lora.up",
        # 将 to_v 的 lora_A 关联到 lora 的下层
        "to_v.lora_A": "to_v.lora.down",
        # 将 to_v 的 lora_B 关联到 lora 的上层
        "to_v.lora_B": "to_v.lora.up",
        # 将 to_out.0 的 lora_A 关联到 lora 的下层
        "to_out.0.lora_A": "to_out.0.lora.down",
        # 将 to_out.0 的 lora_B 关联到 lora 的上层
        "to_out.0.lora_B": "to_out.0.lora.up",
# 结束前一个代码块
}

# 定义一个字典,将旧的扩散器键映射到新的扩散器键
DIFFUSERS_OLD_TO_DIFFUSERS = {
    # 映射旧的 Q 线性层上升键到新的 Q 线性层上升键
    ".to_q_lora.up": ".q_proj.lora_linear_layer.up",
    # 映射旧的 Q 线性层下降键到新的 Q 线性层下降键
    ".to_q_lora.down": ".q_proj.lora_linear_layer.down",
    # 映射旧的 K 线性层上升键到新的 K 线性层上升键
    ".to_k_lora.up": ".k_proj.lora_linear_layer.up",
    # 映射旧的 K 线性层下降键到新的 K 线性层下降键
    ".to_k_lora.down": ".k_proj.lora_linear_layer.down",
    # 映射旧的 V 线性层上升键到新的 V 线性层上升键
    ".to_v_lora.up": ".v_proj.lora_linear_layer.up",
    # 映射旧的 V 线性层下降键到新的 V 线性层下降键
    ".to_v_lora.down": ".v_proj.lora_linear_layer.down",
    # 映射旧的输出层上升键到新的输出层上升键
    ".to_out_lora.up": ".out_proj.lora_linear_layer.up",
    # 映射旧的输出层下降键到新的输出层下降键
    ".to_out_lora.down": ".out_proj.lora_linear_layer.down",
    # 映射旧的 K 大小向量键到新的 K 大小向量键
    ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
    # 映射旧的 V 大小向量键到新的 V 大小向量键
    ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
    # 映射旧的 Q 大小向量键到新的 Q 大小向量键
    ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
    # 映射旧的输出层大小向量键到新的输出层大小向量键
    ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
}

# 定义一个字典,将 PEFT 格式映射到 KOHYA_SS 格式
PEFT_TO_KOHYA_SS = {
    # 映射 PEFT 格式中的 A 到 KOHYA_SS 格式中的下降
    "lora_A": "lora_down",
    # 映射 PEFT 格式中的 B 到 KOHYA_SS 格式中的上升
    "lora_B": "lora_up",
    # 这不是一个全面的字典,因为 KOHYA 格式需要替换键中的 `.` 为 `_`,
    # 添加前缀和添加 alpha 值
    # 检查 `convert_state_dict_to_kohya` 以了解更多
}

# 定义一个字典,将状态字典类型映射到相应的 DIFFUSERS 映射
PEFT_STATE_DICT_MAPPINGS = {
    # 映射旧扩散器类型到 PEFT
    StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
    # 映射新扩散器类型到 PEFT
    StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
}

# 定义一个字典,将状态字典类型映射到相应的 DIFFUSERS 映射
DIFFUSERS_STATE_DICT_MAPPINGS = {
    # 映射旧扩散器类型到旧扩散器映射
    StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
    # 映射 PEFT 类型到 DIFFUSERS
    StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}

# 定义一个字典,将 PEFT 类型映射到 KOHYA 状态字典映射
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}

# 定义一个字典,指定总是要替换的键模式
KEYS_TO_ALWAYS_REPLACE = {
    # 将处理器的键模式替换为基本形式
    ".processor.": ".",
}

# 定义函数,转换状态字典
def convert_state_dict(state_dict, mapping):
    r"""
    简单地遍历状态字典并用 `mapping` 中的模式替换相应的值。

    参数:
        state_dict (`dict[str, torch.Tensor]`):
            要转换的状态字典。
        mapping (`dict[str, str]`):
            用于转换的映射,映射应为以下结构的字典:
                - 键: 要替换的模式
                - 值: 要替换成的模式

    返回:
        converted_state_dict (`dict`)
            转换后的状态字典。
    """
    # 初始化一个新的转换后的状态字典
    converted_state_dict = {}
    # 遍历输入的状态字典
    for k, v in state_dict.items():
        # 首先,过滤出我们总是想替换的键
        for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
            # 如果当前键中包含模式,则替换
            if pattern in k:
                new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
                k = k.replace(pattern, new_pattern)

        # 遍历映射中的模式
        for pattern in mapping.keys():
            # 如果当前键中包含模式,则替换
            if pattern in k:
                new_pattern = mapping[pattern]
                k = k.replace(pattern, new_pattern)
                break
        # 将转换后的键值对添加到新的字典中
        converted_state_dict[k] = v
    # 返回转换后的状态字典
    return converted_state_dict

# 定义函数,将状态字典转换为 PEFT 格式
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
    r"""
    将状态字典转换为 PEFT 格式,状态字典可以来自旧的扩散器格式(`OLD_DIFFUSERS`)或
    新的扩散器格式(`DIFFUSERS`)。该方法目前仅支持从扩散器旧/新格式到 PEFT 的转换。
    # 参数说明部分
    Args:
        state_dict (`dict[str, torch.Tensor]`):
            # 要转换的状态字典
            The state dict to convert.
        original_type (`StateDictType`, *optional*):
            # 状态字典的原始类型,如果未提供,方法将尝试自动推断
            The original type of the state dict, if not provided, the method will try to infer it automatically.
    """
    # 如果原始类型未提供
    if original_type is None:
        # 检查状态字典的键中是否包含“to_out_lora”,用于判断类型
        # 将旧的 diffusers 类型转换为 PEFT
        if any("to_out_lora" in k for k in state_dict.keys()):
            original_type = StateDictType.DIFFUSERS_OLD
        # 检查状态字典的键中是否包含“lora_linear_layer”
        elif any("lora_linear_layer" in k for k in state_dict.keys()):
            original_type = StateDictType.DIFFUSERS
        # 如果无法推断类型,则抛出错误
        else:
            raise ValueError("Could not automatically infer state dict type")

    # 检查推断的原始类型是否在支持的类型映射中
    if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
        # 如果不支持,则抛出错误
        raise ValueError(f"Original type {original_type} is not supported")

    # 根据原始类型获取对应的映射
    mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
    # 转换状态字典并返回结果
    return convert_state_dict(state_dict, mapping)
# 将状态字典转换为新的 diffusers 格式。状态字典可以来自旧的 diffusers 格式
# (`OLD_DIFFUSERS`)、PEFT 格式 (`PEFT`) 或新的 diffusers 格式 (`DIFFUSERS`)。
# 在最后一种情况下,该方法将返回状态字典本身。
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
    # 状态字典转换为新格式的文档字符串
    r"""
    Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
    (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
    return the state dict as is.

    The method only supports the conversion from diffusers old, PEFT to diffusers new for now.

    Args:
        state_dict (`dict[str, torch.Tensor]`):
            The state dict to convert.
        original_type (`StateDictType`, *optional*):
            The original type of the state dict, if not provided, the method will try to infer it automatically.
        kwargs (`dict`, *args*):
            Additional arguments to pass to the method.

            - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
                with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
                `get_peft_model_state_dict` method:
                https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
                but we add it here in case we don't want to rely on that method.
    """
    # 从 kwargs 中获取适配器名称,如果不存在则默认为 None
    peft_adapter_name = kwargs.pop("adapter_name", None)
    # 如果适配器名称不为 None,前面添加一个点
    if peft_adapter_name is not None:
        peft_adapter_name = "." + peft_adapter_name
    else:
        # 否则适配器名称为空字符串
        peft_adapter_name = ""

    # 如果没有提供原始类型
    if original_type is None:
        # 检查状态字典的键是否包含 "to_out_lora",若有则设置原始类型为旧 diffusers
        if any("to_out_lora" in k for k in state_dict.keys()):
            original_type = StateDictType.DIFFUSERS_OLD
        # 检查键是否包含以适配器名称为前缀的权重
        elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
            original_type = StateDictType.PEFT
        # 检查是否包含 "lora_linear_layer",如果有则不需要转换,直接返回状态字典
        elif any("lora_linear_layer" in k for k in state_dict.keys()):
            # nothing to do
            return state_dict
        # 如果未能推断出原始类型,则引发值错误
        else:
            raise ValueError("Could not automatically infer state dict type")

    # 检查原始类型是否在支持的状态字典映射中
    if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
        # 如果不支持,抛出值错误
        raise ValueError(f"Original type {original_type} is not supported")

    # 获取与原始类型相对应的映射
    mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
    # 使用映射转换状态字典并返回
    return convert_state_dict(state_dict, mapping)


# 将状态字典从 UNet 格式转换为 diffusers 格式,主要通过移除一些键来实现
def convert_unet_state_dict_to_peft(state_dict):
    # 状态字典转换文档字符串
    r"""
    Converts a state dict from UNet format to diffusers format - i.e. by removing some keys
    """
    # 定义 UNet 到 diffusers 的映射
    mapping = UNET_TO_DIFFUSERS
    # 使用映射转换状态字典并返回
    return convert_state_dict(state_dict, mapping)


# 尝试首先将状态字典转换为 PEFT 格式,如果没有检测到有效的 DIFFUSERS LoRA 的 "lora_linear_layer"
# 则尝试专门转换 UNet 状态字典
def convert_all_state_dict_to_peft(state_dict):
    # 状态字典转换为 PEFT 格式的文档字符串
    r"""
    Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
    `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
    """
    # 尝试转换状态字典为 PEFT 格式
    try:
        peft_dict = convert_state_dict_to_peft(state_dict)
    # 捕获异常并赋值给变量 e
        except Exception as e:
            # 检查异常信息是否为无法自动推断状态字典类型
            if str(e) == "Could not automatically infer state dict type":
                # 将 UNet 状态字典转换为 PEFT 字典
                peft_dict = convert_unet_state_dict_to_peft(state_dict)
            else:
                # 重新抛出未处理的异常
                raise
    
        # 检查 PEFT 字典中是否包含 "lora_A" 或 "lora_B" 的键
        if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
            # 如果没有,则抛出值错误异常
            raise ValueError("Your LoRA was not converted to PEFT")
    
        # 返回转换后的 PEFT 字典
        return peft_dict
# 定义一个将 PEFT 状态字典转换为 Kohya 格式的函数
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
    r"""
    将 `PEFT` 状态字典转换为可在 AUTOMATIC1111、ComfyUI、SD.Next、InvokeAI 等中使用的 `Kohya` 格式。
    该方法目前仅支持从 PEFT 到 Kohya 的转换。

    参数:
        state_dict (`dict[str, torch.Tensor]`):
            要转换的状态字典。
        original_type (`StateDictType`, *可选*):
            状态字典的原始类型,如果未提供,方法将尝试自动推断。
        kwargs (`dict`, *args*):
            传递给该方法的附加参数。

            - **adapter_name**: 例如,在 PEFT 的情况下,一些键会被适配器名称预先附加,
                因此需要特殊处理。默认情况下,PEFT 也会在
                `get_peft_model_state_dict` 方法中处理这一点:
                https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
                但我们在这里添加它以防我们不想依赖该方法。
    """
    # 尝试导入 torch 库
    try:
        import torch
    # 如果导入失败,记录错误并引发异常
    except ImportError:
        logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
        raise

    # 从 kwargs 中弹出适配器名称,如果没有提供则为 None
    peft_adapter_name = kwargs.pop("adapter_name", None)
    # 如果提供了适配器名称,则在前面加上点
    if peft_adapter_name is not None:
        peft_adapter_name = "." + peft_adapter_name
    # 如果没有适配器名称,则设置为空字符串
    else:
        peft_adapter_name = ""

    # 如果未提供原始类型,则检查状态字典中是否包含特定键
    if original_type is None:
        if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
            original_type = StateDictType.PEFT

    # 检查原始类型是否在支持的类型映射中
    if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
        raise ValueError(f"Original type {original_type} is not supported")

    # 使用适当的映射调用 convert_state_dict 函数
    kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
    # 创建一个空字典以存储最终的 Kohya 状态字典
    kohya_ss_state_dict = {}

    # 额外逻辑:在所有键中替换头部、alpha 参数的 `.` 为 `_`
    for kohya_key, weight in kohya_ss_partial_state_dict.items():
        # 替换特定的键名
        if "text_encoder_2." in kohya_key:
            kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
        elif "text_encoder." in kohya_key:
            kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
        elif "unet" in kohya_key:
            kohya_key = kohya_key.replace("unet", "lora_unet")
        elif "lora_magnitude_vector" in kohya_key:
            kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")

        # 将所有键中的 `.` 替换为 `_`,保留最后两个 `.` 不变
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        # 移除适配器名称,Kohya 不需要这些名称
        kohya_key = kohya_key.replace(peft_adapter_name, "")
        # 将处理后的键及其权重存储到新的字典中
        kohya_ss_state_dict[kohya_key] = weight
        # 如果键名中包含 `lora_down`,则创建相应的 alpha 键
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            # 将 alpha 键的值设置为权重的长度
            kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
    # 返回保存的状态字典,通常用于恢复训练或推理的状态
    return kohya_ss_state_dict

.\diffusers\utils\testing_utils.py

# 导入所需的标准库和第三方库
import functools  # 提供高阶函数的工具
import importlib  # 处理动态导入模块
import inspect  # 获取对象的内部信息
import io  # 提供用于处理流的基本工具
import logging  # 提供记录日志的功能
import multiprocessing  # 处理多进程的支持
import os  # 提供与操作系统交互的功能
import random  # 生成随机数的工具
import re  # 处理正则表达式
import struct  # 处理 C 语言结构体的工具
import sys  # 处理 Python 解释器的变量和函数
import tempfile  # 创建临时文件的工具
import time  # 提供时间相关的功能
import unittest  # 提供单元测试的框架
import urllib.parse  # 处理 URL 的解析和构造
from contextlib import contextmanager  # 提供上下文管理器的功能
from io import BytesIO, StringIO  # 提供内存中的字节流和字符串流
from pathlib import Path  # 处理文件路径的工具
from typing import Callable, Dict, List, Optional, Union  # 提供类型注释的工具

import numpy as np  # 导入 NumPy 库
import PIL.Image  # 导入 PIL 图像处理库
import PIL.ImageOps  # 导入 PIL 图像操作功能
import requests  # 处理 HTTP 请求的库
from numpy.linalg import norm  # 导入计算向量范数的函数
from packaging import version  # 处理版本比较的工具

# 导入自定义工具模块中的所需功能
from .import_utils import (
    BACKENDS_MAPPING,  # 后端映射的字典
    is_compel_available,  # 检查 Compel 是否可用的函数
    is_flax_available,  # 检查 Flax 是否可用的函数
    is_note_seq_available,  # 检查 NoteSeq 是否可用的函数
    is_onnx_available,  # 检查 ONNX 是否可用的函数
    is_opencv_available,  # 检查 OpenCV 是否可用的函数
    is_peft_available,  # 检查 PEFT 是否可用的函数
    is_timm_available,  # 检查 TIMM 是否可用的函数
    is_torch_available,  # 检查 PyTorch 是否可用的函数
    is_torch_version,  # 检查 PyTorch 版本的函数
    is_torchsde_available,  # 检查 TorchSDE 是否可用的函数
    is_transformers_available,  # 检查 Transformers 是否可用的函数
)
from .logging import get_logger  # 导入自定义日志记录功能

# 创建全局随机数生成器
global_rng = random.Random()

# 获取当前模块的日志记录器
logger = get_logger(__name__)

# 检查是否满足 PEFT 的版本要求
_required_peft_version = is_peft_available() and version.parse(
    version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")  # 检查 PEFT 版本是否大于 0.5

# 检查是否满足 Transformers 的版本要求
_required_transformers_version = is_transformers_available() and version.parse(
    version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")  # 检查 Transformers 版本是否大于 4.33

# 根据版本要求确定是否使用 PEFT 后端
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version

# 如果 PyTorch 可用,执行以下代码
if is_torch_available():
    import torch  # 导入 PyTorch 库

    # 设置用于自定义加速器的后端环境变量
    if "DIFFUSERS_TEST_BACKEND" in os.environ:
        backend = os.environ["DIFFUSERS_TEST_BACKEND"]  # 获取环境变量中的后端名称
        try:
            _ = importlib.import_module(backend)  # 尝试导入指定的后端模块
        except ModuleNotFoundError as e:  # 捕获模块未找到异常
            raise ModuleNotFoundError(
                f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
                    to enable a specified backend.):\n{e}"  # 提示用户导入失败
            ) from e  # 抛出原始异常

    # 如果指定了设备环境变量
    if "DIFFUSERS_TEST_DEVICE" in os.environ:
        torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]  # 获取指定的设备名称
        try:
            # 尝试创建设备以验证提供的设备是否有效
            _ = torch.device(torch_device)
        except RuntimeError as e:  # 捕获运行时异常
            raise RuntimeError(
                f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}"  # 提示用户设备未知
            ) from e  # 抛出原始异常
        logger.info(f"torch_device overrode to {torch_device}")  # 记录当前使用的设备名称
    else:
        # 默认情况下根据 CUDA 可用性选择设备
        torch_device = "cuda" if torch.cuda.is_available() else "cpu"
        is_torch_higher_equal_than_1_12 = version.parse(
            version.parse(torch.__version__).base_version
        ) >= version.parse("1.12")  # 检查 PyTorch 版本是否大于等于 1.12

        if is_torch_higher_equal_than_1_12:  # 如果 PyTorch 版本符合条件
            # 某些版本的 PyTorch 1.12 未注册 mps 后端,需查看相关问题
            mps_backend_registered = hasattr(torch.backends, "mps")  # 检查是否注册了 mps 后端
            # 根据 mps 后端可用性选择设备
            torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
# 定义一个函数,用于检查两个 PyTorch 张量是否相近
def torch_all_close(a, b, *args, **kwargs):
    # 检查是否安装了 PyTorch,如果没有,则引发异常
    if not is_torch_available():
        raise ValueError("PyTorch needs to be installed to use this function.")
    # 检查两个张量是否相近,如果不相近,则断言失败,并显示最大差异
    if not torch.allclose(a, b, *args, **kwargs):
        assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
    # 如果相近,返回 True
    return True


# 定义一个函数,计算两个 NumPy 数组之间的余弦相似度距离
def numpy_cosine_similarity_distance(a, b):
    # 计算两个数组的余弦相似度
    similarity = np.dot(a, b) / (norm(a) * norm(b))
    # 计算相似度的距离,距离越小越相似
    distance = 1.0 - similarity.mean()

    # 返回计算得到的距离
    return distance


# 定义一个函数,用于打印张量测试结果
def print_tensor_test(
    tensor,
    limit_to_slices=None,
    max_torch_print=None,
    filename="test_corrections.txt",
    expected_tensor_name="expected_slice",
):
    # 如果设置了最大打印数量,则更新 PyTorch 打印选项
    if max_torch_print:
        torch.set_printoptions(threshold=10_000)

    # 获取当前测试的名称
    test_name = os.environ.get("PYTEST_CURRENT_TEST")
    # 如果输入不是张量,则将其转换为张量
    if not torch.is_tensor(tensor):
        tensor = torch.from_numpy(tensor)
    # 如果设置了切片限制,则限制张量的切片
    if limit_to_slices:
        tensor = tensor[0, -3:, -3:, -1]

    # 将张量转换为字符串格式,并移除换行符
    tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
    # 将张量字符串格式化为 NumPy 数组形式
    output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
    # 分离测试名称为文件、类和函数名
    test_file, test_class, test_fn = test_name.split("::")
    test_fn = test_fn.split()[0]
    # 以附加模式打开文件并写入测试结果
    with open(filename, "a") as f:
        print("::".join([test_file, test_class, test_fn, output_str]), file=f)


# 定义一个函数,用于获取测试目录的路径
def get_tests_dir(append_path=None):
    """
    Args:
        append_path: optional path to append to the tests dir path
    Return:
        The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
        joined after the `tests` dir the former is provided.
    """
    # 获取调用该函数的文件的路径
    caller__file__ = inspect.stack()[1][1]
    # 获取调用文件所在目录的绝对路径
    tests_dir = os.path.abspath(os.path.dirname(caller__file__))

    # 循环查找直到找到以 "tests" 结尾的目录
    while not tests_dir.endswith("tests"):
        tests_dir = os.path.dirname(tests_dir)

    # 如果提供了附加路径,则返回合并后的完整路径
    if append_path:
        return Path(tests_dir, append_path).as_posix()
    else:
        # 否则返回测试目录路径
        return tests_dir


# 从 PR 中提取的函数
# https://github.com/huggingface/accelerate/pull/1964
def str_to_bool(value) -> int:
    """
    Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
    `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
    """
    # 将输入值转换为小写以进行比较
    value = value.lower()
    # 检查输入值是否是“真”值,返回 1
    if value in ("y", "yes", "t", "true", "on", "1"):
        return 1
    # 检查输入值是否是“假”值,返回 0
    elif value in ("n", "no", "f", "false", "off", "0"):
        return 0
    else:
        # 如果输入值无效,则引发异常
        raise ValueError(f"invalid truth value {value}")


# 定义一个函数,从环境变量中解析布尔标志
def parse_flag_from_env(key, default=False):
    try:
        # 尝试从环境变量中获取值
        value = os.environ[key]
    except KeyError:
        # 如果未设置环境变量,则使用默认值
        _value = default
    else:
        # 如果没有设置 KEY,进入此分支,准备将值转换为布尔值
        # 尝试将字符串值转换为布尔值(True 或 False)
        try:
            _value = str_to_bool(value)
        except ValueError:
            # 如果转换失败,抛出更具体的错误信息,提示值必须是 'yes' 或 'no'
            raise ValueError(f"If set, {key} must be yes or no.")
    # 返回转换后的布尔值
    return _value
# 从环境变量中解析是否运行慢测试的标志,默认为 False
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
# 从环境变量中解析是否运行夜间测试的标志,默认为 False
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
# 从环境变量中解析是否运行编译测试的标志,默认为 False
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)


def floats_tensor(shape, scale=1.0, rng=None, name=None):
    """创建一个随机的 float32 张量"""
    # 如果没有提供随机数生成器,则使用全局随机数生成器
    if rng is None:
        rng = global_rng

    # 初始化总维度为 1
    total_dims = 1
    # 计算张量的总元素数量
    for dim in shape:
        total_dims *= dim

    # 初始化一个空列表以存储随机值
    values = []
    # 根据总元素数量生成随机值
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    # 将生成的值转换为张量,并按照指定形状调整,返回连续的张量
    return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()


def slow(test_case):
    """
    装饰器标记一个测试为慢测试。

    慢测试默认被跳过。设置 RUN_SLOW 环境变量为真值以运行它们。
    """
    # 如果 _run_slow_tests 为真,则返回原测试,否则跳过测试并提示信息
    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)


def nightly(test_case):
    """
    装饰器标记一个每晚在 diffusers CI 中运行的测试。

    慢测试默认被跳过。设置 RUN_NIGHTLY 环境变量为真值以运行它们。
    """
    # 如果 _run_nightly_tests 为真,则返回原测试,否则跳过测试并提示信息
    return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)


def is_torch_compile(test_case):
    """
    装饰器标记一个在 diffusers CI 中运行的编译测试。

    编译测试默认被跳过。设置 RUN_COMPILE 环境变量为真值以运行它们。
    """
    # 如果 _run_compile_tests 为真,则返回原测试,否则跳过测试并提示信息
    return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)


def require_torch(test_case):
    """
    装饰器标记一个需要 PyTorch 的测试。未安装 PyTorch 时这些测试将被跳过。
    """
    # 如果 PyTorch 可用,则返回原测试,否则跳过测试并提示信息
    return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)


def require_torch_2(test_case):
    """
    装饰器标记一个需要 PyTorch 2 的测试。这些测试在未安装 PyTorch 2 时将被跳过。
    """
    # 检查 PyTorch 是否可用且版本大于等于 2.0.0,然后返回原测试,否则跳过测试
    return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
        test_case
    )


def require_torch_gpu(test_case):
    """装饰器标记一个需要 CUDA 和 PyTorch 的测试。"""
    # 检查 PyTorch 是否可用且当前设备为 CUDA,然后返回原测试,否则跳过测试
    return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
        test_case
    )


# 这些装饰器用于特定加速器行为,而不是仅限于 GPU
def require_torch_accelerator(test_case):
    """装饰器标记一个需要加速器后端和 PyTorch 的测试。"""
    # 检查 PyTorch 是否可用且当前设备不是 CPU,然后返回原测试,否则跳过测试
    return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
        test_case
    )


def require_torch_multi_gpu(test_case):
    """
    装饰器标记一个需要多 GPU 设置的测试(在 PyTorch 中)。这些测试在没有多个 GPU 的机器上将被跳过。
    若要仅运行 multi_gpu 测试,假设所有测试名称包含 multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
    """
    # 检查是否可以使用 PyTorch
        if not is_torch_available():
            # 如果不可用,则跳过测试,给出提示
            return unittest.skip("test requires PyTorch")(test_case)
    
        # 导入 PyTorch 库
        import torch
    
        # 跳过测试,除非 GPU 设备数量大于 1
        return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
# 装饰器,标记需要支持 FP16 数据类型的加速器的测试
def require_torch_accelerator_with_fp16(test_case):
    # 如果当前设备支持 FP16,则跳过此装饰器
    return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
        test_case
    )


# 装饰器,标记需要支持 FP64 数据类型的加速器的测试
def require_torch_accelerator_with_fp64(test_case):
    # 如果当前设备支持 FP64,则跳过此装饰器
    return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
        test_case
    )


# 装饰器,标记需要支持训练的加速器的测试
def require_torch_accelerator_with_training(test_case):
    # 如果当前设备可用且支持训练,则跳过此装饰器
    return unittest.skipUnless(
        is_torch_available() and backend_supports_training(torch_device),
        "test requires accelerator with training support",
    )(test_case)


# 装饰器,标记如果 torch_device 是 'mps' 则跳过测试
def skip_mps(test_case):
    # 如果当前设备不是 'mps',则跳过此装饰器
    return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)


# 装饰器,标记需要 JAX 和 Flax 的测试
def require_flax(test_case):
    # 如果 Flax 可用,则跳过此装饰器
    return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)


# 装饰器,标记需要 compel 库的测试
def require_compel(test_case):
    # 如果 compel 可用,则跳过此装饰器
    return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)


# 装饰器,标记需要 onnxruntime 的测试
def require_onnxruntime(test_case):
    # 如果 onnxruntime 可用,则跳过此装饰器
    return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)


# 装饰器,标记需要 note_seq 的测试
def require_note_seq(test_case):
    # 如果 note_seq 可用,则跳过此装饰器
    return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)


# 装饰器,标记需要 torchsde 的测试
def require_torchsde(test_case):
    # 如果 torchsde 可用,则跳过此装饰器
    return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)


# 装饰器,标记需要 PEFT 后端的测试
def require_peft_backend(test_case):
    # 如果需要 PEFT 后端,则跳过此装饰器
    return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)


# 装饰器,标记需要 timm 库的测试
def require_timm(test_case):
    # 如果 timm 可用,则跳过此装饰器
    return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)


# 装饰器,标记需要特定版本 PEFT 的测试
def require_peft_version_greater(peft_version):
    # 装饰器标记一个测试,该测试要求特定版本的 PEFT 后端,需满足特定版本
        """
    
        # 定义装饰器函数,接受一个测试用例作为参数
        def decorator(test_case):
            # 检查 PEFT 是否可用,并且当前版本是否大于指定的 PEFT 版本
            correct_peft_version = is_peft_available() and version.parse(
                version.parse(importlib.metadata.version("peft")).base_version
            ) > version.parse(peft_version)
            # 如果满足版本要求,则跳过此测试,并提供相应的提示信息
            return unittest.skipUnless(
                correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
            )(test_case)
    
        # 返回装饰器函数
        return decorator
# 定义一个装饰器,要求加速版本大于给定版本
def require_accelerate_version_greater(accelerate_version):
    # 装饰器内部函数,用于装饰测试用例
    def decorator(test_case):
        # 检查 PEFT 是否可用,并解析当前加速库版本
        correct_accelerate_version = is_peft_available() and version.parse(
            version.parse(importlib.metadata.version("accelerate")).base_version
        ) > version.parse(accelerate_version)
        # 根据版本判断是否跳过测试用例
        return unittest.skipUnless(
            correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
        )(test_case)

    return decorator


# 定义一个装饰器,标记将在 PEFT 后端之后跳过的测试
def deprecate_after_peft_backend(test_case):
    """
    装饰器,标记将在 PEFT 后端之后跳过的测试
    """
    # 根据是否使用 PEFT 后端决定是否跳过测试
    return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)


# 获取当前 Python 版本
def get_python_version():
    # 获取系统的版本信息
    sys_info = sys.version_info
    # 提取主版本和次版本
    major, minor = sys_info.major, sys_info.minor
    # 返回主版本和次版本
    return major, minor


# 加载 NumPy 数组,支持 URL 和本地路径
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
    # 检查 arry 是否为字符串
    if isinstance(arry, str):
        if local_path is not None:
            # local_path 用于修正测试的图像路径
            return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
        # 检查字符串是否为 URL
        elif arry.startswith("http://") or arry.startswith("https://"):
            response = requests.get(arry)  # 发送请求获取内容
            response.raise_for_status()  # 检查请求是否成功
            arry = np.load(BytesIO(response.content))  # 从响应内容加载 NumPy 数组
        # 检查字符串是否为有效的文件路径
        elif os.path.isfile(arry):
            arry = np.load(arry)  # 从文件路径加载 NumPy 数组
        else:
            # 抛出路径或 URL 不正确的错误
            raise ValueError(
                f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
            )
    # 检查 arry 是否为 NumPy 数组
    elif isinstance(arry, np.ndarray):
        pass  # 如果是 NumPy 数组则不做任何处理
    else:
        # 抛出格式不正确的错误
        raise ValueError(
            "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
            " ndarray."
        )

    return arry  # 返回处理后的数组


# 从给定 URL 加载 PyTorch 张量
def load_pt(url: str):
    response = requests.get(url)  # 发送请求获取内容
    response.raise_for_status()  # 检查请求是否成功
    arry = torch.load(BytesIO(response.content))  # 从响应内容加载 PyTorch 张量
    return arry  # 返回加载的张量


# 加载图像,支持字符串和 PIL.Image.Image 类型
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
    """
    将 `image` 加载为 PIL 图像。

    参数:
        image (`str` 或 `PIL.Image.Image`):
            要转换为 PIL 图像格式的图像。
    返回:
        `PIL.Image.Image`:
            一个 PIL 图像。
    """
    # 检查 image 是否为字符串
    if isinstance(image, str):
        # 检查字符串是否为 URL
        if image.startswith("http://") or image.startswith("https://"):
            image = PIL.Image.open(requests.get(image, stream=True).raw)  # 从 URL 加载图像
        # 检查字符串是否为有效的文件路径
        elif os.path.isfile(image):
            image = PIL.Image.open(image)  # 从文件路径加载图像
        else:
            # 抛出路径或 URL 不正确的错误
            raise ValueError(
                f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
            )
    # 检查 image 是否为 PIL 图像
    elif isinstance(image, PIL.Image.Image):
        image = image  # 如果是 PIL 图像则不做任何处理
    # 如果不是正确的格式,则引发一个值错误
    else:
        # 提供详细的错误信息,说明接受的格式要求
        raise ValueError(
            "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
        )
    # 对图像应用 EXIF 变换,调整其方向
    image = PIL.ImageOps.exif_transpose(image)
    # 将图像转换为 RGB 模式
    image = image.convert("RGB")
    # 返回处理后的图像
    return image
# 预处理输入图像以适应特定模型需求
def preprocess_image(image: PIL.Image, batch_size: int):
    # 获取图像的宽度和高度
    w, h = image.size
    # 调整宽度和高度为8的整数倍,以便于处理
    w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
    # 根据新的宽高调整图像大小,使用高质量重采样
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    # 将图像数据转换为numpy数组并归一化到[0, 1]区间
    image = np.array(image).astype(np.float32) / 255.0
    # 扩展图像维度并复制以形成batch
    image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
    # 将numpy数组转换为PyTorch张量
    image = torch.from_numpy(image)
    # 将图像值从[0, 1]范围缩放到[-1, 1]范围
    return 2.0 * image - 1.0


# 将图像序列导出为GIF文件
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
    # 如果未指定输出路径,则创建临时GIF文件
    if output_gif_path is None:
        output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name

    # 保存第一帧,并附加后续帧
    image[0].save(
        output_gif_path,
        save_all=True,
        append_images=image[1:],
        optimize=False,
        duration=100,
        loop=0,
    )
    # 返回生成的GIF文件路径
    return output_gif_path


# 上下文管理器,用于缓冲写入文件
@contextmanager
def buffered_writer(raw_f):
    # 创建一个缓冲写入对象
    f = io.BufferedWriter(raw_f)
    # 暴露缓冲写入对象给上下文
    yield f
    # 刷新缓冲区,确保所有数据写入
    f.flush()


# 导出网格为PLY文件
def export_to_ply(mesh, output_ply_path: str = None):
    """
    写入一个网格的PLY文件。
    """
    # 如果未指定输出路径,则创建临时PLY文件
    if output_ply_path is None:
        output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name

    # 获取顶点坐标并转移到CPU和NumPy数组
    coords = mesh.verts.detach().cpu().numpy()
    # 获取面索引并转移到CPU和NumPy数组
    faces = mesh.faces.cpu().numpy()
    # 获取RGB颜色通道数据并堆叠为RGB格式
    rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)

    # 使用缓冲写入器打开PLY文件进行写入
    with buffered_writer(open(output_ply_path, "wb")) as f:
        # 写入PLY文件头
        f.write(b"ply\n")
        f.write(b"format binary_little_endian 1.0\n")
        f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
        f.write(b"property float x\n")
        f.write(b"property float y\n")
        f.write(b"property float z\n")
        # 如果存在RGB数据,写入颜色属性
        if rgb is not None:
            f.write(b"property uchar red\n")
            f.write(b"property uchar green\n")
            f.write(b"property uchar blue\n")
        # 如果存在面数据,写入面索引
        if faces is not None:
            f.write(bytes(f"element face {len(faces)}\n", "ascii"))
            f.write(b"property list uchar int vertex_index\n")
        # 写入文件头结束标志
        f.write(b"end_header\n")

        # 如果存在RGB数据,处理顶点数据
        if rgb is not None:
            rgb = (rgb * 255.499).round().astype(int)
            vertices = [
                (*coord, *rgb)
                for coord, rgb in zip(
                    coords.tolist(),
                    rgb.tolist(),
                )
            ]
            # 定义数据打包格式
            format = struct.Struct("<3f3B")
            # 写入每个顶点的数据
            for item in vertices:
                f.write(format.pack(*item))
        else:
            # 定义顶点数据打包格式
            format = struct.Struct("<3f")
            # 写入每个顶点坐标
            for vertex in coords.tolist():
                f.write(format.pack(*vertex))

        # 如果存在面数据,处理面索引数据
        if faces is not None:
            format = struct.Struct("<B3I")
            for tri in faces.tolist():
                f.write(format.pack(len(tri), *tri))

    # 返回生成的PLY文件路径
    return output_ply_path


# 导出网格为OBJ文件
def export_to_obj(mesh, output_obj_path: str = None):
    # 如果未指定输出路径,则创建临时OBJ文件
    if output_obj_path is None:
        output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name

    # 获取顶点坐标并转移到CPU和NumPy数组
    verts = mesh.verts.detach().cpu().numpy()
    # 获取面索引并转移到CPU和NumPy数组
    faces = mesh.faces.cpu().numpy()
    # 将 mesh 中的顶点颜色通道提取并转换为 NumPy 数组,堆叠成一个数组
    vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
    # 将顶点坐标和颜色组合成字符串列表,格式化为 OBJ 文件所需的顶点定义
    vertices = [
        "{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
    ]

    # 将面数据格式化为 OBJ 文件所需的面定义,面索引从 0 转为从 1 开始
    faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]

    # 合并顶点数据和面数据,形成完整的 OBJ 数据列表
    combined_data = ["v " + vertex for vertex in vertices] + faces

    # 打开指定路径的文件,以写入模式创建文件对象 f
    with open(output_obj_path, "w") as f:
        # 将合并的数据写入文件,每个条目之间用换行符分隔
        f.writelines("\n".join(combined_data))
# 导出视频帧为视频文件,返回输出视频的路径
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
    # 检查 OpenCV 是否可用
    if is_opencv_available():
        # 导入 OpenCV 库
        import cv2
    else:
        # 如果不可用,抛出导入错误
        raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
    # 如果未指定输出视频路径,则创建临时文件
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    # 设置视频编码格式为 mp4
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    # 获取第一帧的高度、宽度和通道数
    h, w, c = video_frames[0].shape
    # 创建视频写入对象
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
    # 遍历每一帧视频
    for i in range(len(video_frames)):
        # 将帧从 RGB 转换为 BGR 格式
        img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
        # 将转换后的帧写入视频
        video_writer.write(img)
    # 返回输出视频的路径
    return output_video_path


# 加载 Hugging Face 上的 NumPy 数组
def load_hf_numpy(path) -> np.ndarray:
    # 定义基础 URL
    base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"

    # 如果路径不是以 http 或 https 开头,则拼接基础 URL
    if not path.startswith("http://") and not path.startswith("https://"):
        path = os.path.join(base_url, urllib.parse.quote(path))

    # 加载并返回 NumPy 数组
    return load_numpy(path)


# --- pytest 配置函数 --- #

# 避免在测试中多次调用,确保只调用一次
pytest_opt_registered = {}


def pytest_addoption_shared(parser):
    """
    该函数从 `conftest.py` 调用,使用 `pytest_addoption` 包装器。

    允许同时加载两个 `conftest.py` 文件,而不会因添加相同的 pytest 选项而失败。
    """
    option = "--make-reports"
    # 检查选项是否已注册
    if option not in pytest_opt_registered:
        # 添加选项到解析器
        parser.addoption(
            option,
            action="store",
            default=False,
            help="生成报告文件。此选项的值用作报告名称的前缀",
        )
        # 标记选项已注册
        pytest_opt_registered[option] = 1


def pytest_terminal_summary_main(tr, id):
    """
    在测试套件运行结束时生成多个报告 - 每个报告存储在当前目录的独立文件中。
    报告文件以测试套件名称为前缀。

    模拟 --duration 和 -rA pytest 参数。

    从 `conftest.py` 调用此函数。
    
    Args:
    - tr: 从 `conftest.py` 传递的 `terminalreporter`
    - id: 像 `tests` 或 `examples` 的唯一 ID,将并入最终报告文件名
    """
    from _pytest.config import create_terminal_writer

    # 如果没有提供 ID,默认为 "tests"
    if not len(id):
        id = "tests"

    # 获取配置和原始终端写入器
    config = tr.config
    orig_writer = config.get_terminal_writer()
    orig_tbstyle = config.option.tbstyle
    # 保存原始报告字符设置
        orig_reportchars = tr.reportchars
    
        # 定义报告文件存放目录
        dir = "reports"
        # 创建目录,父目录可选,如果已存在则不报错
        Path(dir).mkdir(parents=True, exist_ok=True)
        # 生成报告文件的字典,包含不同报告类型的文件名
        report_files = {
            k: f"{dir}/{id}_{k}.txt"
            for k in [
                "durations",
                "errors",
                "failures_long",
                "failures_short",
                "failures_line",
                "passes",
                "stats",
                "summary_short",
                "warnings",
            ]
        }
    
        # 自定义持续时间报告
        # 注意:无需调用 pytest --durations=XX 获取此单独报告
        # 来源于 https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
        dlist = []  # 初始化持续时间列表
        # 遍历测试结果中的统计数据
        for replist in tr.stats.values():
            for rep in replist:
                # 如果报告有持续时间属性,则添加到列表中
                if hasattr(rep, "duration"):
                    dlist.append(rep)
        # 如果持续时间列表非空
        if dlist:
            # 按持续时间降序排序
            dlist.sort(key=lambda x: x.duration, reverse=True)
            # 打开文件以写入持续时间报告
            with open(report_files["durations"], "w") as f:
                durations_min = 0.05  # 最小持续时间(秒)
                # 写入标题
                f.write("slowest durations\n")
                # 遍历报告并写入文件
                for i, rep in enumerate(dlist):
                    # 如果报告的持续时间小于最小值,写入省略信息并退出循环
                    if rep.duration < durations_min:
                        f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
                        break
                    # 写入每条报告的持续时间、执行时机和节点ID
                    f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
    
        # 定义简短失败摘要的函数
        def summary_failures_short(tr):
            # 假设报告使用了长格式,截取最后一帧
            reports = tr.getreports("failed")
            # 如果没有失败报告,则直接返回
            if not reports:
                return
            # 写入分隔符和标题
            tr.write_sep("=", "FAILURES SHORT STACK")
            # 遍历失败报告
            for rep in reports:
                # 获取失败消息
                msg = tr._getfailureheadline(rep)
                # 写入分隔符和消息
                tr.write_sep("_", msg, red=True, bold=True)
                # 去掉可选的前导额外帧,只保留最后一帧
                longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
                tr._tw.line(longrepr)
                # 注意:不打印任何 rep.sections 以保持报告简短
    
        # 使用现成的报告函数,劫持文件句柄以记录到专用文件
        # 来源于 https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
        # 注意:某些 pytest 插件可能会干扰默认的 `terminalreporter`(例如,pytest-instafail)
    
        # 以行/短/长样式报告失败
        config.option.tbstyle = "auto"  # 完整的回溯
        # 打开文件以写入长失败报告
        with open(report_files["failures_long"], "w") as f:
            # 创建终端写入器,并劫持到文件
            tr._tw = create_terminal_writer(config, f)
            # 写入失败摘要
            tr.summary_failures()
    
        # config.option.tbstyle = "short" # 短回溯
        # 打开文件以写入短失败报告
        with open(report_files["failures_short"], "w") as f:
            # 创建终端写入器,并劫持到文件
            tr._tw = create_terminal_writer(config, f)
            # 写入简短失败摘要
            summary_failures_short(tr)
    
        # 将回溯样式设置为每个错误一行
        config.option.tbstyle = "line"  # 一行每个错误
        # 打开文件以写入行失败报告
        with open(report_files["failures_line"], "w") as f:
            # 创建终端写入器,并劫持到文件
            tr._tw = create_terminal_writer(config, f)
            # 写入失败摘要
            tr.summary_failures()
    # 打开指定的错误报告文件以写入模式
    with open(report_files["errors"], "w") as f:
        # 创建终端写入器并将其赋值给 tr._tw
        tr._tw = create_terminal_writer(config, f)
        # 输出错误摘要
        tr.summary_errors()

    # 打开指定的警告报告文件以写入模式
    with open(report_files["warnings"], "w") as f:
        # 创建终端写入器并将其赋值给 tr._tw
        tr._tw = create_terminal_writer(config, f)
        # 输出常规警告摘要
        tr.summary_warnings()  # normal warnings
        # 输出最终警告摘要
        tr.summary_warnings()  # final warnings

    # 设置报告字符,用于模拟 -rA 选项(在 summary_passes() 和 short_test_summary() 中使用)
    tr.reportchars = "wPpsxXEf"  
    # 打开指定的通过报告文件以写入模式
    with open(report_files["passes"], "w") as f:
        # 创建终端写入器并将其赋值给 tr._tw
        tr._tw = create_terminal_writer(config, f)
        # 输出通过的摘要
        tr.summary_passes()

    # 打开指定的短摘要报告文件以写入模式
    with open(report_files["summary_short"], "w") as f:
        # 创建终端写入器并将其赋值给 tr._tw
        tr._tw = create_terminal_writer(config, f)
        # 输出短测试摘要
        tr.short_test_summary()

    # 打开指定的统计报告文件以写入模式
    with open(report_files["stats"], "w") as f:
        # 创建终端写入器并将其赋值给 tr._tw
        tr._tw = create_terminal_writer(config, f)
        # 输出统计摘要
        tr.summary_stats()

    # 恢复原始设置:
    # 恢复终端写入器为原始写入器
    tr._tw = orig_writer
    # 恢复报告字符为原始字符
    tr.reportchars = orig_reportchars
    # 恢复配置的跟踪样式为原始样式
    config.option.tbstyle = orig_tbstyle
# 从 GitHub 复制的代码,装饰器用于处理不稳定的测试
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
    """
    装饰不稳定的测试,失败时会重试。

    参数:
        max_attempts (`int`, *可选*, 默认值为 5):
            重新尝试不稳定测试的最大尝试次数。
        wait_before_retry (`float`, *可选*):
            如果提供,重试测试前将等待该秒数。
        description (`str`, *可选*):
            描述情况的字符串(什么/在哪里/为什么不稳定,链接到 GitHub 问题/PR 评论、错误等)。
    """

    # 装饰器函数,接收待装饰的测试函数
    def decorator(test_func_ref):
        # 包装函数,处理重试逻辑
        @functools.wraps(test_func_ref)
        def wrapper(*args, **kwargs):
            retry_count = 1  # 初始化重试计数

            # 在最大尝试次数内循环
            while retry_count < max_attempts:
                try:
                    # 尝试调用测试函数并返回结果
                    return test_func_ref(*args, **kwargs)

                except Exception as err:  # 捕获测试函数的异常
                    # 打印错误信息和当前重试次数
                    print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
                    if wait_before_retry is not None:  # 如果提供了等待时间
                        time.sleep(wait_before_retry)  # 等待指定的秒数
                    retry_count += 1  # 增加重试计数

            # 达到最大重试次数后再次尝试调用测试函数并返回结果
            return test_func_ref(*args, **kwargs)

        return wrapper  # 返回包装后的函数

    return decorator  # 返回装饰器


# 从 GitHub 复制的代码,用于在子进程中运行测试
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
    """
    在子进程中运行测试。这可以避免 (GPU) 内存问题。

    参数:
        test_case (`unittest.TestCase`):
            将运行 `target_func` 的测试用例。
        target_func (`Callable`):
            实现实际测试逻辑的函数。
        inputs (`dict`, *可选*, 默认值为 `None`):
            将通过(输入)队列传递给 `target_func` 的输入。
        timeout (`int`, *可选*, 默认值为 `None`):
            将传递给输入和输出队列的超时(以秒为单位)。如果未指定,将检查环境变量 `PYTEST_TIMEOUT`。如果仍为 `None`,其值将设置为 `600`。
    """
    if timeout is None:
        # 获取超时设置,如果环境变量未设置则默认为 600 秒
        timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))

    start_methohd = "spawn"  # 定义子进程的启动方法为 "spawn"
    ctx = multiprocessing.get_context(start_methohd)  # 获取指定上下文的 multiprocessing

    input_queue = ctx.Queue(1)  # 创建一个输入队列,最大大小为 1
    output_queue = ctx.JoinableQueue(1)  # 创建一个可加入的输出队列,最大大小为 1

    # 我们不能将 `unittest.TestCase` 发送到子进程,否则会出现关于 pickle 的问题。
    input_queue.put(inputs, timeout=timeout)  # 将输入放入队列,指定超时时间

    # 创建并启动子进程,指定目标函数和参数
    process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
    process.start()  # 启动子进程
    # 如果不能及时从子进程获取输出,则终止子进程:否则,悬挂的子进程会阻塞
    # 测试以正确的方式退出
        try:
            # 从输出队列中获取结果,设置超时
            results = output_queue.get(timeout=timeout)
            # 标记任务为已完成
            output_queue.task_done()
        except Exception as e:
            # 处理异常,终止进程
            process.terminate()
            # 记录测试失败的原因
            test_case.fail(e)
        # 等待进程结束,设置超时
        process.join(timeout=timeout)
    
        # 检查结果中的错误信息是否存在
        if results["error"] is not None:
            # 记录测试失败的具体错误信息
            test_case.fail(f'{results["error"]}')
# 定义一个上下文管理器类,用于捕获日志流
class CaptureLogger:
    """
    参数:
    上下文管理器,用于捕获 `logging` 流
        logger: 'logging` 日志对象
    返回:
        捕获的输出可以通过 `self.out` 获取
    示例:
    ```python
    >>> from diffusers import logging
    >>> from diffusers.testing_utils import CaptureLogger

    >>> msg = "Testing 1, 2, 3"
    >>> logging.set_verbosity_info()
    >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
    >>> with CaptureLogger(logger) as cl:
    ...     logger.info(msg)
    >>> assert cl.out, msg + "\n"
    ```py
    """

    # 初始化 CaptureLogger 类,接受一个日志对象
    def __init__(self, logger):
        self.logger = logger  # 保存日志对象
        self.io = StringIO()  # 创建一个字符串流用于捕获日志
        self.sh = logging.StreamHandler(self.io)  # 创建流处理器以写入字符串流
        self.out = ""  # 初始化捕获的输出为空字符串

    # 进入上下文时添加日志处理器
    def __enter__(self):
        self.logger.addHandler(self.sh)  # 将处理器添加到日志对象
        return self  # 返回当前实例

    # 退出上下文时移除日志处理器并获取输出
    def __exit__(self, *exc):
        self.logger.removeHandler(self.sh)  # 移除处理器
        self.out = self.io.getvalue()  # 获取捕获的日志输出

    # 返回捕获的日志字符串的表示形式
    def __repr__(self):
        return f"captured: {self.out}\n"


# 启用全确定性以保证分布式训练中的可重现性
def enable_full_determinism():
    """
    帮助函数以保证分布式训练期间的可重现行为。请参见
    - https://pytorch.org/docs/stable/notes/randomness.html 以了解 PyTorch
    """
    # 启用 PyTorch 确定性模式。可能需要设置环境变量 'CUDA_LAUNCH_BLOCKING' 或 'CUBLAS_WORKSPACE_CONFIG'
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # 设置 CUDA 同步执行
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"  # 配置 CUBLAS 工作区
    torch.use_deterministic_algorithms(True)  # 启用确定性算法

    # 启用 CUDNN 确定性模式
    torch.backends.cudnn.deterministic = True  # 设置 CUDNN 为确定性
    torch.backends.cudnn.benchmark = False  # 禁用基准优化
    torch.backends.cuda.matmul.allow_tf32 = False  # 禁用 TF32 精度

# 禁用全确定性以恢复非确定性行为
def disable_full_determinism():
    os.environ["CUDA_LAUNCH_BLOCKING"] = "0"  # 关闭 CUDA 同步执行
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""  # 清除 CUBLAS 工作区配置
    torch.use_deterministic_algorithms(False)  # 禁用确定性算法


# 检查给定设备上是否支持 FP16
def _is_torch_fp16_available(device):
    if not is_torch_available():  # 如果 PyTorch 不可用,返回 False
        return False

    import torch  # 导入 PyTorch

    device = torch.device(device)  # 将设备字符串转换为设备对象

    try:
        x = torch.zeros((2, 2), dtype=torch.float16).to(device)  # 创建 FP16 张量并移动到指定设备
        _ = torch.mul(x, x)  # 进行乘法运算以检查支持情况
        return True  # 如果没有异常,返回 True

    except Exception as e:  # 捕获异常
        if device.type == "cuda":  # 如果设备类型为 cuda
            raise ValueError(
                f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
            )  # 抛出错误,提示 cuda 安装问题

        return False  # 其他设备返回 False


# 检查给定设备上是否支持 FP64
def _is_torch_fp64_available(device):
    if not is_torch_available():  # 如果 PyTorch 不可用,返回 False
        return False

    import torch  # 导入 PyTorch

    device = torch.device(device)  # 将设备字符串转换为设备对象

    try:
        x = torch.zeros((2, 2), dtype=torch.float64).to(device)  # 创建 FP64 张量并移动到指定设备
        _ = torch.mul(x, x)  # 进行乘法运算以检查支持情况
        return True  # 如果没有异常,返回 True
    # 捕获异常并处理
        except Exception as e:
            # 检查设备类型是否为 'cuda'
            if device.type == "cuda":
                # 引发值错误,提示 'cuda' 应该支持 'fp64',但似乎未正确安装
                raise ValueError(
                    f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
                )
    
            # 返回 False,表示操作失败
            return False
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available():  # 检查 PyTorch 是否可用
    # Behaviour flags
    BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}  # 定义支持训练的后端标志

    # Function definitions
    BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}  # 定义后端清空缓存的函数
    BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}  # 定义获取设备数量的函数
    BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}  # 定义设置随机种子的函数


# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):  # 根据设备和函数表调度相应的函数
    if device not in dispatch_table:  # 如果设备不在调度表中
        return dispatch_table["default"](*args, **kwargs)  # 返回默认函数的结果

    fn = dispatch_table[device]  # 获取对应设备的函数

    # Some device agnostic functions return values. Need to guard against 'None' instead at
    # user level
    if fn is None:  # 如果函数为 None
        return None  # 返回 None

    return fn(*args, **kwargs)  # 调用并返回函数的结果


# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int):  # 设置特定设备的随机种子
    return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)  # 调用调度函数


def backend_empty_cache(device: str):  # 清空特定设备的缓存
    return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)  # 调用调度函数


def backend_device_count(device: str):  # 获取特定设备的数量
    return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)  # 调用调度函数


# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):  # 检查特定设备是否支持训练
    if not is_torch_available():  # 如果 PyTorch 不可用
        return False  # 返回 False

    if device not in BACKEND_SUPPORTS_TRAINING:  # 如果设备不在支持训练的字典中
        device = "default"  # 使用默认设备

    return BACKEND_SUPPORTS_TRAINING[device]  # 返回该设备的支持训练标志


# Guard for when Torch is not available
if is_torch_available():  # 检查 PyTorch 是否可用
    # Update device function dict mapping
    def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):  # 更新设备函数字典映射
        try:
            # Try to import the function directly
            spec_fn = getattr(device_spec_module, attribute_name)  # 尝试从模块中获取指定属性
            device_fn_dict[torch_device] = spec_fn  # 更新设备函数字典
        except AttributeError as e:  # 捕获属性错误
            # If the function doesn't exist, and there is no default, throw an error
            if "default" not in device_fn_dict:  # 如果字典中没有默认函数
                raise AttributeError(  # 抛出属性错误
                    f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
                ) from e  # 追踪原始错误
    # 检查环境变量中是否存在特定的设备规格
        if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
            # 获取设备规格文件的路径
            device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
            # 检查指定的设备规格文件路径是否是一个文件
            if not Path(device_spec_path).is_file():
                # 如果文件不存在,抛出值错误并给出提示
                raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
    
            try:
                # 从文件路径中提取模块名称(去掉.py后缀)
                import_name = device_spec_path[: device_spec_path.index(".py")]
            except ValueError as e:
                # 如果路径中没有找到.py,抛出值错误
                raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
    
            # 动态导入设备规格模块
            device_spec_module = importlib.import_module(import_name)
    
            try:
                # 从模块中获取设备名称
                device_name = device_spec_module.DEVICE_NAME
            except AttributeError:
                # 如果模块中没有DEVICE_NAME,抛出属性错误
                raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
    
            # 检查环境变量中是否存在另一设备名称,并与当前设备名称进行比较
            if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
                # 如果不匹配,构建错误信息并抛出值错误
                msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
                msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
                raise ValueError(msg)
    
            # 将torch_device设置为提取的设备名称
            torch_device = device_name
    
            # 对每个`BACKEND_*`字典添加一条条目
            update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
            update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
            update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
            update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
posted @ 2024-10-22 12:33  绝不原创的飞龙  阅读(156)  评论(0)    收藏  举报