diffusers-源码解析-六十六-
diffusers 源码解析(六十六)
.\diffusers\utils\logging.py
# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,标明版权所有者和年份
# Copyright 2024 Optuna, Hugging Face
#
# 根据 Apache License 2.0 版本许可本文件的使用
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件在未遵守许可证的情况下不可使用
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证的副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 在适用的情况下,许可证下的软件以“原样”提供,不提供任何明示或暗示的担保
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取特定的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
"""记录工具函数的模块。"""
# 导入 logging 模块以实现日志记录功能
import logging
# 导入 os 模块以进行操作系统交互
import os
# 导入 sys 模块以访问系统特定参数和功能
import sys
# 导入 threading 模块以实现线程支持
import threading
# 从 logging 模块导入不同的日志级别常量
from logging import (
CRITICAL, # NOQA
DEBUG, # NOQA
ERROR, # NOQA
FATAL, # NOQA
INFO, # NOQA
NOTSET, # NOQA
WARN, # NOQA
WARNING, # NOQA
)
# 导入 Dict 和 Optional 类型以进行类型注解
from typing import Dict, Optional
# 从 tqdm 库导入自动选择的进度条支持
from tqdm import auto as tqdm_lib
# 创建一个线程锁以确保线程安全
_lock = threading.Lock()
# 定义一个默认的日志处理程序,初始为 None
_default_handler: Optional[logging.Handler] = None
# 定义日志级别的字典,映射字符串到 logging 的级别
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
# 设置默认日志级别为 WARNING
_default_log_level = logging.WARNING
# 标志表示进度条是否处于活动状态
_tqdm_active = True
# 定义获取默认日志级别的函数
def _get_default_logging_level() -> int:
"""
如果环境变量 DIFFUSERS_VERBOSITY 设置为有效选项,则返回该值作为新的默认级别。
如果没有设置,则返回 `_default_log_level`。
"""
# 获取环境变量 DIFFUSERS_VERBOSITY 的值
env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
# 如果环境变量存在
if env_level_str:
# 检查环境变量值是否在日志级别字典中
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
# 如果值无效,记录警告信息
logging.getLogger().warning(
f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
# 返回默认日志级别
return _default_log_level
# 定义获取库名称的函数
def _get_library_name() -> str:
# 返回模块名称的第一个部分作为库名称
return __name__.split(".")[0]
# 定义获取库根日志记录器的函数
def _get_library_root_logger() -> logging.Logger:
# 返回库名称对应的日志记录器
return logging.getLogger(_get_library_name())
# 定义配置库根日志记录器的函数
def _configure_library_root_logger() -> None:
global _default_handler
# 使用线程锁来确保线程安全
with _lock:
# 如果默认处理程序已存在,返回
if _default_handler:
# 该库已配置根日志记录器
return
# 创建一个流处理程序,输出到标准错误
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
# 检查 sys.stderr 是否存在
if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows
# 设置 flush 方法为 sys.stderr 的 flush 方法
_default_handler.flush = sys.stderr.flush
# 应用默认配置到库根日志记录器
library_root_logger = _get_library_root_logger()
# 添加默认处理程序到库根日志记录器
library_root_logger.addHandler(_default_handler)
# 设置库根日志记录器的日志级别
library_root_logger.setLevel(_get_default_logging_level())
# 禁用日志记录器的传播
library_root_logger.propagate = False
# 定义重置库根日志记录器的函数
def _reset_library_root_logger() -> None:
global _default_handler
# 使用锁确保线程安全,防止竞争条件
with _lock:
# 如果没有默认处理器,则直接返回
if not _default_handler:
return
# 获取库的根日志记录器
library_root_logger = _get_library_root_logger()
# 从根日志记录器中移除默认处理器
library_root_logger.removeHandler(_default_handler)
# 将根日志记录器的日志级别设置为 NOTSET,表示接受所有级别的日志
library_root_logger.setLevel(logging.NOTSET)
# 将默认处理器设置为 None,表示不再使用默认处理器
_default_handler = None
# 获取日志级别字典
def get_log_levels_dict() -> Dict[str, int]:
# 返回全局日志级别字典
return log_levels
# 获取指定名称的日志记录器
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
返回具有指定名称的日志记录器。
该函数不应直接访问,除非您正在编写自定义的 diffusers 模块。
"""
# 如果未提供名称,则获取库名称
if name is None:
name = _get_library_name()
# 配置库根日志记录器
_configure_library_root_logger()
# 返回指定名称的日志记录器
return logging.getLogger(name)
# 获取当前日志级别
def get_verbosity() -> int:
"""
返回 🤗 Diffusers 根日志记录器的当前级别作为 `int`。
返回:
`int`:
日志级别整数,可以是以下之一:
- `50`: `diffusers.logging.CRITICAL` 或 `diffusers.logging.FATAL`
- `40`: `diffusers.logging.ERROR`
- `30`: `diffusers.logging.WARNING` 或 `diffusers.logging.WARN`
- `20`: `diffusers.logging.INFO`
- `10`: `diffusers.logging.DEBUG`
"""
# 配置库根日志记录器
_configure_library_root_logger()
# 返回根日志记录器的有效级别
return _get_library_root_logger().getEffectiveLevel()
# 设置日志级别
def set_verbosity(verbosity: int) -> None:
"""
设置 🤗 Diffusers 根日志记录器的详细程度。
参数:
verbosity (`int`):
日志级别,可以是以下之一:
- `diffusers.logging.CRITICAL` 或 `diffusers.logging.FATAL`
- `diffusers.logging.ERROR`
- `diffusers.logging.WARNING` 或 `diffusers.logging.WARN`
- `diffusers.logging.INFO`
- `diffusers.logging.DEBUG`
"""
# 配置库根日志记录器
_configure_library_root_logger()
# 设置根日志记录器的级别
_get_library_root_logger().setLevel(verbosity)
# 设置日志级别为 INFO
def set_verbosity_info() -> None:
"""将详细程度设置为 `INFO` 级别。"""
# 调用设置详细程度的函数
return set_verbosity(INFO)
# 设置日志级别为 WARNING
def set_verbosity_warning() -> None:
"""将详细程度设置为 `WARNING` 级别。"""
# 调用设置详细程度的函数
return set_verbosity(WARNING)
# 设置日志级别为 DEBUG
def set_verbosity_debug() -> None:
"""将详细程度设置为 `DEBUG` 级别。"""
# 调用设置详细程度的函数
return set_verbosity(DEBUG)
# 设置日志级别为 ERROR
def set_verbosity_error() -> None:
"""将详细程度设置为 `ERROR` 级别。"""
# 调用设置详细程度的函数
return set_verbosity(ERROR)
# 禁用默认处理程序
def disable_default_handler() -> None:
"""禁用 🤗 Diffusers 根日志记录器的默认处理程序。"""
# 配置库根日志记录器
_configure_library_root_logger()
# 确保默认处理程序存在
assert _default_handler is not None
# 从根日志记录器中移除默认处理程序
_get_library_root_logger().removeHandler(_default_handler)
# 启用默认处理程序
def enable_default_handler() -> None:
"""启用 🤗 Diffusers 根日志记录器的默认处理程序。"""
# 配置库根日志记录器
_configure_library_root_logger()
# 确保默认处理程序存在
assert _default_handler is not None
# 将默认处理程序添加到根日志记录器
_get_library_root_logger().addHandler(_default_handler)
# 添加处理程序到日志记录器
def add_handler(handler: logging.Handler) -> None:
"""将处理程序添加到 HuggingFace Diffusers 根日志记录器。"""
# 配置库根日志记录器
_configure_library_root_logger()
# 确保处理程序存在
assert handler is not None
# 将处理程序添加到根日志记录器
_get_library_root_logger().addHandler(handler)
# 从日志记录器移除处理程序
def remove_handler(handler: logging.Handler) -> None:
"""从 HuggingFace Diffusers 根日志记录器移除给定的处理程序。"""
# 配置库根日志记录器
_configure_library_root_logger()
# 确保处理器不为空,并且在库根日志记录器的处理器列表中
assert handler is not None and handler in _get_library_root_logger().handlers
# 从库根日志记录器中移除指定的处理器
_get_library_root_logger().removeHandler(handler)
# 定义一个函数,用于禁用库的日志输出传播
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
# 配置库的根日志记录器
_configure_library_root_logger()
# 设置根日志记录器的传播属性为 False,禁用日志传播
_get_library_root_logger().propagate = False
# 定义一个函数,用于启用库的日志输出传播
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
double logging if the root logger has been configured.
"""
# 配置库的根日志记录器
_configure_library_root_logger()
# 设置根日志记录器的传播属性为 True,启用日志传播
_get_library_root_logger().propagate = True
# 定义一个函数,用于启用明确的日志格式
def enable_explicit_format() -> None:
"""
Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
```py
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
# 获取根日志记录器的所有处理器
handlers = _get_library_root_logger().handlers
# 遍历每个处理器,设置其格式化器
for handler in handlers:
# 创建一个新的格式化器
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
# 将格式化器设置到处理器上
handler.setFormatter(formatter)
# 定义一个函数,用于重置日志格式
def reset_format() -> None:
"""
Resets the formatting for 🤗 Diffusers' loggers.
All handlers currently bound to the root logger are affected by this method.
"""
# 获取根日志记录器的所有处理器
handlers = _get_library_root_logger().handlers
# 遍历每个处理器,重置其格式化器
for handler in handlers:
# 将处理器的格式化器设置为 None,重置格式
handler.setFormatter(None)
# 定义一个方法,用于发出警告信息
def warning_advice(self, *args, **kwargs) -> None:
"""
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
# 检查环境变量是否设置为不发出建议警告
no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
# 如果设置了环境变量,则直接返回,不发出警告
if no_advisory_warnings:
return
# 调用日志记录器的警告方法
self.warning(*args, **kwargs)
# 将自定义的警告方法绑定到日志记录器
logging.Logger.warning_advice = warning_advice
# 定义一个空的 tqdm 类,用于替代真实的进度条
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""
# 初始化方法,接收可变参数
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
# 如果有参数,保存第一个参数为迭代器
self._iterator = args[0] if args else None
# 定义迭代器方法,返回迭代器
def __iter__(self):
return iter(self._iterator)
# 定义属性访问方法,返回一个空函数
def __getattr__(self, _):
"""Return empty function."""
# 返回一个空函数
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
return
return empty_fn
# 定义上下文管理器的进入方法
def __enter__(self):
return self
# 定义上下文管理器的退出方法
def __exit__(self, type_, value, traceback):
return
# 定义一个自定义的 tqdm 类
class _tqdm_cls:
# 定义调用方法
def __call__(self, *args, **kwargs):
# 检查 tqdm 是否处于激活状态
if _tqdm_active:
# 返回激活状态下的 tqdm 实例
return tqdm_lib.tqdm(*args, **kwargs)
else:
# 返回空的 tqdm 实例
return EmptyTqdm(*args, **kwargs)
# 定义设置锁的方法
def set_lock(self, *args, **kwargs):
# 将锁设置为 None
self._lock = None
# 如果 tqdm 处于激活状态,设置锁
if _tqdm_active:
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
# 定义获取锁的方法
def get_lock(self):
# 如果 tqdm 处于激活状态,获取锁
if _tqdm_active:
return tqdm_lib.tqdm.get_lock()
# 创建一个 _tqdm_cls 的实例
tqdm = _tqdm_cls()
# 定义一个函数,检查进度条是否启用
def is_progress_bar_enabled() -> bool:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global _tqdm_active
# 返回进度条激活状态的布尔值
return bool(_tqdm_active)
# 定义一个启用进度条的函数,不返回任何值
def enable_progress_bar() -> None:
# 函数文档字符串,说明该函数的作用是启用 tqdm 进度条
"""Enable tqdm progress bar."""
# 声明全局变量 _tqdm_active
global _tqdm_active
# 将全局变量 _tqdm_active 设置为 True,表示进度条处于启用状态
_tqdm_active = True
# 定义一个禁用进度条的函数,不返回任何值
def disable_progress_bar() -> None:
# 函数文档字符串,说明该函数的作用是禁用 tqdm 进度条
"""Disable tqdm progress bar."""
# 声明全局变量 _tqdm_active
global _tqdm_active
# 将全局变量 _tqdm_active 设置为 False,表示进度条处于禁用状态
_tqdm_active = False
{{ card_data }}
{{ model_description }}
Intended uses & limitations
How to use
# TODO: add an example code snippet for running this diffusion pipeline
Limitations and bias
[TODO: provide examples of latent issues and potential remediations]
Training details
[TODO: describe the data used to train the model]
.\diffusers\utils\outputs.py
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,第 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有规定,软件
# 按“现状”分发,不附带任何明示或暗示的担保或条件。
# 请参阅许可证以了解管理权限和
# 限制的具体语言。
"""
通用工具函数
"""
# 从有序字典模块导入 OrderedDict
from collections import OrderedDict
# 从数据类模块导入字段和是否为数据类的检查
from dataclasses import fields, is_dataclass
# 导入 Any 和 Tuple 类型
from typing import Any, Tuple
# 导入 NumPy 库
import numpy as np
# 从本地导入工具模块,检查 PyTorch 是否可用及其版本
from .import_utils import is_torch_available, is_torch_version
def is_tensor(x) -> bool:
"""
测试 `x` 是否为 `torch.Tensor` 或 `np.ndarray`。
"""
# 如果 PyTorch 可用
if is_torch_available():
# 导入 PyTorch 库
import torch
# 检查 x 是否为 torch.Tensor 类型
if isinstance(x, torch.Tensor):
return True
# 检查 x 是否为 np.ndarray 类型
return isinstance(x, np.ndarray)
class BaseOutput(OrderedDict):
"""
所有模型输出的基类,作为数据类。具有一个 `__getitem__` 方法,允许通过整数或切片(像元组)或字符串(像字典)进行索引,并会忽略 `None` 属性。
否则像常规 Python 字典一样工作。
<提示 警告={true}>
不能直接解包 [`BaseOutput`]。请先使用 [`~utils.BaseOutput.to_tuple`] 方法将其转换为元组。
</提示>
"""
def __init_subclass__(cls) -> None:
"""将子类注册为 pytree 节点。
这对于在使用 `torch.nn.parallel.DistributedDataParallel` 和
`static_graph=True` 时同步梯度是必要的,尤其是对于输出 `ModelOutput` 子类的模块。
"""
# 如果 PyTorch 可用
if is_torch_available():
# 导入 PyTorch 的 pytree 工具
import torch.utils._pytree
# 检查 PyTorch 版本是否小于 2.2
if is_torch_version("<", "2.2"):
# 注册 pytree 节点,使用字典扁平化和解扁平化
torch.utils._pytree._register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
else:
# 注册 pytree 节点,使用字典扁平化和解扁平化
torch.utils._pytree.register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
# 定义数据类的后处理初始化方法
def __post_init__(self) -> None:
# 获取当前数据类的所有字段
class_fields = fields(self)
# 安全性和一致性检查
if not len(class_fields):
# 如果没有字段,抛出错误
raise ValueError(f"{self.__class__.__name__} has no fields.")
# 获取第一个字段的值
first_field = getattr(self, class_fields[0].name)
# 检查除了第一个字段外,其他字段是否均为 None
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
# 如果其他字段均为 None 且第一个字段为字典,进行赋值
if other_fields_are_none and isinstance(first_field, dict):
for key, value in first_field.items():
# 将字典内容赋值到当前对象
self[key] = value
else:
# 遍历所有字段并赋值非 None 的字段
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
# 将非 None 的字段值赋值到当前对象
self[field.name] = v
# 定义删除项的方法
def __delitem__(self, *args, **kwargs):
# 不允许删除项,抛出异常
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
# 定义设置默认值的方法
def setdefault(self, *args, **kwargs):
# 不允许设置默认值,抛出异常
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
# 定义弹出项的方法
def pop(self, *args, **kwargs):
# 不允许弹出项,抛出异常
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
# 定义更新项的方法
def update(self, *args, **kwargs):
# 不允许更新项,抛出异常
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
# 定义获取项的方法
def __getitem__(self, k: Any) -> Any:
# 如果键是字符串
if isinstance(k, str):
# 将当前项转换为字典并返回对应值
inner_dict = dict(self.items())
return inner_dict[k]
else:
# 如果键不是字符串,返回对应的元组值
return self.to_tuple()[k]
# 定义设置属性的方法
def __setattr__(self, name: Any, value: Any) -> None:
# 如果属性名在键中且值不为 None
if name in self.keys() and value is not None:
# 不调用 self.__setitem__ 以避免递归错误
super().__setitem__(name, value)
# 设置属性值
super().__setattr__(name, value)
# 定义设置项的方法
def __setitem__(self, key, value):
# 将键值对设置到当前对象中
super().__setitem__(key, value)
# 不调用 self.__setattr__ 以避免递归错误
super().__setattr__(key, value)
# 定义序列化的方法
def __reduce__(self):
# 如果当前对象不是数据类
if not is_dataclass(self):
# 调用父类的序列化方法
return super().__reduce__()
# 获取可调用对象和参数
callable, _args, *remaining = super().__reduce__()
# 生成字段的元组
args = tuple(getattr(self, field.name) for field in fields(self))
# 返回可调用对象、参数及其他信息
return callable, args, *remaining
# 定义转换为元组的方法
def to_tuple(self) -> Tuple[Any, ...]:
"""
将当前对象转换为一个包含所有非 `None` 属性/键的元组。
"""
# 返回包含所有键的值的元组
return tuple(self[k] for k in self.keys())
.\diffusers\utils\peft_utils.py
# 版权声明,声明本文件的版权归 HuggingFace 团队所有
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 在 Apache 许可证 2.0(“许可证”)下授权;
# 除非遵循该许可证,否则不得使用本文件。
# 可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是“按原样”提供的,
# 不提供任何形式的保证或条件,无论是明示还是暗示的。
# 有关许可证所管辖的权限和限制的具体信息,请参阅许可证。
"""
PEFT 工具:与 peft 库相关的工具
"""
# 导入 collections 模块以便于使用集合相关的功能
import collections
# 导入 importlib 模块以支持动态导入
import importlib
# 从 typing 导入 Optional 类型以用于类型注释
from typing import Optional
# 导入 version 模块以处理版本相关的功能
from packaging import version
# 从当前包导入 utils,检查 peft 和 torch 库是否可用
from .import_utils import is_peft_available, is_torch_available
# 如果 torch 库可用,则导入 torch 模块
if is_torch_available():
import torch
# 定义函数以递归地替换模型中的 LoraLayer 实例
def recurse_remove_peft_layers(model):
r"""
递归替换模型中所有 `LoraLayer` 的实例为相应的新层。
"""
# 从 peft.tuners.tuners_utils 导入 BaseTunerLayer 类
from peft.tuners.tuners_utils import BaseTunerLayer
# 初始化一个标志,指示是否存在基础层模式
has_base_layer_pattern = False
# 遍历模型中的所有模块
for module in model.modules():
# 检查模块是否为 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 如果模块有名为 "base_layer" 的属性,则更新标志
has_base_layer_pattern = hasattr(module, "base_layer")
break
# 如果存在基础层模式
if has_base_layer_pattern:
# 从 peft.utils 导入 _get_submodules 函数
from peft.utils import _get_submodules
# 获取所有不包含 "lora" 的模块名称列表
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
# 遍历所有模块名称
for key in key_list:
try:
# 获取当前模块的父级、目标和目标名称
parent, target, target_name = _get_submodules(model, key)
except AttributeError:
# 如果发生属性错误,则继续下一个模块
continue
# 如果目标具有 "base_layer" 属性
if hasattr(target, "base_layer"):
# 用目标的基础层替换父模块中的目标
setattr(parent, target_name, target.get_base_layer())
else:
# 处理与 PEFT <= 0.6.2 的向后兼容性
# TODO: 一旦不再支持该 PEFT 版本,可以移除此代码
from peft.tuners.lora import LoraLayer # 导入 LoraLayer 模块
# 遍历模型的所有子模块
for name, module in model.named_children():
# 如果子模块有子模块,则递归进入处理
if len(list(module.children())) > 0:
## 复合模块,递归处理其内部的层
recurse_remove_peft_layers(module)
module_replaced = False # 初始化标志,表示模块是否被替换
# 检查当前模块是否为 LoraLayer 且为线性层
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
# 创建新的线性层,保留输入和输出特征及偏置信息
new_module = torch.nn.Linear(
module.in_features, # 输入特征数量
module.out_features, # 输出特征数量
bias=module.bias is not None, # 是否使用偏置
).to(module.weight.device) # 将新模块移动到当前模块权重的设备上
new_module.weight = module.weight # 复制权重
if module.bias is not None:
new_module.bias = module.bias # 复制偏置
module_replaced = True # 标记模块已被替换
# 检查当前模块是否为 LoraLayer 且为卷积层
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
# 创建新的卷积层,保留卷积参数
new_module = torch.nn.Conv2d(
module.in_channels, # 输入通道数
module.out_channels, # 输出通道数
module.kernel_size, # 卷积核大小
module.stride, # 步幅
module.padding, # 填充
module.dilation, # 膨胀
module.groups, # 分组卷积
).to(module.weight.device) # 将新模块移动到当前模块权重的设备上
new_module.weight = module.weight # 复制权重
if module.bias is not None:
new_module.bias = module.bias # 复制偏置
module_replaced = True # 标记模块已被替换
# 如果模块被替换,则更新模型
if module_replaced:
setattr(model, name, new_module) # 更新模型中的模块
del module # 删除旧模块
# 如果可用,则清空 CUDA 缓存
if torch.cuda.is_available():
torch.cuda.empty_cache() # 释放 CUDA 内存
return model # 返回更新后的模型
# 定义一个函数来调整模型的 LoRA 层的权重
def scale_lora_layers(model, weight):
"""
调整模型的 LoRA 层的权重。
参数:
model (`torch.nn.Module`):
需要调整的模型。
weight (`float`):
要分配给 LoRA 层的权重。
"""
# 从 peft.tuners.tuners_utils 导入 BaseTunerLayer 类
from peft.tuners.tuners_utils import BaseTunerLayer
# 如果权重为 1.0,直接返回,不做任何调整
if weight == 1.0:
return
# 遍历模型的所有模块
for module in model.modules():
# 检查当前模块是否是 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 调整当前 LoRA 层的权重
module.scale_layer(weight)
# 定义一个函数来移除先前给定的 LoRA 层的权重
def unscale_lora_layers(model, weight: Optional[float] = None):
"""
移除模型的 LoRA 层的权重。
参数:
model (`torch.nn.Module`):
需要调整的模型。
weight (`float`, *可选*):
要分配给 LoRA 层的权重。如果未传入权重,将重新初始化 LoRA 层的权重。如果传入 0.0,将以正确的值重新初始化权重。
"""
# 从 peft.tuners.tuners_utils 导入 BaseTunerLayer 类
from peft.tuners.tuners_utils import BaseTunerLayer
# 如果权重为 1.0,直接返回,不做任何调整
if weight == 1.0:
return
# 遍历模型的所有模块
for module in model.modules():
# 检查当前模块是否是 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 如果传入了权重且权重不为 0
if weight is not None and weight != 0:
# 移除当前 LoRA 层的权重
module.unscale_layer(weight)
# 如果传入的权重为 0
elif weight is not None and weight == 0:
# 遍历当前模块的所有活动适配器
for adapter_name in module.active_adapters:
# 如果权重为 0,则将权重重置为原始值
module.set_scale(adapter_name, 1.0)
# 定义一个函数来获取 PEFT 的参数
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
# 初始化排名模式和 alpha 模式的字典
rank_pattern = {}
alpha_pattern = {}
# 获取 rank_dict 的第一个值作为 lora_alpha
r = lora_alpha = list(rank_dict.values())[0]
# 如果 rank_dict 中的值不全相同
if len(set(rank_dict.values())) > 1:
# 获取出现次数最多的 rank
r = collections.Counter(rank_dict.values()).most_common()[0][0]
# 对于排名与最常见排名不同的模块,将其添加到 rank_pattern 中
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
# 提取模块名称,去掉 ".lora_B." 后的部分,并保存到 rank_pattern 中
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
# 检查网络 alpha 字典是否不为 None 且包含元素
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
# 检查网络 alpha 字典中是否有超过一个不同的值
if len(set(network_alpha_dict.values())) > 1:
# 获取出现次数最多的 alpha
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# 对于与出现次数最多的 alpha 不同的模块,将其添加到 `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
# 如果是 UNet 模型
if is_unet:
# 处理模块名称,去掉 ".lora_A." 和 ".alpha"
alpha_pattern = {
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
for k, v in alpha_pattern.items()
}
else:
# 处理模块名称,去掉 ".down." 后的部分
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
# 如果只有一个不同的 alpha,直接取出该值
lora_alpha = set(network_alpha_dict.values()).pop()
# 获取不包含 Diffusers 特定的层名称,去重
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
# 检查 peft_state_dict 中是否有使用 "lora_magnitude_vector"
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# 构建 lora 配置参数字典
lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
"use_dora": use_dora,
}
# 返回 lora 配置参数字典
return lora_config_kwargs
# 获取模型中适配器的名称,根据 BaseTunerLayer 的数量返回
def get_adapter_name(model):
# 从 PEFT 库中导入基础调优层
from peft.tuners.tuners_utils import BaseTunerLayer
# 遍历模型的所有模块
for module in model.modules():
# 检查模块是否为 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 返回适配器的名称,格式为 "default_" 加上适配器数量
return f"default_{len(module.r)}"
# 如果没有找到适配器,返回默认名称 "default_0"
return "default_0"
# 设置模型的适配层,可启用或禁用
def set_adapter_layers(model, enabled=True):
# 从 PEFT 库中导入基础调优层
from peft.tuners.tuners_utils import BaseTunerLayer
# 遍历模型的所有模块
for module in model.modules():
# 检查模块是否为 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 检查模块是否具备 enable_adapters 方法
if hasattr(module, "enable_adapters"):
# 调用 enable_adapters 方法启用或禁用适配器
module.enable_adapters(enabled=enabled)
else:
# 通过禁用状态设置 disable_adapters 属性
module.disable_adapters = not enabled
# 删除模型中的适配层
def delete_adapter_layers(model, adapter_name):
# 从 PEFT 库中导入基础调优层
from peft.tuners.tuners_utils import BaseTunerLayer
# 遍历模型的所有模块
for module in model.modules():
# 检查模块是否为 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 检查模块是否具备 delete_adapter 方法
if hasattr(module, "delete_adapter"):
# 调用 delete_adapter 方法删除指定适配器
module.delete_adapter(adapter_name)
else:
# 抛出错误,提示 PEFT 版本不兼容
raise ValueError(
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
)
# 检查模型是否已加载 PEFT 配置
if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
# 从配置中删除适配器
model.peft_config.pop(adapter_name, None)
# 如果所有适配器都已删除,删除配置并重置标志
if len(model.peft_config) == 0:
del model.peft_config
model._hf_peft_config_loaded = None
# 设置适配器的权重并激活它们
def set_weights_and_activate_adapters(model, adapter_names, weights):
# 从 PEFT 库中导入基础调优层
from peft.tuners.tuners_utils import BaseTunerLayer
# 定义获取模块权重的内部函数
def get_module_weight(weight_for_adapter, module_name):
# 如果权重不是字典,直接返回该权重
if not isinstance(weight_for_adapter, dict):
return weight_for_adapter
# 遍历权重字典,查找对应模块的权重
for layer_name, weight_ in weight_for_adapter.items():
if layer_name in module_name:
return weight_
# 分割模块名称为部分
parts = module_name.split(".")
# 生成关键字以获取块权重
key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
block_weight = weight_for_adapter.get(key, 1.0)
return block_weight
# 遍历每个适配器,使其激活并设置对应的缩放权重
for adapter_name, weight in zip(adapter_names, weights):
for module_name, module in model.named_modules():
# 检查模块是否为 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 检查模块是否具备 set_adapter 方法,以兼容旧版本
if hasattr(module, "set_adapter"):
# 设置适配器名称
module.set_adapter(adapter_name)
else:
# 设置当前激活的适配器名称
module.active_adapter = adapter_name
# 设置适配器的缩放权重
module.set_scale(adapter_name, get_module_weight(weight, module_name))
# 设置多个激活的适配器
# 遍历模型中的所有模块
for module in model.modules():
# 检查当前模块是否为 BaseTunerLayer 的实例
if isinstance(module, BaseTunerLayer):
# 为了与以前的 PEFT 版本保持向后兼容
if hasattr(module, "set_adapter"):
# 如果模块具有 set_adapter 方法,则调用该方法并传入适配器名称
module.set_adapter(adapter_names)
else:
# 如果没有 set_adapter 方法,则直接设置 active_adapter 属性
module.active_adapter = adapter_names
# 检查 PEFT 的版本是否兼容
def check_peft_version(min_version: str) -> None:
# 文档字符串,说明该函数的作用和参数
r"""
Checks if the version of PEFT is compatible.
Args:
version (`str`):
The version of PEFT to check against.
"""
# 检查 PEFT 是否可用,若不可用则抛出异常
if not is_peft_available():
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
# 获取当前 PEFT 版本并与最小版本进行比较
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)
# 若版本不兼容,则抛出异常并提示用户使用兼容版本
if not is_peft_version_compatible:
raise ValueError(
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
.\diffusers\utils\pil_utils.py
# 从类型提示模块导入 List
from typing import List
# 导入 PIL.Image 模块用于图像处理
import PIL.Image
# 导入 PIL.ImageOps 模块提供图像操作功能
import PIL.ImageOps
# 从 packaging 模块导入 version 用于版本比较
from packaging import version
# 从 PIL 导入 Image 类用于图像处理
from PIL import Image
# 检查 PIL 的版本是否大于或等于 9.1.0
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
# 定义图像插值方式的字典,使用新版的 Resampling 属性
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
# 定义图像插值方式的字典,使用旧版的插值常量
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# 将 torch 图像转换为 PIL 图像的函数
def pt_to_pil(images):
"""
Convert a torch image to a PIL image.
"""
# 将图像标准化到 [0, 1] 范围内
images = (images / 2 + 0.5).clamp(0, 1)
# 将图像转移到 CPU,并调整维度为 (batch, height, width, channels)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
# 调用 numpy_to_pil 函数将 numpy 图像转换为 PIL 图像
images = numpy_to_pil(images)
# 返回 PIL 图像
return images
# 将 numpy 图像或图像批量转换为 PIL 图像的函数
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
# 如果图像是三维的,增加一个维度以表示批量
if images.ndim == 3:
images = images[None, ...]
# 将图像值转换为 0-255 的整数,并取整
images = (images * 255).round().astype("uint8")
# 检查是否为单通道图像(灰度图像)
if images.shape[-1] == 1:
# 对于灰度图像,使用模式 "L" 创建 PIL 图像
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
# 对于彩色图像,直接从数组创建 PIL 图像
pil_images = [Image.fromarray(image) for image in images]
# 返回 PIL 图像列表
return pil_images
# 创建图像网格的函数,便于可视化
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image:
"""
Prepares a single grid of images. Useful for visualization purposes.
"""
# 断言图像数量与网格大小一致
assert len(images) == rows * cols
# 如果指定了调整大小,则对每个图像进行调整
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
# 获取每个图像的宽度和高度
w, h = images[0].size
# 创建一个新的 RGB 图像,用于放置网格
grid = Image.new("RGB", size=(cols * w, rows * h))
# 将每个图像放置到网格中
for i, img in enumerate(images):
# 计算每个图像的放置位置
grid.paste(img, box=(i % cols * w, i // cols * h))
# 返回生成的图像网格
return grid
.\diffusers\utils\state_dict_utils.py
# 版权声明,指定版权信息及保留权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 许可声明,指定使用文件的条件和限制
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 许可地址,提供获取许可的链接
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 声明在适用情况下,软件按“原样”分发,且没有任何形式的保证或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可以了解特定权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
"""
状态字典工具:用于轻松转换状态字典的实用方法
"""
# 导入枚举模块
import enum
# 从日志模块导入获取记录器的函数
from .logging import get_logger
# 创建一个记录器实例,用于当前模块的日志记录
logger = get_logger(__name__)
# 定义状态字典类型的枚举类
class StateDictType(enum.Enum):
"""
用于转换状态字典时使用的模式。
"""
# 指定不同的状态字典类型
DIFFUSERS_OLD = "diffusers_old"
KOHYA_SS = "kohya_ss"
PEFT = "peft"
DIFFUSERS = "diffusers"
# 定义 Unet 到 Diffusers 的映射,因为它使用不同的输出键与文本编码器
# 例如:to_q_lora -> q_proj / to_q
UNET_TO_DIFFUSERS = {
# 映射 Unet 的上输出到 Diffusers 的对应输出
".to_out_lora.up": ".to_out.0.lora_B",
".to_out_lora.down": ".to_out.0.lora_A",
".to_q_lora.down": ".to_q.lora_A",
".to_q_lora.up": ".to_q.lora_B",
".to_k_lora.down": ".to_k.lora_A",
".to_k_lora.up": ".to_k.lora_B",
".to_v_lora.down": ".to_v.lora_A",
".to_v_lora.up": ".to_v.lora_B",
".lora.up": ".lora_B",
".lora.down": ".lora_A",
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
}
# 定义 Diffusers 到 PEFT 的映射
DIFFUSERS_TO_PEFT = {
# 映射 Diffusers 的层到 PEFT 的对应层
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
}
# 定义 Diffusers_old 到 PEFT 的映射
DIFFUSERS_OLD_TO_PEFT = {
# 映射旧版本 Diffusers 的层到 PEFT 的对应层
".to_q_lora.up": ".q_proj.lora_B",
".to_q_lora.down": ".q_proj.lora_A",
".to_k_lora.up": ".k_proj.lora_B",
".to_k_lora.down": ".k_proj.lora_A",
".to_v_lora.up": ".v_proj.lora_B",
".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
}
# 定义 PEFT 到 Diffusers 的映射
PEFT_TO_DIFFUSERS = {
# 映射 PEFT 的层到 Diffusers 的对应层
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
# 将 lora_B 关联到 lora_linear_layer 的上层
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
# 将 lora_A 关联到 lora_linear_layer 的下层
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
# 将 out_proj 的 lora_B 关联到 lora_linear_layer 的上层
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
# 将 out_proj 的 lora_A 关联到 lora_linear_layer 的下层
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
# 将 to_k 的 lora_A 关联到 lora 的下层
"to_k.lora_A": "to_k.lora.down",
# 将 to_k 的 lora_B 关联到 lora 的上层
"to_k.lora_B": "to_k.lora.up",
# 将 to_q 的 lora_A 关联到 lora 的下层
"to_q.lora_A": "to_q.lora.down",
# 将 to_q 的 lora_B 关联到 lora 的上层
"to_q.lora_B": "to_q.lora.up",
# 将 to_v 的 lora_A 关联到 lora 的下层
"to_v.lora_A": "to_v.lora.down",
# 将 to_v 的 lora_B 关联到 lora 的上层
"to_v.lora_B": "to_v.lora.up",
# 将 to_out.0 的 lora_A 关联到 lora 的下层
"to_out.0.lora_A": "to_out.0.lora.down",
# 将 to_out.0 的 lora_B 关联到 lora 的上层
"to_out.0.lora_B": "to_out.0.lora.up",
# 结束前一个代码块
}
# 定义一个字典,将旧的扩散器键映射到新的扩散器键
DIFFUSERS_OLD_TO_DIFFUSERS = {
# 映射旧的 Q 线性层上升键到新的 Q 线性层上升键
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
# 映射旧的 Q 线性层下降键到新的 Q 线性层下降键
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
# 映射旧的 K 线性层上升键到新的 K 线性层上升键
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
# 映射旧的 K 线性层下降键到新的 K 线性层下降键
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
# 映射旧的 V 线性层上升键到新的 V 线性层上升键
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
# 映射旧的 V 线性层下降键到新的 V 线性层下降键
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
# 映射旧的输出层上升键到新的输出层上升键
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
# 映射旧的输出层下降键到新的输出层下降键
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
# 映射旧的 K 大小向量键到新的 K 大小向量键
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
# 映射旧的 V 大小向量键到新的 V 大小向量键
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
# 映射旧的 Q 大小向量键到新的 Q 大小向量键
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
# 映射旧的输出层大小向量键到新的输出层大小向量键
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
}
# 定义一个字典,将 PEFT 格式映射到 KOHYA_SS 格式
PEFT_TO_KOHYA_SS = {
# 映射 PEFT 格式中的 A 到 KOHYA_SS 格式中的下降
"lora_A": "lora_down",
# 映射 PEFT 格式中的 B 到 KOHYA_SS 格式中的上升
"lora_B": "lora_up",
# 这不是一个全面的字典,因为 KOHYA 格式需要替换键中的 `.` 为 `_`,
# 添加前缀和添加 alpha 值
# 检查 `convert_state_dict_to_kohya` 以了解更多
}
# 定义一个字典,将状态字典类型映射到相应的 DIFFUSERS 映射
PEFT_STATE_DICT_MAPPINGS = {
# 映射旧扩散器类型到 PEFT
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
# 映射新扩散器类型到 PEFT
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
}
# 定义一个字典,将状态字典类型映射到相应的 DIFFUSERS 映射
DIFFUSERS_STATE_DICT_MAPPINGS = {
# 映射旧扩散器类型到旧扩散器映射
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
# 映射 PEFT 类型到 DIFFUSERS
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
# 定义一个字典,将 PEFT 类型映射到 KOHYA 状态字典映射
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
# 定义一个字典,指定总是要替换的键模式
KEYS_TO_ALWAYS_REPLACE = {
# 将处理器的键模式替换为基本形式
".processor.": ".",
}
# 定义函数,转换状态字典
def convert_state_dict(state_dict, mapping):
r"""
简单地遍历状态字典并用 `mapping` 中的模式替换相应的值。
参数:
state_dict (`dict[str, torch.Tensor]`):
要转换的状态字典。
mapping (`dict[str, str]`):
用于转换的映射,映射应为以下结构的字典:
- 键: 要替换的模式
- 值: 要替换成的模式
返回:
converted_state_dict (`dict`)
转换后的状态字典。
"""
# 初始化一个新的转换后的状态字典
converted_state_dict = {}
# 遍历输入的状态字典
for k, v in state_dict.items():
# 首先,过滤出我们总是想替换的键
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
# 如果当前键中包含模式,则替换
if pattern in k:
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
k = k.replace(pattern, new_pattern)
# 遍历映射中的模式
for pattern in mapping.keys():
# 如果当前键中包含模式,则替换
if pattern in k:
new_pattern = mapping[pattern]
k = k.replace(pattern, new_pattern)
break
# 将转换后的键值对添加到新的字典中
converted_state_dict[k] = v
# 返回转换后的状态字典
return converted_state_dict
# 定义函数,将状态字典转换为 PEFT 格式
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
r"""
将状态字典转换为 PEFT 格式,状态字典可以来自旧的扩散器格式(`OLD_DIFFUSERS`)或
新的扩散器格式(`DIFFUSERS`)。该方法目前仅支持从扩散器旧/新格式到 PEFT 的转换。
# 参数说明部分
Args:
state_dict (`dict[str, torch.Tensor]`):
# 要转换的状态字典
The state dict to convert.
original_type (`StateDictType`, *optional*):
# 状态字典的原始类型,如果未提供,方法将尝试自动推断
The original type of the state dict, if not provided, the method will try to infer it automatically.
"""
# 如果原始类型未提供
if original_type is None:
# 检查状态字典的键中是否包含“to_out_lora”,用于判断类型
# 将旧的 diffusers 类型转换为 PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
# 检查状态字典的键中是否包含“lora_linear_layer”
elif any("lora_linear_layer" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS
# 如果无法推断类型,则抛出错误
else:
raise ValueError("Could not automatically infer state dict type")
# 检查推断的原始类型是否在支持的类型映射中
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
# 如果不支持,则抛出错误
raise ValueError(f"Original type {original_type} is not supported")
# 根据原始类型获取对应的映射
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
# 转换状态字典并返回结果
return convert_state_dict(state_dict, mapping)
# 将状态字典转换为新的 diffusers 格式。状态字典可以来自旧的 diffusers 格式
# (`OLD_DIFFUSERS`)、PEFT 格式 (`PEFT`) 或新的 diffusers 格式 (`DIFFUSERS`)。
# 在最后一种情况下,该方法将返回状态字典本身。
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
# 状态字典转换为新格式的文档字符串
r"""
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
return the state dict as is.
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
# 从 kwargs 中获取适配器名称,如果不存在则默认为 None
peft_adapter_name = kwargs.pop("adapter_name", None)
# 如果适配器名称不为 None,前面添加一个点
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
# 否则适配器名称为空字符串
peft_adapter_name = ""
# 如果没有提供原始类型
if original_type is None:
# 检查状态字典的键是否包含 "to_out_lora",若有则设置原始类型为旧 diffusers
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
# 检查键是否包含以适配器名称为前缀的权重
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
# 检查是否包含 "lora_linear_layer",如果有则不需要转换,直接返回状态字典
elif any("lora_linear_layer" in k for k in state_dict.keys()):
# nothing to do
return state_dict
# 如果未能推断出原始类型,则引发值错误
else:
raise ValueError("Could not automatically infer state dict type")
# 检查原始类型是否在支持的状态字典映射中
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
# 如果不支持,抛出值错误
raise ValueError(f"Original type {original_type} is not supported")
# 获取与原始类型相对应的映射
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
# 使用映射转换状态字典并返回
return convert_state_dict(state_dict, mapping)
# 将状态字典从 UNet 格式转换为 diffusers 格式,主要通过移除一些键来实现
def convert_unet_state_dict_to_peft(state_dict):
# 状态字典转换文档字符串
r"""
Converts a state dict from UNet format to diffusers format - i.e. by removing some keys
"""
# 定义 UNet 到 diffusers 的映射
mapping = UNET_TO_DIFFUSERS
# 使用映射转换状态字典并返回
return convert_state_dict(state_dict, mapping)
# 尝试首先将状态字典转换为 PEFT 格式,如果没有检测到有效的 DIFFUSERS LoRA 的 "lora_linear_layer"
# 则尝试专门转换 UNet 状态字典
def convert_all_state_dict_to_peft(state_dict):
# 状态字典转换为 PEFT 格式的文档字符串
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
`DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
"""
# 尝试转换状态字典为 PEFT 格式
try:
peft_dict = convert_state_dict_to_peft(state_dict)
# 捕获异常并赋值给变量 e
except Exception as e:
# 检查异常信息是否为无法自动推断状态字典类型
if str(e) == "Could not automatically infer state dict type":
# 将 UNet 状态字典转换为 PEFT 字典
peft_dict = convert_unet_state_dict_to_peft(state_dict)
else:
# 重新抛出未处理的异常
raise
# 检查 PEFT 字典中是否包含 "lora_A" 或 "lora_B" 的键
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
# 如果没有,则抛出值错误异常
raise ValueError("Your LoRA was not converted to PEFT")
# 返回转换后的 PEFT 字典
return peft_dict
# 定义一个将 PEFT 状态字典转换为 Kohya 格式的函数
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
r"""
将 `PEFT` 状态字典转换为可在 AUTOMATIC1111、ComfyUI、SD.Next、InvokeAI 等中使用的 `Kohya` 格式。
该方法目前仅支持从 PEFT 到 Kohya 的转换。
参数:
state_dict (`dict[str, torch.Tensor]`):
要转换的状态字典。
original_type (`StateDictType`, *可选*):
状态字典的原始类型,如果未提供,方法将尝试自动推断。
kwargs (`dict`, *args*):
传递给该方法的附加参数。
- **adapter_name**: 例如,在 PEFT 的情况下,一些键会被适配器名称预先附加,
因此需要特殊处理。默认情况下,PEFT 也会在
`get_peft_model_state_dict` 方法中处理这一点:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
但我们在这里添加它以防我们不想依赖该方法。
"""
# 尝试导入 torch 库
try:
import torch
# 如果导入失败,记录错误并引发异常
except ImportError:
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
raise
# 从 kwargs 中弹出适配器名称,如果没有提供则为 None
peft_adapter_name = kwargs.pop("adapter_name", None)
# 如果提供了适配器名称,则在前面加上点
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
# 如果没有适配器名称,则设置为空字符串
else:
peft_adapter_name = ""
# 如果未提供原始类型,则检查状态字典中是否包含特定键
if original_type is None:
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
# 检查原始类型是否在支持的类型映射中
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
# 使用适当的映射调用 convert_state_dict 函数
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
# 创建一个空字典以存储最终的 Kohya 状态字典
kohya_ss_state_dict = {}
# 额外逻辑:在所有键中替换头部、alpha 参数的 `.` 为 `_`
for kohya_key, weight in kohya_ss_partial_state_dict.items():
# 替换特定的键名
if "text_encoder_2." in kohya_key:
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
elif "text_encoder." in kohya_key:
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet")
elif "lora_magnitude_vector" in kohya_key:
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
# 将所有键中的 `.` 替换为 `_`,保留最后两个 `.` 不变
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
# 移除适配器名称,Kohya 不需要这些名称
kohya_key = kohya_key.replace(peft_adapter_name, "")
# 将处理后的键及其权重存储到新的字典中
kohya_ss_state_dict[kohya_key] = weight
# 如果键名中包含 `lora_down`,则创建相应的 alpha 键
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
# 将 alpha 键的值设置为权重的长度
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
# 返回保存的状态字典,通常用于恢复训练或推理的状态
return kohya_ss_state_dict
.\diffusers\utils\testing_utils.py
# 导入所需的标准库和第三方库
import functools # 提供高阶函数的工具
import importlib # 处理动态导入模块
import inspect # 获取对象的内部信息
import io # 提供用于处理流的基本工具
import logging # 提供记录日志的功能
import multiprocessing # 处理多进程的支持
import os # 提供与操作系统交互的功能
import random # 生成随机数的工具
import re # 处理正则表达式
import struct # 处理 C 语言结构体的工具
import sys # 处理 Python 解释器的变量和函数
import tempfile # 创建临时文件的工具
import time # 提供时间相关的功能
import unittest # 提供单元测试的框架
import urllib.parse # 处理 URL 的解析和构造
from contextlib import contextmanager # 提供上下文管理器的功能
from io import BytesIO, StringIO # 提供内存中的字节流和字符串流
from pathlib import Path # 处理文件路径的工具
from typing import Callable, Dict, List, Optional, Union # 提供类型注释的工具
import numpy as np # 导入 NumPy 库
import PIL.Image # 导入 PIL 图像处理库
import PIL.ImageOps # 导入 PIL 图像操作功能
import requests # 处理 HTTP 请求的库
from numpy.linalg import norm # 导入计算向量范数的函数
from packaging import version # 处理版本比较的工具
# 导入自定义工具模块中的所需功能
from .import_utils import (
BACKENDS_MAPPING, # 后端映射的字典
is_compel_available, # 检查 Compel 是否可用的函数
is_flax_available, # 检查 Flax 是否可用的函数
is_note_seq_available, # 检查 NoteSeq 是否可用的函数
is_onnx_available, # 检查 ONNX 是否可用的函数
is_opencv_available, # 检查 OpenCV 是否可用的函数
is_peft_available, # 检查 PEFT 是否可用的函数
is_timm_available, # 检查 TIMM 是否可用的函数
is_torch_available, # 检查 PyTorch 是否可用的函数
is_torch_version, # 检查 PyTorch 版本的函数
is_torchsde_available, # 检查 TorchSDE 是否可用的函数
is_transformers_available, # 检查 Transformers 是否可用的函数
)
from .logging import get_logger # 导入自定义日志记录功能
# 创建全局随机数生成器
global_rng = random.Random()
# 获取当前模块的日志记录器
logger = get_logger(__name__)
# 检查是否满足 PEFT 的版本要求
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5") # 检查 PEFT 版本是否大于 0.5
# 检查是否满足 Transformers 的版本要求
_required_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33") # 检查 Transformers 版本是否大于 4.33
# 根据版本要求确定是否使用 PEFT 后端
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
# 如果 PyTorch 可用,执行以下代码
if is_torch_available():
import torch # 导入 PyTorch 库
# 设置用于自定义加速器的后端环境变量
if "DIFFUSERS_TEST_BACKEND" in os.environ:
backend = os.environ["DIFFUSERS_TEST_BACKEND"] # 获取环境变量中的后端名称
try:
_ = importlib.import_module(backend) # 尝试导入指定的后端模块
except ModuleNotFoundError as e: # 捕获模块未找到异常
raise ModuleNotFoundError(
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
to enable a specified backend.):\n{e}" # 提示用户导入失败
) from e # 抛出原始异常
# 如果指定了设备环境变量
if "DIFFUSERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] # 获取指定的设备名称
try:
# 尝试创建设备以验证提供的设备是否有效
_ = torch.device(torch_device)
except RuntimeError as e: # 捕获运行时异常
raise RuntimeError(
f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}" # 提示用户设备未知
) from e # 抛出原始异常
logger.info(f"torch_device overrode to {torch_device}") # 记录当前使用的设备名称
else:
# 默认情况下根据 CUDA 可用性选择设备
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
is_torch_higher_equal_than_1_12 = version.parse(
version.parse(torch.__version__).base_version
) >= version.parse("1.12") # 检查 PyTorch 版本是否大于等于 1.12
if is_torch_higher_equal_than_1_12: # 如果 PyTorch 版本符合条件
# 某些版本的 PyTorch 1.12 未注册 mps 后端,需查看相关问题
mps_backend_registered = hasattr(torch.backends, "mps") # 检查是否注册了 mps 后端
# 根据 mps 后端可用性选择设备
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
# 定义一个函数,用于检查两个 PyTorch 张量是否相近
def torch_all_close(a, b, *args, **kwargs):
# 检查是否安装了 PyTorch,如果没有,则引发异常
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")
# 检查两个张量是否相近,如果不相近,则断言失败,并显示最大差异
if not torch.allclose(a, b, *args, **kwargs):
assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
# 如果相近,返回 True
return True
# 定义一个函数,计算两个 NumPy 数组之间的余弦相似度距离
def numpy_cosine_similarity_distance(a, b):
# 计算两个数组的余弦相似度
similarity = np.dot(a, b) / (norm(a) * norm(b))
# 计算相似度的距离,距离越小越相似
distance = 1.0 - similarity.mean()
# 返回计算得到的距离
return distance
# 定义一个函数,用于打印张量测试结果
def print_tensor_test(
tensor,
limit_to_slices=None,
max_torch_print=None,
filename="test_corrections.txt",
expected_tensor_name="expected_slice",
):
# 如果设置了最大打印数量,则更新 PyTorch 打印选项
if max_torch_print:
torch.set_printoptions(threshold=10_000)
# 获取当前测试的名称
test_name = os.environ.get("PYTEST_CURRENT_TEST")
# 如果输入不是张量,则将其转换为张量
if not torch.is_tensor(tensor):
tensor = torch.from_numpy(tensor)
# 如果设置了切片限制,则限制张量的切片
if limit_to_slices:
tensor = tensor[0, -3:, -3:, -1]
# 将张量转换为字符串格式,并移除换行符
tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
# 将张量字符串格式化为 NumPy 数组形式
output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
# 分离测试名称为文件、类和函数名
test_file, test_class, test_fn = test_name.split("::")
test_fn = test_fn.split()[0]
# 以附加模式打开文件并写入测试结果
with open(filename, "a") as f:
print("::".join([test_file, test_class, test_fn, output_str]), file=f)
# 定义一个函数,用于获取测试目录的路径
def get_tests_dir(append_path=None):
"""
Args:
append_path: optional path to append to the tests dir path
Return:
The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
joined after the `tests` dir the former is provided.
"""
# 获取调用该函数的文件的路径
caller__file__ = inspect.stack()[1][1]
# 获取调用文件所在目录的绝对路径
tests_dir = os.path.abspath(os.path.dirname(caller__file__))
# 循环查找直到找到以 "tests" 结尾的目录
while not tests_dir.endswith("tests"):
tests_dir = os.path.dirname(tests_dir)
# 如果提供了附加路径,则返回合并后的完整路径
if append_path:
return Path(tests_dir, append_path).as_posix()
else:
# 否则返回测试目录路径
return tests_dir
# 从 PR 中提取的函数
# https://github.com/huggingface/accelerate/pull/1964
def str_to_bool(value) -> int:
"""
Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
`on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
"""
# 将输入值转换为小写以进行比较
value = value.lower()
# 检查输入值是否是“真”值,返回 1
if value in ("y", "yes", "t", "true", "on", "1"):
return 1
# 检查输入值是否是“假”值,返回 0
elif value in ("n", "no", "f", "false", "off", "0"):
return 0
else:
# 如果输入值无效,则引发异常
raise ValueError(f"invalid truth value {value}")
# 定义一个函数,从环境变量中解析布尔标志
def parse_flag_from_env(key, default=False):
try:
# 尝试从环境变量中获取值
value = os.environ[key]
except KeyError:
# 如果未设置环境变量,则使用默认值
_value = default
else:
# 如果没有设置 KEY,进入此分支,准备将值转换为布尔值
# 尝试将字符串值转换为布尔值(True 或 False)
try:
_value = str_to_bool(value)
except ValueError:
# 如果转换失败,抛出更具体的错误信息,提示值必须是 'yes' 或 'no'
raise ValueError(f"If set, {key} must be yes or no.")
# 返回转换后的布尔值
return _value
# 从环境变量中解析是否运行慢测试的标志,默认为 False
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
# 从环境变量中解析是否运行夜间测试的标志,默认为 False
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
# 从环境变量中解析是否运行编译测试的标志,默认为 False
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""创建一个随机的 float32 张量"""
# 如果没有提供随机数生成器,则使用全局随机数生成器
if rng is None:
rng = global_rng
# 初始化总维度为 1
total_dims = 1
# 计算张量的总元素数量
for dim in shape:
total_dims *= dim
# 初始化一个空列表以存储随机值
values = []
# 根据总元素数量生成随机值
for _ in range(total_dims):
values.append(rng.random() * scale)
# 将生成的值转换为张量,并按照指定形状调整,返回连续的张量
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
def slow(test_case):
"""
装饰器标记一个测试为慢测试。
慢测试默认被跳过。设置 RUN_SLOW 环境变量为真值以运行它们。
"""
# 如果 _run_slow_tests 为真,则返回原测试,否则跳过测试并提示信息
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def nightly(test_case):
"""
装饰器标记一个每晚在 diffusers CI 中运行的测试。
慢测试默认被跳过。设置 RUN_NIGHTLY 环境变量为真值以运行它们。
"""
# 如果 _run_nightly_tests 为真,则返回原测试,否则跳过测试并提示信息
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
def is_torch_compile(test_case):
"""
装饰器标记一个在 diffusers CI 中运行的编译测试。
编译测试默认被跳过。设置 RUN_COMPILE 环境变量为真值以运行它们。
"""
# 如果 _run_compile_tests 为真,则返回原测试,否则跳过测试并提示信息
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
def require_torch(test_case):
"""
装饰器标记一个需要 PyTorch 的测试。未安装 PyTorch 时这些测试将被跳过。
"""
# 如果 PyTorch 可用,则返回原测试,否则跳过测试并提示信息
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
装饰器标记一个需要 PyTorch 2 的测试。这些测试在未安装 PyTorch 2 时将被跳过。
"""
# 检查 PyTorch 是否可用且版本大于等于 2.0.0,然后返回原测试,否则跳过测试
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
def require_torch_gpu(test_case):
"""装饰器标记一个需要 CUDA 和 PyTorch 的测试。"""
# 检查 PyTorch 是否可用且当前设备为 CUDA,然后返回原测试,否则跳过测试
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
test_case
)
# 这些装饰器用于特定加速器行为,而不是仅限于 GPU
def require_torch_accelerator(test_case):
"""装饰器标记一个需要加速器后端和 PyTorch 的测试。"""
# 检查 PyTorch 是否可用且当前设备不是 CPU,然后返回原测试,否则跳过测试
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
test_case
)
def require_torch_multi_gpu(test_case):
"""
装饰器标记一个需要多 GPU 设置的测试(在 PyTorch 中)。这些测试在没有多个 GPU 的机器上将被跳过。
若要仅运行 multi_gpu 测试,假设所有测试名称包含 multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""
# 检查是否可以使用 PyTorch
if not is_torch_available():
# 如果不可用,则跳过测试,给出提示
return unittest.skip("test requires PyTorch")(test_case)
# 导入 PyTorch 库
import torch
# 跳过测试,除非 GPU 设备数量大于 1
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
# 装饰器,标记需要支持 FP16 数据类型的加速器的测试
def require_torch_accelerator_with_fp16(test_case):
# 如果当前设备支持 FP16,则跳过此装饰器
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
test_case
)
# 装饰器,标记需要支持 FP64 数据类型的加速器的测试
def require_torch_accelerator_with_fp64(test_case):
# 如果当前设备支持 FP64,则跳过此装饰器
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
test_case
)
# 装饰器,标记需要支持训练的加速器的测试
def require_torch_accelerator_with_training(test_case):
# 如果当前设备可用且支持训练,则跳过此装饰器
return unittest.skipUnless(
is_torch_available() and backend_supports_training(torch_device),
"test requires accelerator with training support",
)(test_case)
# 装饰器,标记如果 torch_device 是 'mps' 则跳过测试
def skip_mps(test_case):
# 如果当前设备不是 'mps',则跳过此装饰器
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
# 装饰器,标记需要 JAX 和 Flax 的测试
def require_flax(test_case):
# 如果 Flax 可用,则跳过此装饰器
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
# 装饰器,标记需要 compel 库的测试
def require_compel(test_case):
# 如果 compel 可用,则跳过此装饰器
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
# 装饰器,标记需要 onnxruntime 的测试
def require_onnxruntime(test_case):
# 如果 onnxruntime 可用,则跳过此装饰器
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
# 装饰器,标记需要 note_seq 的测试
def require_note_seq(test_case):
# 如果 note_seq 可用,则跳过此装饰器
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
# 装饰器,标记需要 torchsde 的测试
def require_torchsde(test_case):
# 如果 torchsde 可用,则跳过此装饰器
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
# 装饰器,标记需要 PEFT 后端的测试
def require_peft_backend(test_case):
# 如果需要 PEFT 后端,则跳过此装饰器
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
# 装饰器,标记需要 timm 库的测试
def require_timm(test_case):
# 如果 timm 可用,则跳过此装饰器
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
# 装饰器,标记需要特定版本 PEFT 的测试
def require_peft_version_greater(peft_version):
# 装饰器标记一个测试,该测试要求特定版本的 PEFT 后端,需满足特定版本
"""
# 定义装饰器函数,接受一个测试用例作为参数
def decorator(test_case):
# 检查 PEFT 是否可用,并且当前版本是否大于指定的 PEFT 版本
correct_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(peft_version)
# 如果满足版本要求,则跳过此测试,并提供相应的提示信息
return unittest.skipUnless(
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
)(test_case)
# 返回装饰器函数
return decorator
# 定义一个装饰器,要求加速版本大于给定版本
def require_accelerate_version_greater(accelerate_version):
# 装饰器内部函数,用于装饰测试用例
def decorator(test_case):
# 检查 PEFT 是否可用,并解析当前加速库版本
correct_accelerate_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
# 根据版本判断是否跳过测试用例
return unittest.skipUnless(
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
)(test_case)
return decorator
# 定义一个装饰器,标记将在 PEFT 后端之后跳过的测试
def deprecate_after_peft_backend(test_case):
"""
装饰器,标记将在 PEFT 后端之后跳过的测试
"""
# 根据是否使用 PEFT 后端决定是否跳过测试
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
# 获取当前 Python 版本
def get_python_version():
# 获取系统的版本信息
sys_info = sys.version_info
# 提取主版本和次版本
major, minor = sys_info.major, sys_info.minor
# 返回主版本和次版本
return major, minor
# 加载 NumPy 数组,支持 URL 和本地路径
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
# 检查 arry 是否为字符串
if isinstance(arry, str):
if local_path is not None:
# local_path 用于修正测试的图像路径
return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
# 检查字符串是否为 URL
elif arry.startswith("http://") or arry.startswith("https://"):
response = requests.get(arry) # 发送请求获取内容
response.raise_for_status() # 检查请求是否成功
arry = np.load(BytesIO(response.content)) # 从响应内容加载 NumPy 数组
# 检查字符串是否为有效的文件路径
elif os.path.isfile(arry):
arry = np.load(arry) # 从文件路径加载 NumPy 数组
else:
# 抛出路径或 URL 不正确的错误
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
)
# 检查 arry 是否为 NumPy 数组
elif isinstance(arry, np.ndarray):
pass # 如果是 NumPy 数组则不做任何处理
else:
# 抛出格式不正确的错误
raise ValueError(
"Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
" ndarray."
)
return arry # 返回处理后的数组
# 从给定 URL 加载 PyTorch 张量
def load_pt(url: str):
response = requests.get(url) # 发送请求获取内容
response.raise_for_status() # 检查请求是否成功
arry = torch.load(BytesIO(response.content)) # 从响应内容加载 PyTorch 张量
return arry # 返回加载的张量
# 加载图像,支持字符串和 PIL.Image.Image 类型
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
将 `image` 加载为 PIL 图像。
参数:
image (`str` 或 `PIL.Image.Image`):
要转换为 PIL 图像格式的图像。
返回:
`PIL.Image.Image`:
一个 PIL 图像。
"""
# 检查 image 是否为字符串
if isinstance(image, str):
# 检查字符串是否为 URL
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw) # 从 URL 加载图像
# 检查字符串是否为有效的文件路径
elif os.path.isfile(image):
image = PIL.Image.open(image) # 从文件路径加载图像
else:
# 抛出路径或 URL 不正确的错误
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
# 检查 image 是否为 PIL 图像
elif isinstance(image, PIL.Image.Image):
image = image # 如果是 PIL 图像则不做任何处理
# 如果不是正确的格式,则引发一个值错误
else:
# 提供详细的错误信息,说明接受的格式要求
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
# 对图像应用 EXIF 变换,调整其方向
image = PIL.ImageOps.exif_transpose(image)
# 将图像转换为 RGB 模式
image = image.convert("RGB")
# 返回处理后的图像
return image
# 预处理输入图像以适应特定模型需求
def preprocess_image(image: PIL.Image, batch_size: int):
# 获取图像的宽度和高度
w, h = image.size
# 调整宽度和高度为8的整数倍,以便于处理
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
# 根据新的宽高调整图像大小,使用高质量重采样
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
# 将图像数据转换为numpy数组并归一化到[0, 1]区间
image = np.array(image).astype(np.float32) / 255.0
# 扩展图像维度并复制以形成batch
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
# 将numpy数组转换为PyTorch张量
image = torch.from_numpy(image)
# 将图像值从[0, 1]范围缩放到[-1, 1]范围
return 2.0 * image - 1.0
# 将图像序列导出为GIF文件
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
# 如果未指定输出路径,则创建临时GIF文件
if output_gif_path is None:
output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
# 保存第一帧,并附加后续帧
image[0].save(
output_gif_path,
save_all=True,
append_images=image[1:],
optimize=False,
duration=100,
loop=0,
)
# 返回生成的GIF文件路径
return output_gif_path
# 上下文管理器,用于缓冲写入文件
@contextmanager
def buffered_writer(raw_f):
# 创建一个缓冲写入对象
f = io.BufferedWriter(raw_f)
# 暴露缓冲写入对象给上下文
yield f
# 刷新缓冲区,确保所有数据写入
f.flush()
# 导出网格为PLY文件
def export_to_ply(mesh, output_ply_path: str = None):
"""
写入一个网格的PLY文件。
"""
# 如果未指定输出路径,则创建临时PLY文件
if output_ply_path is None:
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
# 获取顶点坐标并转移到CPU和NumPy数组
coords = mesh.verts.detach().cpu().numpy()
# 获取面索引并转移到CPU和NumPy数组
faces = mesh.faces.cpu().numpy()
# 获取RGB颜色通道数据并堆叠为RGB格式
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
# 使用缓冲写入器打开PLY文件进行写入
with buffered_writer(open(output_ply_path, "wb")) as f:
# 写入PLY文件头
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
# 如果存在RGB数据,写入颜色属性
if rgb is not None:
f.write(b"property uchar red\n")
f.write(b"property uchar green\n")
f.write(b"property uchar blue\n")
# 如果存在面数据,写入面索引
if faces is not None:
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
f.write(b"property list uchar int vertex_index\n")
# 写入文件头结束标志
f.write(b"end_header\n")
# 如果存在RGB数据,处理顶点数据
if rgb is not None:
rgb = (rgb * 255.499).round().astype(int)
vertices = [
(*coord, *rgb)
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
# 定义数据打包格式
format = struct.Struct("<3f3B")
# 写入每个顶点的数据
for item in vertices:
f.write(format.pack(*item))
else:
# 定义顶点数据打包格式
format = struct.Struct("<3f")
# 写入每个顶点坐标
for vertex in coords.tolist():
f.write(format.pack(*vertex))
# 如果存在面数据,处理面索引数据
if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist():
f.write(format.pack(len(tri), *tri))
# 返回生成的PLY文件路径
return output_ply_path
# 导出网格为OBJ文件
def export_to_obj(mesh, output_obj_path: str = None):
# 如果未指定输出路径,则创建临时OBJ文件
if output_obj_path is None:
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
# 获取顶点坐标并转移到CPU和NumPy数组
verts = mesh.verts.detach().cpu().numpy()
# 获取面索引并转移到CPU和NumPy数组
faces = mesh.faces.cpu().numpy()
# 将 mesh 中的顶点颜色通道提取并转换为 NumPy 数组,堆叠成一个数组
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
# 将顶点坐标和颜色组合成字符串列表,格式化为 OBJ 文件所需的顶点定义
vertices = [
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
]
# 将面数据格式化为 OBJ 文件所需的面定义,面索引从 0 转为从 1 开始
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
# 合并顶点数据和面数据,形成完整的 OBJ 数据列表
combined_data = ["v " + vertex for vertex in vertices] + faces
# 打开指定路径的文件,以写入模式创建文件对象 f
with open(output_obj_path, "w") as f:
# 将合并的数据写入文件,每个条目之间用换行符分隔
f.writelines("\n".join(combined_data))
# 导出视频帧为视频文件,返回输出视频的路径
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
# 检查 OpenCV 是否可用
if is_opencv_available():
# 导入 OpenCV 库
import cv2
else:
# 如果不可用,抛出导入错误
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
# 如果未指定输出视频路径,则创建临时文件
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
# 设置视频编码格式为 mp4
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# 获取第一帧的高度、宽度和通道数
h, w, c = video_frames[0].shape
# 创建视频写入对象
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
# 遍历每一帧视频
for i in range(len(video_frames)):
# 将帧从 RGB 转换为 BGR 格式
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
# 将转换后的帧写入视频
video_writer.write(img)
# 返回输出视频的路径
return output_video_path
# 加载 Hugging Face 上的 NumPy 数组
def load_hf_numpy(path) -> np.ndarray:
# 定义基础 URL
base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
# 如果路径不是以 http 或 https 开头,则拼接基础 URL
if not path.startswith("http://") and not path.startswith("https://"):
path = os.path.join(base_url, urllib.parse.quote(path))
# 加载并返回 NumPy 数组
return load_numpy(path)
# --- pytest 配置函数 --- #
# 避免在测试中多次调用,确保只调用一次
pytest_opt_registered = {}
def pytest_addoption_shared(parser):
"""
该函数从 `conftest.py` 调用,使用 `pytest_addoption` 包装器。
允许同时加载两个 `conftest.py` 文件,而不会因添加相同的 pytest 选项而失败。
"""
option = "--make-reports"
# 检查选项是否已注册
if option not in pytest_opt_registered:
# 添加选项到解析器
parser.addoption(
option,
action="store",
default=False,
help="生成报告文件。此选项的值用作报告名称的前缀",
)
# 标记选项已注册
pytest_opt_registered[option] = 1
def pytest_terminal_summary_main(tr, id):
"""
在测试套件运行结束时生成多个报告 - 每个报告存储在当前目录的独立文件中。
报告文件以测试套件名称为前缀。
模拟 --duration 和 -rA pytest 参数。
从 `conftest.py` 调用此函数。
Args:
- tr: 从 `conftest.py` 传递的 `terminalreporter`
- id: 像 `tests` 或 `examples` 的唯一 ID,将并入最终报告文件名
"""
from _pytest.config import create_terminal_writer
# 如果没有提供 ID,默认为 "tests"
if not len(id):
id = "tests"
# 获取配置和原始终端写入器
config = tr.config
orig_writer = config.get_terminal_writer()
orig_tbstyle = config.option.tbstyle
# 保存原始报告字符设置
orig_reportchars = tr.reportchars
# 定义报告文件存放目录
dir = "reports"
# 创建目录,父目录可选,如果已存在则不报错
Path(dir).mkdir(parents=True, exist_ok=True)
# 生成报告文件的字典,包含不同报告类型的文件名
report_files = {
k: f"{dir}/{id}_{k}.txt"
for k in [
"durations",
"errors",
"failures_long",
"failures_short",
"failures_line",
"passes",
"stats",
"summary_short",
"warnings",
]
}
# 自定义持续时间报告
# 注意:无需调用 pytest --durations=XX 获取此单独报告
# 来源于 https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
dlist = [] # 初始化持续时间列表
# 遍历测试结果中的统计数据
for replist in tr.stats.values():
for rep in replist:
# 如果报告有持续时间属性,则添加到列表中
if hasattr(rep, "duration"):
dlist.append(rep)
# 如果持续时间列表非空
if dlist:
# 按持续时间降序排序
dlist.sort(key=lambda x: x.duration, reverse=True)
# 打开文件以写入持续时间报告
with open(report_files["durations"], "w") as f:
durations_min = 0.05 # 最小持续时间(秒)
# 写入标题
f.write("slowest durations\n")
# 遍历报告并写入文件
for i, rep in enumerate(dlist):
# 如果报告的持续时间小于最小值,写入省略信息并退出循环
if rep.duration < durations_min:
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
break
# 写入每条报告的持续时间、执行时机和节点ID
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
# 定义简短失败摘要的函数
def summary_failures_short(tr):
# 假设报告使用了长格式,截取最后一帧
reports = tr.getreports("failed")
# 如果没有失败报告,则直接返回
if not reports:
return
# 写入分隔符和标题
tr.write_sep("=", "FAILURES SHORT STACK")
# 遍历失败报告
for rep in reports:
# 获取失败消息
msg = tr._getfailureheadline(rep)
# 写入分隔符和消息
tr.write_sep("_", msg, red=True, bold=True)
# 去掉可选的前导额外帧,只保留最后一帧
longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
tr._tw.line(longrepr)
# 注意:不打印任何 rep.sections 以保持报告简短
# 使用现成的报告函数,劫持文件句柄以记录到专用文件
# 来源于 https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
# 注意:某些 pytest 插件可能会干扰默认的 `terminalreporter`(例如,pytest-instafail)
# 以行/短/长样式报告失败
config.option.tbstyle = "auto" # 完整的回溯
# 打开文件以写入长失败报告
with open(report_files["failures_long"], "w") as f:
# 创建终端写入器,并劫持到文件
tr._tw = create_terminal_writer(config, f)
# 写入失败摘要
tr.summary_failures()
# config.option.tbstyle = "short" # 短回溯
# 打开文件以写入短失败报告
with open(report_files["failures_short"], "w") as f:
# 创建终端写入器,并劫持到文件
tr._tw = create_terminal_writer(config, f)
# 写入简短失败摘要
summary_failures_short(tr)
# 将回溯样式设置为每个错误一行
config.option.tbstyle = "line" # 一行每个错误
# 打开文件以写入行失败报告
with open(report_files["failures_line"], "w") as f:
# 创建终端写入器,并劫持到文件
tr._tw = create_terminal_writer(config, f)
# 写入失败摘要
tr.summary_failures()
# 打开指定的错误报告文件以写入模式
with open(report_files["errors"], "w") as f:
# 创建终端写入器并将其赋值给 tr._tw
tr._tw = create_terminal_writer(config, f)
# 输出错误摘要
tr.summary_errors()
# 打开指定的警告报告文件以写入模式
with open(report_files["warnings"], "w") as f:
# 创建终端写入器并将其赋值给 tr._tw
tr._tw = create_terminal_writer(config, f)
# 输出常规警告摘要
tr.summary_warnings() # normal warnings
# 输出最终警告摘要
tr.summary_warnings() # final warnings
# 设置报告字符,用于模拟 -rA 选项(在 summary_passes() 和 short_test_summary() 中使用)
tr.reportchars = "wPpsxXEf"
# 打开指定的通过报告文件以写入模式
with open(report_files["passes"], "w") as f:
# 创建终端写入器并将其赋值给 tr._tw
tr._tw = create_terminal_writer(config, f)
# 输出通过的摘要
tr.summary_passes()
# 打开指定的短摘要报告文件以写入模式
with open(report_files["summary_short"], "w") as f:
# 创建终端写入器并将其赋值给 tr._tw
tr._tw = create_terminal_writer(config, f)
# 输出短测试摘要
tr.short_test_summary()
# 打开指定的统计报告文件以写入模式
with open(report_files["stats"], "w") as f:
# 创建终端写入器并将其赋值给 tr._tw
tr._tw = create_terminal_writer(config, f)
# 输出统计摘要
tr.summary_stats()
# 恢复原始设置:
# 恢复终端写入器为原始写入器
tr._tw = orig_writer
# 恢复报告字符为原始字符
tr.reportchars = orig_reportchars
# 恢复配置的跟踪样式为原始样式
config.option.tbstyle = orig_tbstyle
# 从 GitHub 复制的代码,装饰器用于处理不稳定的测试
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
装饰不稳定的测试,失败时会重试。
参数:
max_attempts (`int`, *可选*, 默认值为 5):
重新尝试不稳定测试的最大尝试次数。
wait_before_retry (`float`, *可选*):
如果提供,重试测试前将等待该秒数。
description (`str`, *可选*):
描述情况的字符串(什么/在哪里/为什么不稳定,链接到 GitHub 问题/PR 评论、错误等)。
"""
# 装饰器函数,接收待装饰的测试函数
def decorator(test_func_ref):
# 包装函数,处理重试逻辑
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1 # 初始化重试计数
# 在最大尝试次数内循环
while retry_count < max_attempts:
try:
# 尝试调用测试函数并返回结果
return test_func_ref(*args, **kwargs)
except Exception as err: # 捕获测试函数的异常
# 打印错误信息和当前重试次数
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
if wait_before_retry is not None: # 如果提供了等待时间
time.sleep(wait_before_retry) # 等待指定的秒数
retry_count += 1 # 增加重试计数
# 达到最大重试次数后再次尝试调用测试函数并返回结果
return test_func_ref(*args, **kwargs)
return wrapper # 返回包装后的函数
return decorator # 返回装饰器
# 从 GitHub 复制的代码,用于在子进程中运行测试
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
"""
在子进程中运行测试。这可以避免 (GPU) 内存问题。
参数:
test_case (`unittest.TestCase`):
将运行 `target_func` 的测试用例。
target_func (`Callable`):
实现实际测试逻辑的函数。
inputs (`dict`, *可选*, 默认值为 `None`):
将通过(输入)队列传递给 `target_func` 的输入。
timeout (`int`, *可选*, 默认值为 `None`):
将传递给输入和输出队列的超时(以秒为单位)。如果未指定,将检查环境变量 `PYTEST_TIMEOUT`。如果仍为 `None`,其值将设置为 `600`。
"""
if timeout is None:
# 获取超时设置,如果环境变量未设置则默认为 600 秒
timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
start_methohd = "spawn" # 定义子进程的启动方法为 "spawn"
ctx = multiprocessing.get_context(start_methohd) # 获取指定上下文的 multiprocessing
input_queue = ctx.Queue(1) # 创建一个输入队列,最大大小为 1
output_queue = ctx.JoinableQueue(1) # 创建一个可加入的输出队列,最大大小为 1
# 我们不能将 `unittest.TestCase` 发送到子进程,否则会出现关于 pickle 的问题。
input_queue.put(inputs, timeout=timeout) # 将输入放入队列,指定超时时间
# 创建并启动子进程,指定目标函数和参数
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
process.start() # 启动子进程
# 如果不能及时从子进程获取输出,则终止子进程:否则,悬挂的子进程会阻塞
# 测试以正确的方式退出
try:
# 从输出队列中获取结果,设置超时
results = output_queue.get(timeout=timeout)
# 标记任务为已完成
output_queue.task_done()
except Exception as e:
# 处理异常,终止进程
process.terminate()
# 记录测试失败的原因
test_case.fail(e)
# 等待进程结束,设置超时
process.join(timeout=timeout)
# 检查结果中的错误信息是否存在
if results["error"] is not None:
# 记录测试失败的具体错误信息
test_case.fail(f'{results["error"]}')
# 定义一个上下文管理器类,用于捕获日志流
class CaptureLogger:
"""
参数:
上下文管理器,用于捕获 `logging` 流
logger: 'logging` 日志对象
返回:
捕获的输出可以通过 `self.out` 获取
示例:
```python
>>> from diffusers import logging
>>> from diffusers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg + "\n"
```py
"""
# 初始化 CaptureLogger 类,接受一个日志对象
def __init__(self, logger):
self.logger = logger # 保存日志对象
self.io = StringIO() # 创建一个字符串流用于捕获日志
self.sh = logging.StreamHandler(self.io) # 创建流处理器以写入字符串流
self.out = "" # 初始化捕获的输出为空字符串
# 进入上下文时添加日志处理器
def __enter__(self):
self.logger.addHandler(self.sh) # 将处理器添加到日志对象
return self # 返回当前实例
# 退出上下文时移除日志处理器并获取输出
def __exit__(self, *exc):
self.logger.removeHandler(self.sh) # 移除处理器
self.out = self.io.getvalue() # 获取捕获的日志输出
# 返回捕获的日志字符串的表示形式
def __repr__(self):
return f"captured: {self.out}\n"
# 启用全确定性以保证分布式训练中的可重现性
def enable_full_determinism():
"""
帮助函数以保证分布式训练期间的可重现行为。请参见
- https://pytorch.org/docs/stable/notes/randomness.html 以了解 PyTorch
"""
# 启用 PyTorch 确定性模式。可能需要设置环境变量 'CUDA_LAUNCH_BLOCKING' 或 'CUBLAS_WORKSPACE_CONFIG'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # 设置 CUDA 同步执行
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" # 配置 CUBLAS 工作区
torch.use_deterministic_algorithms(True) # 启用确定性算法
# 启用 CUDNN 确定性模式
torch.backends.cudnn.deterministic = True # 设置 CUDNN 为确定性
torch.backends.cudnn.benchmark = False # 禁用基准优化
torch.backends.cuda.matmul.allow_tf32 = False # 禁用 TF32 精度
# 禁用全确定性以恢复非确定性行为
def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0" # 关闭 CUDA 同步执行
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" # 清除 CUBLAS 工作区配置
torch.use_deterministic_algorithms(False) # 禁用确定性算法
# 检查给定设备上是否支持 FP16
def _is_torch_fp16_available(device):
if not is_torch_available(): # 如果 PyTorch 不可用,返回 False
return False
import torch # 导入 PyTorch
device = torch.device(device) # 将设备字符串转换为设备对象
try:
x = torch.zeros((2, 2), dtype=torch.float16).to(device) # 创建 FP16 张量并移动到指定设备
_ = torch.mul(x, x) # 进行乘法运算以检查支持情况
return True # 如果没有异常,返回 True
except Exception as e: # 捕获异常
if device.type == "cuda": # 如果设备类型为 cuda
raise ValueError(
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
) # 抛出错误,提示 cuda 安装问题
return False # 其他设备返回 False
# 检查给定设备上是否支持 FP64
def _is_torch_fp64_available(device):
if not is_torch_available(): # 如果 PyTorch 不可用,返回 False
return False
import torch # 导入 PyTorch
device = torch.device(device) # 将设备字符串转换为设备对象
try:
x = torch.zeros((2, 2), dtype=torch.float64).to(device) # 创建 FP64 张量并移动到指定设备
_ = torch.mul(x, x) # 进行乘法运算以检查支持情况
return True # 如果没有异常,返回 True
# 捕获异常并处理
except Exception as e:
# 检查设备类型是否为 'cuda'
if device.type == "cuda":
# 引发值错误,提示 'cuda' 应该支持 'fp64',但似乎未正确安装
raise ValueError(
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
)
# 返回 False,表示操作失败
return False
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available(): # 检查 PyTorch 是否可用
# Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} # 定义支持训练的后端标志
# Function definitions
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} # 定义后端清空缓存的函数
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} # 定义获取设备数量的函数
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} # 定义设置随机种子的函数
# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): # 根据设备和函数表调度相应的函数
if device not in dispatch_table: # 如果设备不在调度表中
return dispatch_table["default"](*args, **kwargs) # 返回默认函数的结果
fn = dispatch_table[device] # 获取对应设备的函数
# Some device agnostic functions return values. Need to guard against 'None' instead at
# user level
if fn is None: # 如果函数为 None
return None # 返回 None
return fn(*args, **kwargs) # 调用并返回函数的结果
# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int): # 设置特定设备的随机种子
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) # 调用调度函数
def backend_empty_cache(device: str): # 清空特定设备的缓存
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) # 调用调度函数
def backend_device_count(device: str): # 获取特定设备的数量
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) # 调用调度函数
# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str): # 检查特定设备是否支持训练
if not is_torch_available(): # 如果 PyTorch 不可用
return False # 返回 False
if device not in BACKEND_SUPPORTS_TRAINING: # 如果设备不在支持训练的字典中
device = "default" # 使用默认设备
return BACKEND_SUPPORTS_TRAINING[device] # 返回该设备的支持训练标志
# Guard for when Torch is not available
if is_torch_available(): # 检查 PyTorch 是否可用
# Update device function dict mapping
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): # 更新设备函数字典映射
try:
# Try to import the function directly
spec_fn = getattr(device_spec_module, attribute_name) # 尝试从模块中获取指定属性
device_fn_dict[torch_device] = spec_fn # 更新设备函数字典
except AttributeError as e: # 捕获属性错误
# If the function doesn't exist, and there is no default, throw an error
if "default" not in device_fn_dict: # 如果字典中没有默认函数
raise AttributeError( # 抛出属性错误
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
) from e # 追踪原始错误
# 检查环境变量中是否存在特定的设备规格
if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
# 获取设备规格文件的路径
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
# 检查指定的设备规格文件路径是否是一个文件
if not Path(device_spec_path).is_file():
# 如果文件不存在,抛出值错误并给出提示
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
try:
# 从文件路径中提取模块名称(去掉.py后缀)
import_name = device_spec_path[: device_spec_path.index(".py")]
except ValueError as e:
# 如果路径中没有找到.py,抛出值错误
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
# 动态导入设备规格模块
device_spec_module = importlib.import_module(import_name)
try:
# 从模块中获取设备名称
device_name = device_spec_module.DEVICE_NAME
except AttributeError:
# 如果模块中没有DEVICE_NAME,抛出属性错误
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
# 检查环境变量中是否存在另一设备名称,并与当前设备名称进行比较
if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
# 如果不匹配,构建错误信息并抛出值错误
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
raise ValueError(msg)
# 将torch_device设置为提取的设备名称
torch_device = device_name
# 对每个`BACKEND_*`字典添加一条条目
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")


浙公网安备 33010602011771号