diffusers 源码解析(四十一)
.\diffusers\pipelines\pag\__init__.py
# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING
# 从相对路径的 utils 模块中导入多个工具函数和常量
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 用于慢速导入的常量
OptionalDependencyNotAvailable, # 可选依赖不可用异常
_LazyModule, # 用于懒加载模块的类
get_objects_from_module, # 从模块中获取对象的函数
is_flax_available, # 检查 flax 库是否可用的函数
is_torch_available, # 检查 torch 库是否可用的函数
is_transformers_available, # 检查 transformers 库是否可用的函数
)
# 定义一个空字典用于存放虚拟对象
_dummy_objects = {}
# 定义一个空字典用于存放导入结构
_import_structure = {}
try:
# 检查 transformers 和 torch 库是否可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,抛出可选依赖不可用异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 如果捕获到可选依赖不可用异常,从 utils 中导入虚拟对象
from ...utils import dummy_torch_and_transformers_objects # noqa F403
# 更新 _dummy_objects 字典,获取虚拟对象
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
# 如果库可用,更新 _import_structure 字典,添加各个管道的导入信息
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
# 如果在类型检查或慢速导入模式下
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
# 检查 transformers 和 torch 库是否可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,抛出可选依赖不可用异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 捕获异常时,从 utils 中导入虚拟对象
from ...utils.dummy_torch_and_transformers_objects import *
else:
# 如果库可用,导入具体的管道实现
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_kolors import KolorsPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
else:
# 如果不在类型检查或慢速导入模式下,导入 sys 模块
import sys
# 使用 _LazyModule 创建一个懒加载的模块
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"], # 当前模块的文件路径
_import_structure, # 导入结构
module_spec=__spec__, # 模块规范
)
# 将 _dummy_objects 中的虚拟对象添加到当前模块
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
.\diffusers\pipelines\paint_by_example\image_encoder.py
# 版权声明,说明代码的版权所有者及使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 该文件根据 Apache License, Version 2.0("许可证")授权;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件在“按现状”基础上分发,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参见许可证以了解有关权限和限制的具体语言。
import torch # 导入 PyTorch 库
from torch import nn # 从 PyTorch 导入神经网络模块
from transformers import CLIPPreTrainedModel, CLIPVisionModel # 导入 CLIP 相关模型
from ...models.attention import BasicTransformerBlock # 导入基本变换器块
from ...utils import logging # 导入日志工具
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器,禁用 Pylint 对无效名称的警告
class PaintByExampleImageEncoder(CLIPPreTrainedModel): # 定义图像编码器类,继承自预训练的 CLIP 模型
def __init__(self, config, proj_size=None): # 初始化方法,接收配置和可选的投影大小
super().__init__(config) # 调用父类初始化方法
self.proj_size = proj_size or getattr(config, "projection_dim", 768) # 设置投影大小,默认为 768
self.model = CLIPVisionModel(config) # 创建 CLIP 视觉模型实例
self.mapper = PaintByExampleMapper(config) # 创建映射器实例
self.final_layer_norm = nn.LayerNorm(config.hidden_size) # 创建层归一化实例
self.proj_out = nn.Linear(config.hidden_size, self.proj_size) # 创建线性变换,用于输出投影
# 用于缩放的无条件向量
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) # 初始化无条件向量为随机值
def forward(self, pixel_values, return_uncond_vector=False): # 前向传播方法,接收像素值和是否返回无条件向量的标志
clip_output = self.model(pixel_values=pixel_values) # 通过模型处理像素值
latent_states = clip_output.pooler_output # 获取池化输出作为潜在状态
latent_states = self.mapper(latent_states[:, None]) # 使用映射器处理潜在状态
latent_states = self.final_layer_norm(latent_states) # 对潜在状态进行层归一化
latent_states = self.proj_out(latent_states) # 通过线性层进行投影
if return_uncond_vector: # 如果需要返回无条件向量
return latent_states, self.uncond_vector # 返回潜在状态和无条件向量
return latent_states # 否则仅返回潜在状态
class PaintByExampleMapper(nn.Module): # 定义映射器类,继承自 PyTorch 的 nn.Module
def __init__(self, config): # 初始化方法,接收配置
super().__init__() # 调用父类初始化方法
num_layers = (config.num_hidden_layers + 1) // 5 # 计算层数,确保至少为 1
hid_size = config.hidden_size # 获取隐藏层大小
num_heads = 1 # 设置注意力头的数量为 1
self.blocks = nn.ModuleList( # 创建模块列表,包含多个变换器块
[
BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) # 添加变换器块
for _ in range(num_layers) # 根据层数创建多个块
]
)
def forward(self, hidden_states): # 前向传播方法,接收隐藏状态
for block in self.blocks: # 遍历所有变换器块
hidden_states = block(hidden_states) # 依次通过每个块处理隐藏状态
return hidden_states # 返回最终的隐藏状态
.\diffusers\pipelines\paint_by_example\pipeline_paint_by_example.py
# 版权声明,表明该文件的所有权和使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证,使用该文件的条件
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 可以在此获取许可证的副本
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,否则软件按“原样”分发,不提供任何形式的保证
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取有关权限和限制的具体信息
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect # 导入inspect模块,用于获取对象的信息
from typing import Callable, List, Optional, Union # 导入类型提示功能
import numpy as np # 导入numpy库,用于数值计算
import PIL.Image # 导入PIL库,用于图像处理
import torch # 导入PyTorch库,用于深度学习
from transformers import CLIPImageProcessor # 导入CLIP图像处理器
from ...image_processor import VaeImageProcessor # 从上级模块导入VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel # 导入模型相关类
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler # 导入调度器
from ...utils import deprecate, logging # 导入工具函数和日志记录
from ...utils.torch_utils import randn_tensor # 从工具模块导入随机张量生成函数
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin # 导入扩散管道和混合类
from ..stable_diffusion import StableDiffusionPipelineOutput # 导入稳定扩散管道的输出类
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker # 导入安全检查器
from .image_encoder import PaintByExampleImageEncoder # 从当前模块导入图像编码器
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器,便于调试
# pylint: disable=invalid-name # 禁用pylint关于名称的无效警告
# 从diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img复制的函数
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
# 检查encoder_output是否具有latent_dist属性且采样模式为'sample'
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
# 返回latent分布的样本
return encoder_output.latent_dist.sample(generator)
# 检查encoder_output是否具有latent_dist属性且采样模式为'argmax'
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
# 返回latent分布的众数
return encoder_output.latent_dist.mode()
# 检查encoder_output是否具有latents属性
elif hasattr(encoder_output, "latents"):
# 返回latents属性
return encoder_output.latents
# 如果都不满足,抛出异常
else:
raise AttributeError("Could not access latents of provided encoder_output")
# 准备图像和掩码以供“按示例绘制”管道使用
def prepare_mask_and_masked_image(image, mask):
"""
准备一对 (image, mask),使其可以被 Paint by Example 管道使用。
这意味着这些输入将转换为``torch.Tensor``,形状为``batch x channels x height x width``,
其中``channels``为``3``(对于``image``)和``1``(对于``mask``)。
``image`` 将转换为 ``torch.float32`` 并归一化为 ``[-1, 1]``。
``mask`` 将被二值化(``mask > 0.5``)并同样转换为 ``torch.float32``。
```
# 函数参数说明
Args:
# 输入图像,类型可以是 np.array、PIL.Image 或 torch.Tensor
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
# 描述图像的不同可能格式,包括 PIL.Image、np.array 或 torch.Tensor
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
# 掩码,用于指定需要修复的区域
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
# 描述掩码的不同可能格式,类似于图像
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
# 异常说明
Raises:
# 触发条件为 torch.Tensor 格式图像或掩码的数值范围不正确
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
# 类型错误,当图像和掩码类型不匹配时抛出
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
# 返回值说明
Returns:
# 返回一个包含掩码和修复图像的元组,均为 torch.Tensor 格式,具有 4 个维度
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
# 检查输入图像是否为 torch.Tensor 类型
if isinstance(image, torch.Tensor):
# 如果掩码不是 torch.Tensor,抛出类型错误
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# 如果图像为单个图像,将其转换为批处理格式
# Batch single image
if image.ndim == 3:
# 确保单个图像的形状为 (3, H, W)
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
# 在第一个维度添加批处理维度
image = image.unsqueeze(0)
# 如果掩码为二维,添加批处理和通道维度
# Batch and add channel dim for single mask
if mask.ndim == 2:
# 在前面添加两个维度
mask = mask.unsqueeze(0).unsqueeze(0)
# 如果掩码为三维,检查其与图像的批次匹配
# Batch single mask or add channel dim
if mask.ndim == 3:
# Batched mask
if mask.shape[0] == image.shape[0]:
# 如果掩码的批次与图像相同,添加通道维度
mask = mask.unsqueeze(1)
else:
# 否则,在前面添加批处理维度
mask = mask.unsqueeze(0)
# 确保图像和掩码都是四维
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
# 确保图像和掩码的空间维度相同
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
# 确保图像和掩码的批处理大小相同
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# 确保掩码只有一个通道
assert mask.shape[1] == 1, "Mask image must have a single channel"
# 检查图像的数值范围是否在 [-1, 1] 之间
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# 检查掩码的数值范围是否在 [0, 1] 之间
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# 对掩码进行反转,以便于修复
# paint-by-example inverses the mask
mask = 1 - mask
# 二值化掩码
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# 将图像转换为 float32 类型
# Image as float32
image = image.to(dtype=torch.float32)
# 如果掩码是 torch.Tensor 类型,但图像不是,则抛出类型错误
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# 如果输入的 image 是 PIL 图像对象,则将其转换为列表
if isinstance(image, PIL.Image.Image):
image = [image]
# 将每个图像转换为 RGB 格式,并拼接成一个数组,增加维度以适应后续处理
image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
# 将图像数组的维度顺序调整为 (批量, 通道, 高, 宽)
image = image.transpose(0, 3, 1, 2)
# 将 NumPy 数组转换为 PyTorch 张量并归一化到 [-1, 1] 范围
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# 处理 mask
# 如果输入的 mask 是 PIL 图像对象,则将其转换为列表
if isinstance(mask, PIL.Image.Image):
mask = [mask]
# 将每个掩膜图像转换为灰度格式,并拼接成一个数组,增加维度以适应后续处理
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
# 将掩膜数组转换为 float32 类型并归一化到 [0, 1] 范围
mask = mask.astype(np.float32) / 255.0
# paint-by-example 方法反转掩膜
mask = 1 - mask
# 将掩膜中低于 0.5 的值设置为 0,高于或等于 0.5 的值设置为 1
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# 将 NumPy 数组转换为 PyTorch 张量
mask = torch.from_numpy(mask)
# 将图像与掩膜相乘,得到被掩膜处理的图像
masked_image = image * mask
# 返回掩膜和被掩膜处理的图像
return mask, masked_image
# 定义一个名为 PaintByExamplePipeline 的类,继承自 DiffusionPipeline 和 StableDiffusionMixin
class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
r"""
# 警告提示,表示这是一个实验性特性
<Tip warning={true}>
🧪 This is an experimental feature!
</Tip>
# 使用 Stable Diffusion 进行图像引导的图像修补的管道。
# 该模型从 [`DiffusionPipeline`] 继承。检查超类文档以获取所有管道的通用方法
# (下载、保存、在特定设备上运行等)。
# 参数说明:
vae ([`AutoencoderKL`]):
用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。
image_encoder ([`PaintByExampleImageEncoder`]):
编码示例输入图像。`unet` 是基于示例图像而非文本提示进行条件处理。
tokenizer ([`~transformers.CLIPTokenizer`]):
用于文本分词的 `CLIPTokenizer`。
unet ([`UNet2DConditionModel`]):
用于去噪编码图像潜在的 `UNet2DConditionModel`。
scheduler ([`SchedulerMixin`]):
与 `unet` 结合使用以去噪编码图像潜在的调度器,可以是
[`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`]。
safety_checker ([`StableDiffusionSafetyChecker`]):
估计生成图像是否可能被视为冒犯或有害的分类模块。
请参考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以获取有关模型潜在危害的更多细节。
feature_extractor ([`~transformers.CLIPImageProcessor`]):
用于从生成图像中提取特征的 `CLIPImageProcessor`;用作 `safety_checker` 的输入。
"""
# TODO: 如果管道没有 feature_extractor,则需要在初始图像(如果为 PIL 格式)编码时给出描述性消息。
# 定义模型在 CPU 上卸载的顺序,指定 'unet' 在前,'vae' 在后
model_cpu_offload_seq = "unet->vae"
# 定义在 CPU 卸载时排除的组件,指定 'image_encoder' 不参与卸载
_exclude_from_cpu_offload = ["image_encoder"]
# 定义可选组件,指定 'safety_checker' 为可选
_optional_components = ["safety_checker"]
# 初始化方法,设置管道的主要组件
def __init__(
self,
vae: AutoencoderKL,
image_encoder: PaintByExampleImageEncoder,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = False,
):
# 调用父类的初始化方法
super().__init__()
# 注册各个模块,设置管道的组成部分
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# 计算 VAE 的缩放因子,基于 VAE 配置中的块输出通道数
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 初始化 VaeImageProcessor,使用计算出的缩放因子
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# 将是否需要安全检查器的信息注册到配置中
self.register_to_config(requires_safety_checker=requires_safety_checker)
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 复制而来
def run_safety_checker(self, image, device, dtype):
# 如果安全检查器不存在,将有害概念标记为 None
if self.safety_checker is None:
has_nsfw_concept = None
else:
# 如果输入图像是张量,使用图像处理器后处理为 PIL 格式
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
# 如果输入图像不是张量,将其转换为 PIL 格式
feature_extractor_input = self.image_processor.numpy_to_pil(image)
# 使用特征提取器处理图像并转移到指定设备
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
# 运行安全检查器,返回处理后的图像和有害概念的存在情况
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
# 返回处理后的图像和有害概念
return image, has_nsfw_concept
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制而来
def prepare_extra_step_kwargs(self, generator, eta):
# 准备调度器步骤的额外参数,因为并非所有调度器都有相同的签名
# eta(η)仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
# eta 对应于 DDIM 论文中的 η:https://arxiv.org/abs/2010.02502
# 应在 [0, 1] 之间
# 检查调度器是否接受 eta 参数
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
# 如果接受 eta,将其添加到额外参数中
extra_step_kwargs["eta"] = eta
# 检查调度器是否接受 generator 参数
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
# 如果接受 generator,将其添加到额外参数中
extra_step_kwargs["generator"] = generator
# 返回准备好的额外参数
return extra_step_kwargs
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 复制而来
def decode_latents(self, latents):
# 警告信息,提示 decode_latents 方法已弃用,将在 1.0.0 中移除
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
# 记录弃用警告
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
# 根据 VAE 的缩放因子调整潜在向量
latents = 1 / self.vae.config.scaling_factor * latents
# 解码潜在向量,返回图像
image = self.vae.decode(latents, return_dict=False)[0]
# 将图像缩放到 [0, 1] 范围内
image = (image / 2 + 0.5).clamp(0, 1)
# 将图像转换为 float32 格式以兼容 bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# 返回处理后的图像
return image
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs 复制而来
# 检查输入参数的有效性
def check_inputs(self, image, height, width, callback_steps):
# 检查 `image` 是否为有效类型,必须是 `torch.Tensor`、`PIL.Image.Image` 或列表
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
# 如果 `image` 类型不符合,抛出错误
raise ValueError(
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
# 检查 `height` 和 `width` 是否能被 8 整除
if height % 8 != 0 or width % 8 != 0:
# 如果不能整除,抛出错误
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 检查 `callback_steps` 是否有效,必须是正整数或 None
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
# 如果无效,抛出错误
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 从 StableDiffusionPipeline 复制的准备潜在变量的方法
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 定义潜在变量的形状
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
# 检查生成器列表的长度是否与批量大小匹配
if isinstance(generator, list) and len(generator) != batch_size:
# 如果不匹配,抛出错误
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果没有提供潜在变量,则生成随机潜在变量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 如果提供了潜在变量,将其转移到指定设备
latents = latents.to(device)
# 将初始噪声缩放到调度器所需的标准差
latents = latents * self.scheduler.init_noise_sigma
# 返回准备好的潜在变量
return latents
# 从 StableDiffusionInpaintPipeline 复制的准备掩膜潜在变量的方法
def prepare_mask_latents(
# 定义方法的输入参数,包括掩膜和其他信息
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# 将掩码调整为与潜变量形状相同,以便将掩码与潜变量拼接
# 这样做可以避免在使用 cpu_offload 和半精度时出现问题
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
) # 通过插值调整掩码的大小
mask = mask.to(device=device, dtype=dtype) # 将掩码移动到指定设备并转换数据类型
masked_image = masked_image.to(device=device, dtype=dtype) # 将掩码图像移动到指定设备并转换数据类型
if masked_image.shape[1] == 4: # 检查掩码图像是否为四通道
masked_image_latents = masked_image # 如果是,直接将其赋值给潜变量
else:
masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # 否则,使用 VAE 编码图像
# 针对每个提示重复掩码和掩码图像潜变量,使用适合 MPS 的方法
if mask.shape[0] < batch_size: # 检查掩码的数量是否少于批处理大小
if not batch_size % mask.shape[0] == 0: # 检查掩码数量是否可整除批处理大小
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
) # 如果不匹配,抛出值错误
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) # 重复掩码以匹配批处理大小
if masked_image_latents.shape[0] < batch_size: # 检查潜变量数量是否少于批处理大小
if not batch_size % masked_image_latents.shape[0] == 0: # 检查潜变量数量是否可整除批处理大小
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
) # 如果不匹配,抛出值错误
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) # 重复潜变量以匹配批处理大小
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask # 根据是否使用无分类器引导选择掩码
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
) # 根据是否使用无分类器引导选择潜变量
# 调整设备以防止与潜变量模型输入拼接时出现设备错误
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) # 将潜变量移动到指定设备并转换数据类型
return mask, masked_image_latents # 返回处理后的掩码和潜变量
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image 复制
# 定义一个编码变分自编码器图像的私有方法
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
# 检查生成器是否为列表类型
if isinstance(generator, list):
# 对每个图像编码并获取潜在表示,使用对应的生成器
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
# 将所有潜在表示在第0维上拼接成一个张量
image_latents = torch.cat(image_latents, dim=0)
else:
# 如果生成器不是列表,直接编码图像并获取潜在表示
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
# 将潜在表示乘以缩放因子
image_latents = self.vae.config.scaling_factor * image_latents
# 返回编码后的潜在表示
return image_latents
# 定义一个编码图像的私有方法
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
# 获取图像编码器参数的数据类型
dtype = next(self.image_encoder.parameters()).dtype
# 检查输入图像是否为张量,如果不是则提取特征
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
# 将图像移动到指定设备,并转换为正确的数据类型
image = image.to(device=device, dtype=dtype)
# 对图像进行编码,获取图像嵌入和负提示嵌入
image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
# 复制图像嵌入以适应每个提示的生成数量
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
# 重塑嵌入张量的形状
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# 检查是否使用无分类器引导
if do_classifier_free_guidance:
# 复制负提示嵌入以匹配图像嵌入的数量
negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)
# 为无分类器引导执行两个前向传播,通过拼接无条件和文本嵌入来避免两个前向传播
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
# 返回编码后的图像嵌入
return image_embeddings
# 定义一个调用方法,禁用梯度计算以提高效率
@torch.no_grad()
def __call__(
# 接收示例图像和图像的参数,允许不同类型的输入
example_image: Union[torch.Tensor, PIL.Image.Image],
image: Union[torch.Tensor, PIL.Image.Image],
mask_image: Union[torch.Tensor, PIL.Image.Image],
# 可选参数定义图像的高度和宽度
height: Optional[int] = None,
width: Optional[int] = None,
# 定义推理步骤的数量和引导缩放比例
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
# 负提示的可选输入
negative_prompt: Optional[Union[str, List[str]]] = None,
# 每个提示生成的图像数量
num_images_per_prompt: Optional[int] = 1,
# 控制采样多样性的参数
eta: float = 0.0,
# 生成器的可选输入
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 可选的潜在变量输入
latents: Optional[torch.Tensor] = None,
# 输出类型的可选参数
output_type: Optional[str] = "pil",
# 是否返回字典格式的结果
return_dict: bool = True,
# 可选的回调函数用于处理中间结果
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
# 每隔多少步调用一次回调
callback_steps: int = 1,
.\diffusers\pipelines\paint_by_example\__init__.py
# 从 dataclasses 模块导入 dataclass 装饰器,用于简化数据类的定义
from dataclasses import dataclass
# 从 typing 模块导入类型检查所需的类型提示
from typing import TYPE_CHECKING, List, Optional, Union
# 导入 numpy 库,通常用于数值计算
import numpy as np
# 导入 PIL 库,处理图像
import PIL
# 从 PIL 导入 Image 类,用于图像处理
from PIL import Image
# 从相对路径导入 utils 模块中的多个工具函数和常量
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 常量,指示是否需要慢速导入
OptionalDependencyNotAvailable, # 异常类,表示可选依赖不可用
_LazyModule, # 类,延迟加载模块的实现
get_objects_from_module, # 函数,从模块中获取对象
is_torch_available, # 函数,检查 PyTorch 是否可用
is_transformers_available, # 函数,检查 Transformers 是否可用
)
# 初始化一个空字典,存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典,用于存储模块导入结构
_import_structure = {}
# 尝试检查可选依赖项是否可用
try:
# 如果 Transformers 和 PyTorch 不可用,抛出异常
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从 utils 模块导入虚拟对象,避免错误使用
from ...utils import dummy_torch_and_transformers_objects # noqa F403
# 更新虚拟对象字典,添加从虚拟模块获取的对象
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果没有异常,更新导入结构字典
else:
_import_structure["image_encoder"] = ["PaintByExampleImageEncoder"] # 指定图像编码器模块
_import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"] # 指定图像处理管道模块
# 如果正在进行类型检查或需要慢速导入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
# 尝试再次检查可选依赖项是否可用
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
# 从虚拟对象模块导入所有内容
from ...utils.dummy_torch_and_transformers_objects import *
# 如果没有异常,从实际模块导入所需的类
else:
from .image_encoder import PaintByExampleImageEncoder # 导入图像编码器
from .pipeline_paint_by_example import PaintByExamplePipeline # 导入图像处理管道
# 如果不是类型检查且不需要慢速导入
else:
import sys # 导入 sys 模块以访问 Python 运行时
# 使用 LazyModule 创建一个延迟加载的模块,减少启动时间
sys.modules[__name__] = _LazyModule(
__name__, # 当前模块名
globals()["__file__"], # 当前模块的文件路径
_import_structure, # 导入结构
module_spec=__spec__, # 模块规范
)
# 将虚拟对象添加到当前模块中
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value) # 设置属性
.\diffusers\pipelines\pia\pipeline_pia.py
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件
# 根据许可证分发是在“按现状”基础上提供的,
# 不附带任何形式的明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和
# 限制的特定语言。
import inspect # 导入 inspect 模块,用于检查对象的类型和属性
from dataclasses import dataclass # 从 dataclasses 导入 dataclass 装饰器,用于简化类的定义
from typing import Any, Callable, Dict, List, Optional, Union # 导入类型注解,用于类型提示
import numpy as np # 导入 NumPy 库,通常用于数值计算和数组操作
import PIL # 导入 PIL 库,用于图像处理
import torch # 导入 PyTorch 库,用于深度学习
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection # 从 transformers 导入 CLIP 相关类
from ...image_processor import PipelineImageInput # 从本地模块导入 PipelineImageInput 类
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin # 从本地模块导入各种混合类
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel # 从本地模块导入模型类
from ...models.lora import adjust_lora_scale_text_encoder # 从本地模块导入调整 Lora 缩放的函数
from ...models.unets.unet_motion_model import MotionAdapter # 从本地模块导入 MotionAdapter 类
from ...schedulers import ( # 从本地模块导入调度器类
DDIMScheduler, # 导入 DDIM 调度器
DPMSolverMultistepScheduler, # 导入 DPM 多步调度器
EulerAncestralDiscreteScheduler, # 导入 Euler 祖先离散调度器
EulerDiscreteScheduler, # 导入 Euler 离散调度器
LMSDiscreteScheduler, # 导入 LMS 离散调度器
PNDMScheduler, # 导入 PNDM 调度器
)
from ...utils import ( # 从本地模块导入各种工具函数和常量
USE_PEFT_BACKEND, # 导入标识是否使用 PEFT 后端的常量
BaseOutput, # 导入 BaseOutput 类,通常用于输出格式
logging, # 导入 logging 模块,用于记录日志
replace_example_docstring, # 导入替换示例文档字符串的函数
scale_lora_layers, # 导入缩放 Lora 层的函数
unscale_lora_layers, # 导入取消缩放 Lora 层的函数
)
from ...utils.torch_utils import randn_tensor # 从本地模块导入生成随机张量的函数
from ...video_processor import VideoProcessor # 从本地模块导入视频处理类
from ..free_init_utils import FreeInitMixin # 从上级模块导入 FreeInitMixin 类
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin # 从上级模块导入扩散管道和稳定扩散混合类
logger = logging.get_logger(__name__) # 创建一个记录器实例,用于记录当前模块的日志,禁用 pylint 对名称的警告
EXAMPLE_DOC_STRING = """ # 定义一个示例文档字符串的常量
``` # 该常量的开始部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量的结束部分
```py # 该常量的结束部分
``` # 该常量
# 示例代码的使用说明
Examples:
# 导入所需的库
```py
>>> import torch # 导入 PyTorch 库
>>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline # 从 diffusers 导入相关类
>>> from diffusers.utils import export_to_gif, load_image # 导入工具函数
# 从预训练模型加载运动适配器
>>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
# 从预训练模型创建 PIAPipeline 对象,并指定适配器和数据类型
>>> pipe = PIAPipeline.from_pretrained(
... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16
... )
# 设置调度器为 EulerDiscreteScheduler,并使用现有配置
>>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
# 从指定 URL 加载图像
>>> image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
... )
# 调整图像大小为 512x512 像素
>>> image = image.resize((512, 512))
# 定义正向提示内容
>>> prompt = "cat in a hat"
# 定义反向提示内容,以减少不良效果
>>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted"
# 创建一个随机数生成器并设置种子
>>> generator = torch.Generator("cpu").manual_seed(0)
# 生成输出图像,通过管道处理输入图像和提示
>>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator)
# 获取输出结果中的第一帧
>>> frames = output.frames[0]
# 将帧导出为 GIF 动画文件
>>> export_to_gif(frames, "pia-animation.gif")
"""
# 定义一个包含不同运动范围的列表,每个子列表代表不同类型的运动
RANGE_LIST = [
[1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 小运动
[1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # 中等运动
[1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # 大运动
[1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # 循环
[1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # 循环
[1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # 循环
[0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # 风格迁移候选小运动
[0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # 风格迁移中等运动
[0.5, 0.2], # 风格迁移大运动
]
# 定义一个函数,根据统计信息准备掩码系数
def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int):
# 确保视频帧数大于 0
assert num_frames > 0, "video_length should be greater than 0"
# 确保视频帧数大于条件帧
assert num_frames > cond_frame, "video_length should be greater than cond_frame"
# 将 RANGE_LIST 赋值给 range_list
range_list = RANGE_LIST
# 确保运动缩放类型在范围列表中可用
assert motion_scale < len(range_list), f"motion_scale type{motion_scale} not implemented"
# 根据运动缩放类型获取对应的系数
coef = range_list[motion_scale]
# 用最后一个系数填充至 num_frames 长度
coef = coef + ([coef[-1]] * (num_frames - len(coef)))
# 计算每帧与条件帧的距离
order = [abs(i - cond_frame) for i in range(num_frames)]
# 根据距离重新排列系数
coef = [coef[order[i]] for i in range(num_frames)]
# 返回重新排列后的系数
return coef
@dataclass
# 定义一个数据类,用于 PIAPipeline 的输出
class PIAPipelineOutput(BaseOutput):
r"""
PIAPipeline 的输出类。
参数:
frames (`torch.Tensor`, `np.ndarray`, 或 List[List[PIL.Image.Image]]):
长度为 `batch_size` 的嵌套列表,包含每个 `num_frames` 的去噪 PIL 图像序列,形状为
`(batch_size, num_frames, channels, height, width)` 的 NumPy 数组,或形状为
`(batch_size, num_frames, channels, height, width)` 的 Torch 张量。
"""
# 输出帧,可以是多种数据类型
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
# 定义一个用于文本到视频生成的管道类
class PIAPipeline(
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
FreeInitMixin,
):
r"""
文本到视频生成的管道。
此模型继承自 [`DiffusionPipeline`]. 请查看超类文档以了解所有管道实现的通用方法
(下载、保存、在特定设备上运行等)。
此管道还继承以下加载方法:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用于加载文本反转嵌入
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重
- [`~loaders.IPAdapterMixin.load_ip_adapter`] 用于加载 IP 适配器
# 函数参数说明
Args:
vae ([`AutoencoderKL`]):
# 变分自编码器 (VAE) 模型,用于将图像编码和解码为潜在表示
text_encoder ([`CLIPTextModel`]):
# 冻结的文本编码器,使用 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
tokenizer (`CLIPTokenizer`):
# [`~transformers.CLIPTokenizer`] 用于对文本进行标记化
unet ([`UNet2DConditionModel`]):
# [`UNet2DConditionModel`] 用于创建 UNetMotionModel,以去噪编码后的视频潜在特征
motion_adapter ([`MotionAdapter`]):
# [`MotionAdapter`] 与 `unet` 结合使用,以去噪编码后的视频潜在特征
scheduler ([`SchedulerMixin`]):
# 与 `unet` 结合使用的调度器,用于去噪编码后的图像潜在特征。可以是
# [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`]
"""
# 定义模型 CPU 卸载顺序
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
# 定义可选组件列表
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
# 定义回调张量输入列表
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
# 初始化方法的参数列表
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: Union[UNet2DConditionModel, UNetMotionModel],
scheduler: Union[
# 允许的调度器类型
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
motion_adapter: Optional[MotionAdapter] = None,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
):
# 调用父类的初始化方法
super().__init__()
# 如果 unet 是 UNet2DConditionModel 的实例,则转换为 UNetMotionModel
if isinstance(unet, UNet2DConditionModel):
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
# 注册模块,初始化各个组件
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
motion_adapter=motion_adapter,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
# 计算 VAE 的缩放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 初始化视频处理器,设置是否调整大小和 VAE 缩放因子
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt 中复制,num_images_per_prompt -> num_videos_per_prompt
# 定义一个编码提示的函数,接受多个参数
def encode_prompt(
self, # 类的实例
prompt, # 输入的提示文本
device, # 计算设备(如 CPU 或 GPU)
num_images_per_prompt, # 每个提示生成的图像数量
do_classifier_free_guidance, # 是否使用无分类器引导
negative_prompt=None, # 可选的负面提示文本
prompt_embeds: Optional[torch.Tensor] = None, # 可选的提示嵌入张量
negative_prompt_embeds: Optional[torch.Tensor] = None, # 可选的负面提示嵌入张量
lora_scale: Optional[float] = None, # 可选的 LoRA 缩放因子
clip_skip: Optional[int] = None, # 可选的剪辑跳过参数
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image 复制的函数
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): # 定义图像编码函数
dtype = next(self.image_encoder.parameters()).dtype # 获取图像编码器参数的数据类型
if not isinstance(image, torch.Tensor): # 如果输入的图像不是张量
image = self.feature_extractor(image, return_tensors="pt").pixel_values # 使用特征提取器将其转换为张量
image = image.to(device=device, dtype=dtype) # 将图像张量移动到指定设备并设置数据类型
if output_hidden_states: # 如果需要输出隐藏状态
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] # 编码图像并获取倒数第二个隐藏状态
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) # 根据生成图像数量重复隐藏状态
uncond_image_enc_hidden_states = self.image_encoder( # 编码全零图像以获取无条件隐藏状态
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2] # 获取无条件图像的倒数第二个隐藏状态
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( # 重复无条件隐藏状态
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states # 返回编码的图像隐藏状态和无条件隐藏状态
else: # 如果不需要输出隐藏状态
image_embeds = self.image_encoder(image).image_embeds # 编码图像以获取图像嵌入
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) # 根据生成图像数量重复图像嵌入
uncond_image_embeds = torch.zeros_like(image_embeds) # 创建与图像嵌入相同形状的全零无条件嵌入
return image_embeds, uncond_image_embeds # 返回图像嵌入和无条件嵌入
# 从 diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents 复制的函数
def decode_latents(self, latents): # 定义解码潜变量的函数
latents = 1 / self.vae.config.scaling_factor * latents # 按缩放因子调整潜变量
batch_size, channels, num_frames, height, width = latents.shape # 获取潜变量的形状信息
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) # 重新排列并重塑潜变量
image = self.vae.decode(latents).sample # 使用 VAE 解码潜变量以获取图像样本
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # 处理图像以生成视频格式
# 我们总是将其转换为 float32,因为这不会造成显著开销,并且与 bfloat16 兼容
video = video.float() # 将视频数据转换为 float32 类型
return video # 返回解码后的视频数据
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的函数
# 定义准备额外参数的方法,供调度器步骤使用
def prepare_extra_step_kwargs(self, generator, eta):
# 为调度器步骤准备额外的关键字参数,因为不同调度器的参数签名可能不同
# eta(η)仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
# eta 对应于 DDIM 论文中的 η:https://arxiv.org/abs/2010.02502
# 并且应该在 [0, 1] 之间
# 检查调度器的步骤函数是否接受 eta 参数
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 初始化额外步骤参数字典
extra_step_kwargs = {}
# 如果接受 eta 参数,则将其添加到额外参数字典中
if accepts_eta:
extra_step_kwargs["eta"] = eta
# 检查调度器的步骤函数是否接受 generator 参数
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
# 如果接受 generator 参数,则将其添加到额外参数字典中
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 返回准备好的额外步骤参数字典
return extra_step_kwargs
# 定义检查输入参数的方法
def check_inputs(
self,
prompt, # 输入提示
height, # 输出图像的高度
width, # 输出图像的宽度
negative_prompt=None, # 可选的负面提示
prompt_embeds=None, # 可选的提示嵌入
negative_prompt_embeds=None, # 可选的负面提示嵌入
ip_adapter_image=None, # 可选的图像适配器输入图像
ip_adapter_image_embeds=None, # 可选的图像适配器输入嵌入
callback_on_step_end_tensor_inputs=None, # 可选的步骤结束回调输入
):
# 检查高度和宽度是否能被8整除
if height % 8 != 0 or width % 8 != 0:
# 抛出值错误,提示高度和宽度不符合要求
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 检查回调输入是否存在且不全在已注册的回调输入中
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
# 抛出值错误,提示未找到有效的回调输入
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# 检查是否同时传入了 prompt 和 prompt_embeds
if prompt is not None and prompt_embeds is not None:
# 抛出值错误,提示不能同时传入两者
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 检查是否同时未传入 prompt 和 prompt_embeds
elif prompt is None and prompt_embeds is None:
# 抛出值错误,提示至少需要提供一个
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 检查 prompt 的类型
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 抛出值错误,提示 prompt 的类型不正确
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查是否同时传入了 negative_prompt 和 negative_prompt_embeds
if negative_prompt is not None and negative_prompt_embeds is not None:
# 抛出值错误,提示不能同时传入两者
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 检查 prompt_embeds 和 negative_prompt_embeds 是否形状一致
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 抛出值错误,提示两者形状不一致
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 检查是否同时传入了 ip_adapter_image 和 ip_adapter_image_embeds
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
# 抛出值错误,提示不能同时传入两者
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
# 检查 ip_adapter_image_embeds 是否存在
if ip_adapter_image_embeds is not None:
# 检查类型是否为列表
if not isinstance(ip_adapter_image_embeds, list):
# 抛出值错误,提示类型不正确
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
# 检查列表中的第一个元素的维度
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
# 抛出值错误,提示维度不符合要求
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds 复制的代码
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
# 初始化图像嵌入列表
image_embeds = []
# 如果启用分类器自由引导,初始化负图像嵌入列表
if do_classifier_free_guidance:
negative_image_embeds = []
# 如果没有提供图像嵌入
if ip_adapter_image_embeds is None:
# 确保输入的图像是列表格式
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
# 检查图像数量是否与适配器数量匹配
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
# 遍历每个图像和对应的投影层
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
# 确定是否输出隐藏状态
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
# 编码单个图像,获取嵌入
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
# 将嵌入添加到列表中
image_embeds.append(single_image_embeds[None, :])
# 如果启用分类器自由引导,添加负嵌入
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
# 如果提供了图像嵌入,遍历嵌入列表
for single_image_embeds in ip_adapter_image_embeds:
# 处理负图像嵌入(如果启用引导)
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
# 将嵌入添加到列表中
image_embeds.append(single_image_embeds)
# 初始化适配器图像嵌入列表
ip_adapter_image_embeds = []
# 遍历图像嵌入,为每个图像复制指定数量
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
# 如果启用分类器自由引导,处理负嵌入
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
# 将嵌入移动到指定设备
single_image_embeds = single_image_embeds.to(device=device)
# 添加到适配器图像嵌入列表
ip_adapter_image_embeds.append(single_image_embeds)
# 返回适配器图像嵌入
return ip_adapter_image_embeds
# 从 diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents 复制的代码
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# 定义输出张量的形状,包括批量大小、通道数、帧数以及缩放后的高度和宽度
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
# 检查生成器是否为列表且其长度与批量大小不匹配
if isinstance(generator, list) and len(generator) != batch_size:
# 抛出错误,提示生成器的长度与请求的有效批量大小不匹配
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果潜变量为 None,则生成随机张量作为潜变量
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 将给定的潜变量转移到指定设备
latents = latents.to(device)
# 将初始噪声按调度器要求的标准差进行缩放
latents = latents * self.scheduler.init_noise_sigma
# 返回处理后的潜变量
return latents
def prepare_masked_condition(
self,
image,
batch_size,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
motion_scale=0,
):
# 定义输出张量的形状,包括批量大小、通道数、帧数以及缩放后的高度和宽度
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
# 解包形状信息,获取缩放后的高度和宽度
_, _, _, scaled_height, scaled_width = shape
# 对输入图像进行预处理
image = self.video_processor.preprocess(image)
# 将预处理后的图像转移到指定设备并转换为指定数据类型
image = image.to(device, dtype)
# 如果生成器是列表,则逐个编码每个图像并采样
if isinstance(generator, list):
image_latent = [
self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size)
]
# 将所有潜变量张量按维度 0 连接起来
image_latent = torch.cat(image_latent, dim=0)
else:
# 否则直接编码图像并采样
image_latent = self.vae.encode(image).latent_dist.sample(generator)
# 将潜变量转移到指定设备并转换为指定数据类型
image_latent = image_latent.to(device=device, dtype=dtype)
# 对潜变量进行插值调整,改变大小到缩放后的高度和宽度
image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width])
# 复制潜变量并按配置的缩放因子进行缩放
image_latent_padding = image_latent.clone() * self.vae.config.scaling_factor
# 创建一个全零的掩码张量,大小与批量大小、帧数、缩放后的高度和宽度匹配
mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype)
# 根据统计信息准备掩码系数
mask_coef = prepare_mask_coef_by_statistics(num_frames, 0, motion_scale)
# 创建一个全零的张量用于存储被掩盖的图像,形状与 image_latent_padding 匹配
masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to(
device=device, dtype=self.unet.dtype
)
# 遍历每一帧,更新掩码和被掩盖的图像
for f in range(num_frames):
mask[:, :, f, :, :] = mask_coef[f]
masked_image[:, :, f, :, :] = image_latent_padding.clone()
# 根据条件决定是否复制掩码以支持分类器自由引导
mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
# 根据条件决定是否复制被掩盖的图像以支持分类器自由引导
masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image
# 返回掩码和被掩盖的图像
return mask, masked_image
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps 复制
# 定义获取时间步长的方法,接受推理步数、强度和设备作为参数
def get_timesteps(self, num_inference_steps, strength, device):
# 计算初始时间步,确保不超过总推理步数
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
# 计算开始时间步,确保不小于零
t_start = max(num_inference_steps - init_timestep, 0)
# 从调度器中获取相应的时间步长
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
# 如果调度器具有设置开始索引的方法,则调用该方法
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
# 返回计算得到的时间步和剩余推理步数
return timesteps, num_inference_steps - t_start
# 定义一个属性,用于获取引导比例
@property
def guidance_scale(self):
return self._guidance_scale
# 定义一个属性,用于获取剪切跳过参数
@property
def clip_skip(self):
return self._clip_skip
# 定义一个属性,判断是否使用无分类器引导
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
# 定义一个属性,用于获取交叉注意力参数
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
# 定义一个属性,用于获取时间步数
@property
def num_timesteps(self):
return self._num_timesteps
# 禁用梯度计算,装饰器用于优化性能
@torch.no_grad()
# 替换示例文档字符串
@replace_example_docstring(EXAMPLE_DOC_STRING)
# 定义调用方法,接受多个输入参数
def __call__(
self,
image: PipelineImageInput,
prompt: Union[str, List[str]] = None,
strength: float = 1.0,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
motion_scale: int = 0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
.\diffusers\pipelines\pia\__init__.py
# 导入类型检查相关的常量
from typing import TYPE_CHECKING
# 从 utils 模块导入所需的常量和函数
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 慢导入的标志
OptionalDependencyNotAvailable, # 可选依赖不可用的异常
_LazyModule, # 懒加载模块的类
get_objects_from_module, # 从模块中获取对象的函数
is_torch_available, # 检查是否可用 PyTorch 的函数
is_transformers_available, # 检查是否可用 Transformers 的函数
)
# 初始化一个空字典用于存储假对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}
try:
# 检查 Transformers 和 PyTorch 是否可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,则抛出异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 如果捕获到可选依赖不可用的异常,则导入假对象
from ...utils import dummy_torch_and_transformers_objects
# 将假对象更新到 _dummy_objects 字典
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
# 如果依赖可用,则更新导入结构
_import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"]
# 检查类型检查或慢导入的标志
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
# 再次检查 Transformers 和 PyTorch 是否可用
if not (is_transformers_available() and is_torch_available()):
# 如果不可用,则抛出异常
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
# 如果捕获到可选依赖不可用的异常,则导入假对象
from ...utils.dummy_torch_and_transformers_objects import *
else:
# 如果依赖可用,则从 pipeline_pia 导入相关类
from .pipeline_pia import PIAPipeline, PIAPipelineOutput
else:
# 如果不是类型检查或慢导入,则使用懒加载模块
import sys
# 将当前模块替换为懒加载模块实例
sys.modules[__name__] = _LazyModule(
__name__, # 模块名称
globals()["__file__"], # 模块文件路径
_import_structure, # 导入结构
module_spec=__spec__, # 模块规格
)
# 将假对象设置到当前模块中
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
.\diffusers\pipelines\pipeline_flax_utils.py
# 指定编码为 UTF-8
# coding=utf-8
# 版权声明,表明文件归 HuggingFace Inc. 团队所有
# 版权声明,表明文件归 NVIDIA CORPORATION 所有
#
# 根据 Apache License, Version 2.0 许可使用本文件
# 只能在遵循许可的情况下使用本文件
# 可以在以下网址获取许可
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则软件以“现状”分发
# 不提供任何明示或暗示的保证或条件
# 详见许可证中关于权限和限制的具体条款
# 导入 importlib 模块,用于动态导入模块
import importlib
# 导入 inspect 模块,用于获取对象的内部信息
import inspect
# 导入 os 模块,用于与操作系统交互
import os
# 从 typing 模块导入各种类型提示
from typing import Any, Dict, List, Optional, Union
# 导入 flax 框架
import flax
# 导入 numpy 库,用于数值计算
import numpy as np
# 导入 PIL.Image,用于图像处理
import PIL.Image
# 从 flax.core.frozen_dict 导入 FrozenDict,用于不可变字典
from flax.core.frozen_dict import FrozenDict
# 从 huggingface_hub 导入创建仓库和下载快照的函数
from huggingface_hub import create_repo, snapshot_download
# 从 huggingface_hub.utils 导入参数验证函数
from huggingface_hub.utils import validate_hf_hub_args
# 从 PIL 导入 Image 用于图像处理
from PIL import Image
# 从 tqdm.auto 导入进度条显示
from tqdm.auto import tqdm
# 从上层模块导入 ConfigMixin 类
from ..configuration_utils import ConfigMixin
# 从上层模块导入与模型相关的常量和类
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
# 从上层模块导入与调度器相关的常量和类
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
# 从上层模块导入一些工具函数和常量
from ..utils import (
CONFIG_NAME, # 配置文件名常量
BaseOutput, # 基础输出类
PushToHubMixin, # 用于推送到 Hugging Face Hub 的混合类
http_user_agent, # HTTP 用户代理字符串
is_transformers_available, # 检查 transformers 库是否可用的函数
logging, # 日志模块
)
# 检查 transformers 库是否可用,如果可用则导入 FlaxPreTrainedModel
if is_transformers_available():
from transformers import FlaxPreTrainedModel
# 定义加载模型的文件名常量
INDEX_FILE = "diffusion_flax_model.bin"
# 创建日志记录器
logger = logging.get_logger(__name__)
# 定义可加载的类及其方法
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
}
# 创建一个字典以存储所有可导入的类
ALL_IMPORTABLE_CLASSES = {}
# 遍历可加载的类并将其更新到 ALL_IMPORTABLE_CLASSES 字典中
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
# 定义导入 Flax 模型的函数
def import_flax_or_no_model(module, class_name):
try:
# 1. 首先确保如果存在 Flax 对象,则导入该对象
class_obj = getattr(module, "Flax" + class_name)
except AttributeError:
# 2. 如果失败,则说明没有模型,不附加 "Flax"
class_obj = getattr(module, class_name)
except AttributeError:
# 如果两者都不存在,抛出错误
raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
# 返回找到的类对象
return class_obj
# 定义 FlaxImagePipelineOutput 类,继承自 BaseOutput
@flax.struct.dataclass
class FlaxImagePipelineOutput(BaseOutput):
"""
图像管道的输出类。
# 定义函数参数的文档字符串
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
# 输入参数 images 是去噪后的 PIL 图像列表,长度为 `batch_size` 或形状为 `(batch_size, height, width, num_channels)` 的 NumPy 数组。
"""
# 声明 images 变量的类型,可以是 PIL 图像列表或 NumPy 数组
images: Union[List[PIL.Image.Image], np.ndarray]
# FlaxDiffusionPipeline类是Flax基础管道的基类
class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
r"""
Flax基础管道的基类。
[`FlaxDiffusionPipeline`]存储扩散管道的所有组件(模型、调度器和处理器),并提供加载、下载和保存模型的方法。它还包括以下方法:
- 启用/禁用去噪迭代的进度条
类属性:
- **config_name** ([`str`]) -- 存储扩散管道组件的类和模块名称的配置文件名。
"""
# 定义配置文件名
config_name = "model_index.json"
# 注册模块的方法,接收任意关键字参数
def register_modules(self, **kwargs):
# 为避免循环导入,在此处导入
from diffusers import pipelines
# 遍历传入的模块
for name, module in kwargs.items():
# 如果模块为None,注册字典为None值
if module is None:
register_dict = {name: (None, None)}
else:
# 获取模块的库名
library = module.__module__.split(".")[0]
# 检查模块是否为管道模块
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# 如果库不在LOADABLE_CLASSES中,或模块是管道模块,则将库名设为管道目录名
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# 获取类名
class_name = module.__class__.__name__
# 注册字典为库名和类名的元组
register_dict = {name: (library, class_name)}
# 保存模型索引配置
self.register_to_config(**register_dict)
# 将模块设置为当前对象的属性
setattr(self, name, module)
# 保存预训练模型的方法,接收目录路径和参数等
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
params: Union[Dict, FrozenDict],
push_to_hub: bool = False,
**kwargs,
@classmethod
# 类方法装饰器,表示该方法是类级别的方法
@validate_hf_hub_args
@classmethod
# 获取对象初始化方法的签名参数
def _get_signature_keys(cls, obj):
# 获取对象初始化方法的参数
parameters = inspect.signature(obj.__init__).parameters
# 获取必需参数
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
# 获取可选参数
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
# 计算期望的模块名
expected_modules = set(required_parameters.keys()) - {"self"}
# 返回期望的模块名和可选参数
return expected_modules, optional_parameters
# 属性装饰器,表明该方法为属性
# 定义一个返回管道组件的字典的方法
def components(self) -> Dict[str, Any]:
r"""
`self.components` 属性对于使用相同权重和配置运行不同的管道非常有用,避免重新分配内存。
示例:
```py
>>> from diffusers import (
... FlaxStableDiffusionPipeline,
... FlaxStableDiffusionImg2ImgPipeline,
... )
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```py
返回:
包含初始化管道所需所有模块的字典。
"""
# 获取期望的模块和可选参数
expected_modules, optional_parameters = self._get_signature_keys(self)
# 创建一个字典,包含配置中的所有模块,但排除以“_”开头的和可选参数
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
# 检查组件的键是否与期望的模块一致
if set(components.keys()) != expected_modules:
# 如果不一致,抛出错误并显示期望和实际定义的模块
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components} are defined."
)
# 返回组件字典
return components
# 静态方法:将 NumPy 图像或图像批次转换为 PIL 图像
@staticmethod
def numpy_to_pil(images):
"""
将 NumPy 图像或图像批次转换为 PIL 图像。
"""
# 如果图像的维度为 3,增加一个新的维度
if images.ndim == 3:
images = images[None, ...]
# 将图像值缩放到 0-255 之间并转换为无符号整数类型
images = (images * 255).round().astype("uint8")
# 如果图像的最后一个维度是 1,表示灰度图像
if images.shape[-1] == 1:
# 特殊情况处理单通道灰度图像
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
# 对于其他类型图像,直接转换
pil_images = [Image.fromarray(image) for image in images]
# 返回 PIL 图像列表
return pil_images
# TODO: 使其兼容 jax.lax
# 定义一个进度条的方法,接受可迭代对象
def progress_bar(self, iterable):
# 如果没有进度条配置,则初始化为空字典
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
# 如果已有配置,检查其类型是否为字典
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
# 返回带有进度条的可迭代对象
return tqdm(iterable, **self._progress_bar_config)
# 设置进度条的配置参数
def set_progress_bar_config(self, **kwargs):
# 将配置参数赋值给进度条配置
self._progress_bar_config = kwargs
.\diffusers\pipelines\pipeline_loading_utils.py
# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,说明版权归 HuggingFace Inc. 团队所有
# Copyright 2024 The HuggingFace Inc. team.
#
# 根据 Apache 许可证第 2.0 版进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 除非遵守许可证,否则不得使用本文件
# you may not use this file except in compliance with the License.
# 可以通过以下网址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,否则软件在"按原样"基础上分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件,包括明示或暗示
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 请参阅许可证以获取有关权限和限制的具体信息
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入必要的库
import importlib
import os
import re
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
# 导入 PyTorch 库
import torch
# 从 huggingface_hub 导入模型信息函数
from huggingface_hub import model_info
# 导入 Hugging Face Hub 的参数验证工具
from huggingface_hub.utils import validate_hf_hub_args
# 导入版本管理工具
from packaging import version
# 导入当前模块的版本信息
from .. import __version__
# 从 utils 模块导入一些工具函数和常量
from ..utils import (
FLAX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
get_class_from_dynamic_module,
is_accelerate_available,
is_peft_available,
is_transformers_available,
logging,
)
# 从 torch_utils 模块导入编译模块检查函数
from ..utils.torch_utils import is_compiled_module
# 检查 transformers 库是否可用
if is_transformers_available():
# 导入 transformers 库
import transformers
# 从 transformers 中导入预训练模型基类
from transformers import PreTrainedModel
# 导入 transformers 中的权重名称常量
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
# 检查 accelerate 库是否可用
if is_accelerate_available():
# 导入 accelerate 库
import accelerate
# 从 accelerate 中导入模型调度函数
from accelerate import dispatch_model
# 导入用于从模块中移除钩子的工具
from accelerate.hooks import remove_hook_from_module
# 从 accelerate 中导入计算模块大小和获取最大内存的工具
from accelerate.utils import compute_module_sizes, get_max_memory
# 定义加载模型时使用的索引文件名
INDEX_FILE = "diffusion_pytorch_model.bin"
# 定义自定义管道文件名
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
# 定义虚拟模块的文件夹路径
DUMMY_MODULES_FOLDER = "diffusers.utils"
# 定义 transformers 虚拟模块的文件夹路径
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
# 定义连接管道的关键字列表
CONNECTED_PIPES_KEYS = ["prior"]
# 创建一个记录器实例,用于记录日志信息
logger = logging.get_logger(__name__)
# 定义可加载类的字典,映射库名到相应的类和方法
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
"onnxruntime.training": {
"ORTModule": ["save_pretrained", "from_pretrained"],
},
}
# 初始化一个空字典,用于存储所有可导入的类
ALL_IMPORTABLE_CLASSES = {}
# 遍历 LOADABLE_CLASSES 字典中的每个库
for library in LOADABLE_CLASSES:
# 将指定库中的可加载类更新到所有可导入的类集合中
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
# 检查文件名是否与 safetensors 兼容,返回布尔值
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
"""
检查 safetensors 兼容性:
- 默认情况下,所有模型使用默认 pytorch 序列化保存,因此我们使用默认 pytorch 文件列表来了解所需的 safetensors 文件。
- 仅当每个默认 pytorch 文件都有匹配的 safetensors 文件时,模型才与 safetensors 兼容。
将默认 pytorch 序列化文件名转换为 safetensors 序列化文件名:
- 对于来自 diffusers 库的模型,仅需将 ".bin" 扩展名替换为 ".safetensors"
- 对于来自 transformers 库的模型,文件名从 "pytorch_model" 更改为 "model",并将 ".bin" 扩展名替换为 ".safetensors"
"""
# 初始化一个空列表,用于存储默认 pytorch 文件名
pt_filenames = []
# 初始化一个空集合,用于存储 safetensors 文件名
sf_filenames = set()
# 如果未传递组件,则将其设置为空列表
passed_components = passed_components or []
# 遍历输入的文件名
for filename in filenames:
# 分离文件名和扩展名
_, extension = os.path.splitext(filename)
# 如果文件在传递的组件中,跳过处理
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
continue
# 如果扩展名为 .bin,则添加到 pytorch 文件名列表中
if extension == ".bin":
pt_filenames.append(os.path.normpath(filename))
# 如果扩展名为 .safetensors,则添加到 safetensors 文件名集合中
elif extension == ".safetensors":
sf_filenames.add(os.path.normpath(filename))
# 遍历所有默认 pytorch 文件名
for filename in pt_filenames:
# 拆分路径和文件名
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)
# 如果文件名以 "pytorch_model" 开头,则进行替换
if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename
# 构建预期的 safetensors 文件名
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
# 检查预期的 safetensors 文件名是否在集合中
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
# 如果所有检查通过,返回 True
return True
# 检查文件名是否与 variant 兼容,返回文件路径列表或字符串
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
# 定义一个权重文件名的列表
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
# 如果 transformers 可用,添加更多权重文件名
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# 从权重文件名中提取前缀
weight_prefixes = [w.split(".")[0] for w in weight_names]
# 从权重文件名中提取后缀
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# 定义 transformers 索引格式的正则表达式
transformers_index_format = r"\d{5}-of-\d{5}"
# 如果 variant 不为 None,表示需要处理变体文件
if variant is not None:
# 定义一个正则表达式,匹配带有变体的权重文件名
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
# 定义一个正则表达式,匹配带有变体的索引文件名
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
# 定义一个正则表达式,匹配不带变体的权重文件名
non_variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
# 定义一个正则表达式,匹配不带变体的索引文件名
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
# 如果 variant 不为 None,获取所有变体权重和索引文件名
if variant is not None:
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
# 合并变体权重和索引文件名
variant_filenames = variant_weights | variant_indexes
else:
# 如果没有变体,则变体文件名集合为空
variant_filenames = set()
# 获取所有不带变体的权重文件名
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
# 获取所有不带变体的索引文件名
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
# 合并不带变体的权重和索引文件名
non_variant_filenames = non_variant_weights | non_variant_indexes
# 默认情况下使用所有变体文件名
usable_filenames = set(variant_filenames)
# 定义一个函数,将文件名转换为对应的变体文件名
def convert_to_variant(filename):
# 如果文件名中包含 'index',则替换为带变体的索引文件名
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
# 如果文件名符合特定格式,则转换为带变体的文件名
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
# 否则默认按变体格式修改文件名
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
# 返回变体文件名
return variant_filename
# 遍历所有不带变体的文件名
for f in non_variant_filenames:
# 转换为对应的变体文件名
variant_filename = convert_to_variant(f)
# 如果该变体文件名不在可用文件名集合中,则添加
if variant_filename not in usable_filenames:
usable_filenames.add(f)
# 返回可用文件名和变体文件名的集合
return usable_filenames, variant_filenames
# 装饰器,用于验证 Hugging Face Hub 参数
@validate_hf_hub_args
# 定义一个函数,用于发出关于过时模型变体的警告
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
# 获取模型信息,包括预训练模型的路径和其他参数
info = model_info(
pretrained_model_name_or_path,
token=token,
revision=None,
)
# 从模型信息中提取所有文件名
filenames = {sibling.rfilename for sibling in info.siblings}
# 获取与指定变体兼容的模型文件名
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
# 去掉文件名中的版本信息,生成新文件名列表
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
# 检查给定的模型文件名是否是兼容文件名的子集
if set(model_filenames).issubset(set(comp_model_filenames)):
# 发出关于通过修订加载模型变体的警告
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
# 发出关于不正确加载模型变体的警告,并请求用户报告缺失文件的问题
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
# 定义一个函数,用于解包模型
def _unwrap_model(model):
"""Unwraps a model."""
# 检查模型是否为编译模块,如果是则解包
if is_compiled_module(model):
model = model._orig_mod
# 检查 PEFT 是否可用
if is_peft_available():
from peft import PeftModel
# 如果模型是 PeftModel 类型,则解包至基础模型
if isinstance(model, PeftModel):
model = model.base_model.model
# 返回解包后的模型
return model
# 定义一个简单的帮助函数,用于在不正确模块时抛出或发出警告
def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
"""Simple helper method to raise or warn in case incorrect module has been passed"""
# 如果当前模块不是管道模块
if not is_pipeline_module:
# 动态导入指定的库
library = importlib.import_module(library_name)
# 获取库中指定名称的类对象
class_obj = getattr(library, class_name)
# 遍历可导入类,构建文件名到类对象的字典
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
# 初始化期望的类对象为 None
expected_class_obj = None
# 遍历类候选者,检查是否与目标类兼容
for class_name, class_candidate in class_candidates.items():
# 如果候选类不为 None 且是目标类的子类
if class_candidate is not None and issubclass(class_obj, class_candidate):
# 将其设置为期望的类对象
expected_class_obj = class_candidate
# Dynamo 将原始模型包装在一个私有类中。
# 没有找到公共 API 获取原始类。
# 从传入的类对象中获取子模型
sub_model = passed_class_obj[name]
# 解包子模型,获取原始模型
unwrapped_sub_model = _unwrap_model(sub_model)
# 获取解包后模型的类
model_cls = unwrapped_sub_model.__class__
# 检查解包模型类是否是期望类的子类
if not issubclass(model_cls, expected_class_obj):
# 如果不是,抛出值错误,指明类型不匹配
raise ValueError(
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
)
else:
# 如果是管道模块,记录警告信息
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# 定义一个获取类对象和候选类的辅助方法
def get_class_obj_and_candidates(
# 传入库名、类名、可导入的类、管道、是否为管道模块、组件名及缓存目录
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""简单的辅助方法来检索模块的类对象以及潜在的父类对象"""
# 构建组件文件夹路径
component_folder = os.path.join(cache_dir, component_name)
# 如果是管道模块
if is_pipeline_module:
# 获取指定库名的管道模块
pipeline_module = getattr(pipelines, library_name)
# 获取指定类对象
class_obj = getattr(pipeline_module, class_name)
# 创建类候选字典,键为可导入类名,值为类对象
class_candidates = {c: class_obj for c in importable_classes.keys()}
# 如果组件文件存在
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# 从动态模块加载自定义组件
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
# 创建类候选字典
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# 否则从库中导入
library = importlib.import_module(library_name)
# 获取指定类对象
class_obj = getattr(library, class_name)
# 创建类候选字典
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
# 返回类对象及候选类字典
return class_obj, class_candidates
# 定义一个获取自定义管道类的辅助方法
def _get_custom_pipeline_class(
# 传入自定义管道及其相关参数
custom_pipeline,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
# 如果自定义管道是一个文件
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# 分解为文件夹和文件名
file_name = path.name
custom_pipeline = path.parent.absolute()
# 如果提供了仓库 ID
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
# 默认文件名
file_name = CUSTOM_PIPELINE_FILE_NAME
# 如果提供了仓库 ID 和修订版本
if repo_id is not None and hub_revision is not None:
# 从 Hub 加载管道代码时,确保覆盖修订版本
revision = hub_revision
# 返回从动态模块获取类
return get_class_from_dynamic_module(
custom_pipeline,
module_file=file_name,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
)
# 定义一个获取管道类的辅助方法
def _get_pipeline_class(
# 传入类对象及其他参数
class_obj,
config=None,
load_connected_pipeline=False,
custom_pipeline=None,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
# 如果提供了自定义管道
if custom_pipeline is not None:
# 调用获取自定义管道类的方法
return _get_custom_pipeline_class(
custom_pipeline,
repo_id=repo_id,
hub_revision=hub_revision,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
)
# 如果类对象不是 DiffusionPipeline
if class_obj.__name__ != "DiffusionPipeline":
# 直接返回类对象
return class_obj
# 导入 diffusers 模块
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
# 获取类名,如果未提供则从配置中获取
class_name = class_name or config["_class_name"]
# 如果类名不存在,抛出错误
if not class_name:
raise ValueError(
"在配置文件中找不到类名。请确保传入正确的 `class_name`。"
)
# 如果类名以 "Flax" 开头,则去掉前四个字符,否则保持原样
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
# 从 diffusers_module 动态获取指定名称的类
pipeline_cls = getattr(diffusers_module, class_name)
# 如果需要加载连接的管道
if load_connected_pipeline:
# 从 auto_pipeline 导入获取连接管道的函数
from .auto_pipeline import _get_connected_pipeline
# 获取与管道类相关联的连接管道类
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
# 如果找到了连接的管道类
if connected_pipeline_cls is not None:
# 记录加载的连接管道类的信息
logger.info(
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
)
else:
# 记录没有找到连接管道类的信息
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
# 使用找到的连接管道类或保持原管道类
pipeline_cls = connected_pipeline_cls or pipeline_cls
# 返回最终的管道类
return pipeline_cls
# 加载一个空模型的函数,接受多种参数
def _load_empty_model(
library_name: str, # 库的名称
class_name: str, # 类的名称
importable_classes: List[Any], # 可导入的类列表
pipelines: Any, # 管道相关信息
is_pipeline_module: bool, # 是否为管道模块的布尔值
name: str, # 模型名称
torch_dtype: Union[str, torch.dtype], # Torch 数据类型
cached_folder: Union[str, os.PathLike], # 缓存文件夹路径
**kwargs, # 其他额外参数
):
# 检索类对象
class_obj, _ = get_class_obj_and_candidates(
library_name, # 库的名称
class_name, # 类的名称
importable_classes, # 可导入的类列表
pipelines, # 管道相关信息
is_pipeline_module, # 是否为管道模块的布尔值
component_name=name, # 组件名称
cache_dir=cached_folder, # 缓存目录
)
# 检查是否可用 transformers 库
if is_transformers_available():
# 解析 transformers 的版本号
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
# 如果不可用,版本设置为 "N/A"
transformers_version = "N/A"
# 确定库的类型
is_transformers_model = (
is_transformers_available() # 检查 transformers 库是否可用
and issubclass(class_obj, PreTrainedModel) # 检查类是否为 PreTrainedModel 的子类
and transformers_version >= version.parse("4.20.0") # 检查版本号是否符合要求
)
# 导入 diffusers 模块
diffusers_module = importlib.import_module(__name__.split(".")[0])
# 检查类是否为 diffusers 模型的子类
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
model = None # 初始化模型为 None
config_path = cached_folder # 设置配置路径为缓存文件夹
# 设置用户代理信息
user_agent = {
"diffusers": __version__, # 当前 diffusers 的版本
"file_type": "model", # 文件类型为模型
"framework": "pytorch", # 框架为 PyTorch
}
# 如果是 diffusers 模型
if is_diffusers_model:
# 加载配置,然后在元信息上加载模型
config, unused_kwargs, commit_hash = class_obj.load_config(
os.path.join(config_path, name), # 配置文件路径
cache_dir=cached_folder, # 缓存目录
return_unused_kwargs=True, # 返回未使用的关键字参数
return_commit_hash=True, # 返回提交哈希值
force_download=kwargs.pop("force_download", False), # 强制下载标志
proxies=kwargs.pop("proxies", None), # 代理设置
local_files_only=kwargs.pop("local_files_only", False), # 仅使用本地文件标志
token=kwargs.pop("token", None), # 认证令牌
revision=kwargs.pop("revision", None), # 版本修订信息
subfolder=kwargs.pop("subfolder", None), # 子文件夹
user_agent=user_agent, # 用户代理信息
)
# 初始化空权重
with accelerate.init_empty_weights():
model = class_obj.from_config(config, **unused_kwargs) # 从配置中创建模型
# 如果是 transformers 模型
elif is_transformers_model:
config_class = getattr(class_obj, "config_class", None) # 获取配置类
# 如果配置类为空,抛出错误
if config_class is None:
raise ValueError("`config_class` cannot be None. Please double-check the model.")
# 从预训练配置加载模型配置
config = config_class.from_pretrained(
cached_folder, # 缓存文件夹
subfolder=name, # 子文件夹
force_download=kwargs.pop("force_download", False), # 强制下载标志
proxies=kwargs.pop("proxies", None), # 代理设置
local_files_only=kwargs.pop("local_files_only", False), # 仅使用本地文件标志
token=kwargs.pop("token", None), # 认证令牌
revision=kwargs.pop("revision", None), # 版本修订信息
user_agent=user_agent, # 用户代理信息
)
# 初始化空权重
with accelerate.init_empty_weights():
model = class_obj(config) # 使用配置创建模型
# 如果模型已创建
if model is not None:
model = model.to(dtype=torch_dtype) # 将模型转换为指定的数据类型
return model # 返回加载的模型
# 定义一个包含模块大小的字典,键为模块名称,值为模块大小(以浮点数表示)
module_sizes: Dict[str, float],
# 定义一个包含设备内存的字典,键为设备名称,值为设备内存(以浮点数表示)
device_memory: Dict[str, float],
# 定义设备映射策略的字符串,默认值为 "balanced"
device_mapping_strategy: str = "balanced"
):
# 获取设备内存字典的所有设备 ID,并转换为列表
device_ids = list(device_memory.keys())
# 创建一个设备循环列表,包含设备 ID 的正序和反序
device_cycle = device_ids + device_ids[::-1]
# 复制设备内存字典,以避免修改原始字典
device_memory = device_memory.copy()
# 初始化设备 ID 和组件的映射字典
device_id_component_mapping = {}
# 当前设备索引,初始化为 0
current_device_index = 0
# 遍历模块大小字典
for component in module_sizes:
# 根据当前索引获取对应的设备 ID
device_id = device_cycle[current_device_index % len(device_cycle)]
# 获取当前组件所需的内存大小
component_memory = module_sizes[component]
# 获取当前设备的可用内存
curr_device_memory = device_memory[device_id]
# 如果 GPU 的内存不足以容纳当前组件,则将其转移到 CPU
if component_memory > curr_device_memory:
device_id_component_mapping["cpu"] = [component]
else:
# 如果设备 ID 不在映射中,则初始化该设备的组件列表
if device_id not in device_id_component_mapping:
device_id_component_mapping[device_id] = [component]
else:
# 如果设备 ID 已存在,则将组件添加到该设备的组件列表
device_id_component_mapping[device_id].append(component)
# 更新设备的剩余内存
device_memory[device_id] -= component_memory
# 移动到下一个设备索引
current_device_index += 1
# 返回设备 ID 到组件的映射字典
return device_id_component_mapping
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
# 为了避免循环导入问题,导入管道模块
from diffusers import pipelines
# 从关键字参数中获取 torch 数据类型,默认值为 torch.float32
torch_dtype = kwargs.get("torch_dtype", torch.float32)
# 在一个元设备上加载管道中的每个模块,以便推导出设备映射
init_empty_modules = {}
# 遍历初始化字典中的每个名称及其对应的库名称和类名称
for name, (library_name, class_name) in init_dict.items():
# 如果类名以 "Flax" 开头,抛出不支持的错误
if class_name.startswith("Flax"):
raise ValueError("Flax pipelines are not supported with `device_map`.")
# 定义所有可导入的类
is_pipeline_module = hasattr(pipelines, library_name) # 检查 pipelines 是否有对应的库名称
importable_classes = ALL_IMPORTABLE_CLASSES # 获取所有可导入类的集合
loaded_sub_model = None # 初始化已加载的子模型为 None
# 使用传入的子模型或从库名称加载类名称
if name in passed_class_obj: # 如果传入的类对象中有该名称
# 如果模型在管道模块中,则从管道加载模型
# 检查传入的类对象是否具有正确的父类
maybe_raise_or_warn(
library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
)
with accelerate.init_empty_weights(): # 在初始化空权重的上下文中
loaded_sub_model = passed_class_obj[name] # 从传入的对象中加载模型
else:
# 加载空模型
loaded_sub_model = _load_empty_model(
library_name=library_name, # 库名称
class_name=class_name, # 类名称
importable_classes=importable_classes, # 可导入类集合
pipelines=pipelines, # 管道
is_pipeline_module=is_pipeline_module, # 是否为管道模块
pipeline_class=pipeline_class, # 管道类
name=name, # 名称
torch_dtype=torch_dtype, # torch 数据类型
cached_folder=kwargs.get("cached_folder", None), # 缓存文件夹
force_download=kwargs.get("force_download", None), # 强制下载标志
proxies=kwargs.get("proxies", None), # 代理设置
local_files_only=kwargs.get("local_files_only", None), # 仅限本地文件标志
token=kwargs.get("token", None), # 访问令牌
revision=kwargs.get("revision", None), # 版本修订号
)
# 如果已加载子模型不为 None,将其添加到初始化空模块中
if loaded_sub_model is not None:
init_empty_modules[name] = loaded_sub_model
# 确定设备映射
# 获取一个按大小排序的字典,用于映射模型级组件
# 到其大小。
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""] # 计算每个模块的大小
for module_name, module in init_empty_modules.items() # 遍历初始化空模块中的每个模块
if isinstance(module, torch.nn.Module) # 仅考虑 PyTorch 模块
}
# 对模块大小字典进行排序,按值降序排列
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
# 获取每个设备(仅限 GPU)的最大可用内存
max_memory = get_max_memory(max_memory) # 获取最大内存信息
# 对最大内存字典进行排序,按值降序排列
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
# 从最大内存字典中移除 CPU 条目
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
# 获取一个字典,用于将模型级组件映射到基于最大内存和模型大小的可用设备
final_device_map = None # 初始化最终设备映射为 None
# 检查最大内存的长度是否大于 0
if len(max_memory) > 0:
# 分配组件到设备,并返回设备与组件的映射
device_id_component_mapping = _assign_components_to_devices(
module_sizes, max_memory, device_mapping_strategy=device_map
)
# 初始化最终设备映射字典
final_device_map = {}
# 遍历设备 ID 和其对应的组件
for device_id, components in device_id_component_mapping.items():
# 遍历每个组件,将其映射到设备 ID
for component in components:
final_device_map[component] = device_id
# 返回最终的设备映射字典
return final_device_map
# 定义一个加载子模型的辅助方法,接受多个参数来配置模型加载
def load_sub_model(
# 模型所在库的名称
library_name: str,
# 模型类的名称
class_name: str,
# 可导入类的列表
importable_classes: List[Any],
# 管道相关参数
pipelines: Any,
# 是否为管道模块的标志
is_pipeline_module: bool,
# 管道类
pipeline_class: Any,
# 指定的 torch 数据类型
torch_dtype: torch.dtype,
# 提供者参数
provider: Any,
# 会话选项
sess_options: Any,
# 设备映射,可能是字典或字符串
device_map: Optional[Union[Dict[str, torch.device], str]],
# 最大内存使用配置
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
# 离线文件夹路径
offload_folder: Optional[Union[str, os.PathLike]],
# 是否离线保存状态字典
offload_state_dict: bool,
# 模型变体的字典
model_variants: Dict[str, str],
# 模型名称
name: str,
# 是否从 Flax 框架加载
from_flax: bool,
# 模型变体名称
variant: str,
# 是否使用低 CPU 内存使用模式
low_cpu_mem_usage: bool,
# 缓存文件夹路径
cached_folder: Union[str, os.PathLike],
):
"""从指定库和类名加载模块 `name` 的辅助方法"""
# 获取类对象及候选类列表
class_obj, class_candidates = get_class_obj_and_candidates(
library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
)
load_method_name = None
# 获取加载方法名称
for class_name, class_candidate in class_candidates.items():
# 如果候选类不为 None,且 class_obj 是其子类
if class_candidate is not None and issubclass(class_obj, class_candidate):
# 从可导入类中获取加载方法名称
load_method_name = importable_classes[class_name][1]
# 如果加载方法名称为 None,说明是一个虚拟模块 -> 抛出错误
if load_method_name is None:
none_module = class_obj.__module__
# 检查模块路径是否属于虚拟模块
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
TRANSFORMERS_DUMMY_MODULES_FOLDER
)
# 如果是虚拟模块,调用 class_obj 以获得友好的错误信息
if is_dummy_path and "dummy" in none_module:
class_obj()
# 抛出值错误,说明没有定义任何加载方法
raise ValueError(
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
# 获取加载方法
load_method = getattr(class_obj, load_method_name)
# 为加载方法添加关键字参数
diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {}
# 如果类是 PyTorch 模块,则添加 torch_dtype 参数
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
# 如果类是 OnnxRuntimeModel,则添加 provider 和 sess_options 参数
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
# 检查类是否为 diffusers 模型
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
# 检查 transformers 是否可用
if is_transformers_available():
# 获取 transformers 的版本
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
# 如果不可用,则版本标记为 "N/A"
transformers_version = "N/A"
# 检查类是否为 transformers 模型,并且版本满足要求
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
# 加载 transformers 模型时,如果 device_map 为 None,则权重将被初始化,而不是 diffusers.
# 为了加快默认加载速度,设置 `low_cpu_mem_usage=low_cpu_mem_usage` 标志,默认为 `True`。
# 确保权重不会被初始化,从而显著加快加载速度。
if is_diffusers_model or is_transformers_model:
# 将设备映射添加到加载参数中
loading_kwargs["device_map"] = device_map
# 设置最大内存使用限制
loading_kwargs["max_memory"] = max_memory
# 设置用于卸载的文件夹
loading_kwargs["offload_folder"] = offload_folder
# 设置卸载状态字典的标志
loading_kwargs["offload_state_dict"] = offload_state_dict
# 从模型变体中获取对应的变体
loading_kwargs["variant"] = model_variants.pop(name, None)
if from_flax:
# 如果模型来自 Flax,设置标志
loading_kwargs["from_flax"] = True
# 以下内容可以在 `transformers` 的最低版本高于 4.27 时删除
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
# 如果版本不符合要求,抛出导入错误
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
# 如果变体为空,从加载参数中移除变体
loading_kwargs.pop("variant")
# 如果来自 Flax 并且模型是变换器模型,无法使用 `low_cpu_mem_usage` 加载
if not (from_flax and is_transformers_model):
# 设置低 CPU 内存使用标志
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
# 否则将标志设置为 False
loading_kwargs["low_cpu_mem_usage"] = False
# 检查模块是否在子目录中
if os.path.isdir(os.path.join(cached_folder, name)):
# 从子目录加载模型
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# 否则从根目录加载模型
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
# 移除模型中的钩子
remove_hook_from_module(loaded_sub_model, recurse=True)
# 检查是否需要将模型卸载到 CPU
needs_offloading_to_cpu = device_map[""] == "cpu"
if needs_offloading_to_cpu:
# 如果需要卸载,将模型分发到 CPU
dispatch_model(
loaded_sub_model,
state_dict=loaded_sub_model.state_dict(),
device_map=device_map,
force_hooks=True,
main_device=0,
)
else:
# 否则正常分发模型
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
# 返回加载的子模型
return loaded_sub_model
# 根据传入模块获取类库名称和类名的元组
def _fetch_class_library_tuple(module):
# 在这里导入模块以避免循环导入问题
diffusers_module = importlib.import_module(__name__.split(".")[0])
# 获取 diffusers_module 中的 pipelines 属性
pipelines = getattr(diffusers_module, "pipelines")
# 从原始模块注册配置,而不是从动态编译的模块
not_compiled_module = _unwrap_model(module)
# 获取模块的库名称
library = not_compiled_module.__module__.split(".")[0]
# 检查该模块是否为管道模块
module_path_items = not_compiled_module.__module__.split(".")
# 获取模块路径倒数第二项,作为管道目录
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
# 获取模块路径,并检查是否为管道模块
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# 如果库名称不在 LOADABLE_CLASSES 中,则认为它是自定义模块
# 或者如果是管道模块,则库名称设为模块名称
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__
# 获取类名
class_name = not_compiled_module.__class__.__name__
# 返回库名称和类名的元组
return (library, class_name)