diffusers-源码解析-四十五-
diffusers 源码解析(四十五)
.\diffusers\pipelines\stable_diffusion\pipeline_flax_stable_diffusion_img2img.py
# 版权声明,表明此文件的版权归 HuggingFace 团队所有
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证进行授权,用户必须遵守该许可证使用此文件
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 用户可以在以下链接获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用的法律要求或书面同意,否则此文件按“原样”提供,没有任何明示或暗示的担保
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言管辖权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings # 导入警告模块,用于发出警告消息
from functools import partial # 从 functools 导入 partial 函数,用于部分应用
from typing import Dict, List, Optional, Union # 导入类型提示相关的模块
import jax # 导入 JAX 库,用于加速数值计算
import jax.numpy as jnp # 导入 JAX 的 NumPy 模块,作为 jnp 使用
import numpy as np # 导入 NumPy 库,作为 np 使用
from flax.core.frozen_dict import FrozenDict # 从 flax 导入 FrozenDict,用于创建不可变字典
from flax.jax_utils import unreplicate # 从 flax 导入 unreplicate 函数,用于去除 JAX 复制
from flax.training.common_utils import shard # 从 flax 导入 shard 函数,用于数据切分
from PIL import Image # 从 PIL 导入 Image 模块,用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel # 导入 transformers 中的处理器和模型
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel # 导入模型
from ...schedulers import ( # 从调度器模块导入各种调度器
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring # 导入工具函数和日志模块
from ..pipeline_flax_utils import FlaxDiffusionPipeline # 从 pipeline_flax_utils 导入 FlaxDiffusionPipeline
from .pipeline_output import FlaxStableDiffusionPipelineOutput # 从 pipeline_output 导入 FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker # 从 safety_checker_flax 导入安全检查器
logger = logging.get_logger(__name__) # 创建一个日志记录器,使用当前模块的名称
# 设置为 True 时使用 Python 的 for 循环,而非 jax.fori_loop,以便于调试
DEBUG = False
# 示例文档字符串的模板
EXAMPLE_DOC_STRING = """
# 示例代码块,用于演示如何使用库和函数
Examples:
```py
# 导入 JAX 库
>>> import jax
# 导入 NumPy 库
>>> import numpy as np
# 导入 JAX 的 NumPy 实现
>>> import jax.numpy as jnp
# 从 flax.jax_utils 导入复制函数
>>> from flax.jax_utils import replicate
# 从 flax.training.common_utils 导入分片函数
>>> from flax.training.common_utils import shard
# 导入 requests 库用于发送 HTTP 请求
>>> import requests
# 从 io 模块导入 BytesIO 用于处理字节流
>>> from io import BytesIO
# 从 PIL 库导入 Image 类用于图像处理
>>> from PIL import Image
# 从 diffusers 导入 FlaxStableDiffusionImg2ImgPipeline 类
>>> from diffusers import FlaxStableDiffusionImg2ImgPipeline
# 定义一个创建随机数种子的函数
>>> def create_key(seed=0):
... # 返回一个基于给定种子的 JAX 随机数生成器密钥
... return jax.random.PRNGKey(seed)
# 使用种子 0 创建随机数生成器密钥
>>> rng = create_key(0)
# 定义要下载的图像 URL
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
# 发送 GET 请求以获取图像
>>> response = requests.get(url)
# 从响应内容中读取图像,并转换为 RGB 模式
>>> init_img = Image.open(BytesIO(response.content)).convert("RGB")
# 调整图像大小为 768x512 像素
>>> init_img = init_img.resize((768, 512))
# 定义提示词
>>> prompts = "A fantasy landscape, trending on artstation"
# 从预训练模型中加载图像到图像生成管道
>>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", # 模型名称
... revision="flax", # 版本标识
... dtype=jnp.bfloat16, # 数据类型
... )
# 获取设备数量以生成样本
>>> num_samples = jax.device_count()
# 根据设备数量拆分随机数生成器的密钥
>>> rng = jax.random.split(rng, jax.device_count())
# 准备输入的提示词和图像,复制 num_samples 次
>>> prompt_ids, processed_image = pipeline.prepare_inputs(
... prompt=[prompts] * num_samples, # 创建提示词列表
... image=[init_img] * num_samples # 创建图像列表
... )
# 复制参数以便在多个设备上使用
>>> p_params = replicate(params)
# 将提示词 ID 分片以适应设备
>>> prompt_ids = shard(prompt_ids)
# 将处理后的图像分片以适应设备
>>> processed_image = shard(processed_image)
# 调用管道生成图像
>>> output = pipeline(
... prompt_ids=prompt_ids, # 提示词 ID
... image=processed_image, # 处理后的图像
... params=p_params, # 复制的参数
... prng_seed=rng, # 随机数种子
... strength=0.75, # 强度参数
... num_inference_steps=50, # 推理步骤数
... jit=True, # 启用 JIT 编译
... height=512, # 输出图像高度
... width=768, # 输出图像宽度
... ).images # 获取生成的图像
# 将输出的图像转换为 PIL 格式以便展示
>>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
```py
# 定义一个基于 Flax 的文本引导图像生成管道类,用于图像到图像的生成
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
r"""
基于 Flax 的管道,用于使用 Stable Diffusion 进行文本引导的图像到图像生成。
该模型继承自 [`FlaxDiffusionPipeline`]。有关所有管道的通用方法的文档(下载、保存、在特定设备上运行等),请查看超类文档。
参数:
vae ([`FlaxAutoencoderKL`]):
用于对图像进行编码和解码的变分自编码器(VAE)模型,将图像转换为潜在表示。
text_encoder ([`~transformers.FlaxCLIPTextModel`]):
冻结的文本编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
tokenizer ([`~transformers.CLIPTokenizer`]):
用于对文本进行标记化的 `CLIPTokenizer`。
unet ([`FlaxUNet2DConditionModel`]):
用于对编码图像潜在空间进行去噪的 `FlaxUNet2DConditionModel`。
scheduler ([`SchedulerMixin`]):
与 `unet` 结合使用的调度器,用于去噪编码的图像潜在空间。可以是
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
[`FlaxDPMSolverMultistepScheduler`] 之一。
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
分类模块,用于评估生成的图像是否可能被认为是冒犯性或有害的。
有关模型潜在危害的更多详细信息,请参阅 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5)。
feature_extractor ([`~transformers.CLIPImageProcessor`]):
用于从生成图像中提取特征的 `CLIPImageProcessor`;作为输入用于 `safety_checker`。
"""
# 初始化方法,设置管道的各个组件
def __init__(
self,
# 变分自编码器模型
vae: FlaxAutoencoderKL,
# 文本编码器模型
text_encoder: FlaxCLIPTextModel,
# 文本标记器
tokenizer: CLIPTokenizer,
# 去噪模型
unet: FlaxUNet2DConditionModel,
# 调度器,用于去噪处理
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
# 安全检查模块
safety_checker: FlaxStableDiffusionSafetyChecker,
# 特征提取器
feature_extractor: CLIPImageProcessor,
# 数据类型,默认为 float32
dtype: jnp.dtype = jnp.float32,
):
# 调用父类的构造函数
super().__init__()
# 设置数据类型属性
self.dtype = dtype
# 检查安全检查器是否为 None
if safety_checker is None:
# 记录警告,提醒用户禁用了安全检查器
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# 注册模块,将各个组件进行初始化
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# 计算 VAE 的缩放因子,基于其配置的输出通道数
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 准备输入,接受文本提示和图像
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
# 检查 prompt 类型是否为字符串或列表
if not isinstance(prompt, (str, list)):
# 如果不符合类型,抛出错误
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查 image 类型是否为 PIL 图像或列表
if not isinstance(image, (Image.Image, list)):
# 如果不符合类型,抛出错误
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
# 如果 image 是单个图像,则转换为列表
if isinstance(image, Image.Image):
image = [image]
# 预处理图像,并将它们拼接为一个数组
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
# 将文本提示编码为模型输入格式
text_input = self.tokenizer(
prompt,
padding="max_length", # 填充到最大长度
max_length=self.tokenizer.model_max_length, # 最大长度设置
truncation=True, # 超出最大长度时截断
return_tensors="np", # 返回 NumPy 格式的张量
)
# 返回文本输入 ID 和处理后的图像
return text_input.input_ids, processed_images
# 获取是否包含不适宜内容的概念
def _get_has_nsfw_concepts(self, features, params):
# 使用安全检查器检查特征是否包含不适宜内容
has_nsfw_concepts = self.safety_checker(features, params)
# 返回检查结果
return has_nsfw_concepts
# 定义一个安全检查器的运行方法,处理输入的图像
def _run_safety_checker(self, images, safety_model_params, jit=False):
# 当 jit 为 True 时,安全模型参数应已被复制
pil_images = [Image.fromarray(image) for image in images] # 将 NumPy 数组转换为 PIL 图像
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values # 提取图像特征并返回像素值
if jit: # 如果启用 JIT 编译
features = shard(features) # 将特征分片以优化性能
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) # 检查是否存在 NSFW 概念
has_nsfw_concepts = unshard(has_nsfw_concepts) # 将结果反分片
safety_model_params = unreplicate(safety_model_params) # 反复制安全模型参数
else: # 如果没有启用 JIT 编译
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) # 获取 NSFW 概念的存在性
images_was_copied = False # 标记图像是否已被复制
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): # 遍历 NSFW 概念的列表
if has_nsfw_concept: # 如果检测到 NSFW 概念
if not images_was_copied: # 如果尚未复制图像
images_was_copied = True # 标记为已复制
images = images.copy() # 复制图像数组
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # 用黑色图像替换原图像
if any(has_nsfw_concepts): # 如果任一图像有 NSFW 概念
warnings.warn( # 发出警告
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts # 返回处理后的图像和 NSFW 概念的存在性
# 定义获取开始时间步的方法
def get_timestep_start(self, num_inference_steps, strength):
# 使用初始时间步计算原始时间步
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) # 计算初始时间步,确保不超出总步骤
t_start = max(num_inference_steps - init_timestep, 0) # 计算开始时间步,确保不为负
return t_start # 返回开始时间步
# 定义生成方法
def _generate(
self,
prompt_ids: jnp.ndarray, # 输入提示的 ID
image: jnp.ndarray, # 输入图像
params: Union[Dict, FrozenDict], # 模型参数
prng_seed: jax.Array, # 随机种子
start_timestep: int, # 开始时间步
num_inference_steps: int, # 推理步骤数量
height: int, # 生成图像的高度
width: int, # 生成图像的宽度
guidance_scale: float, # 引导比例
noise: Optional[jnp.ndarray] = None, # 噪声选项
neg_prompt_ids: Optional[jnp.ndarray] = None, # 负提示 ID 选项
@replace_example_docstring(EXAMPLE_DOC_STRING) # 用示例文档字符串替换
def __call__( # 定义可调用方法
self,
prompt_ids: jnp.ndarray, # 输入提示的 ID
image: jnp.ndarray, # 输入图像
params: Union[Dict, FrozenDict], # 模型参数
prng_seed: jax.Array, # 随机种子
strength: float = 0.8, # 强度参数,默认为 0.8
num_inference_steps: int = 50, # 推理步骤数量,默认为 50
height: Optional[int] = None, # 生成图像的高度,默认为 None
width: Optional[int] = None, # 生成图像的宽度,默认为 None
guidance_scale: Union[float, jnp.ndarray] = 7.5, # 引导比例,默认为 7.5
noise: jnp.ndarray = None, # 噪声,默认为 None
neg_prompt_ids: jnp.ndarray = None, # 负提示 ID,默认为 None
return_dict: bool = True, # 是否返回字典,默认为 True
jit: bool = False, # 是否启用 JIT 编译,默认为 False
# 静态参数为 pipe, start_timestep, num_inference_steps, height, width。任何更改都会触发重新编译。
# 非静态参数为 (sharded) 输入张量,按其第一维映射 (因此为 `0`)。
@partial(
jax.pmap, # 使用 JAX 的 pmap 函数进行并行映射
in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), # 指定输入参数的维度
static_broadcasted_argnums=(0, 5, 6, 7, 8), # 静态广播的参数索引
)
def _p_generate(
pipe, # 生成管道对象
prompt_ids, # 输入的提示 ID
image, # 输入的图像数据
params, # 其他参数
prng_seed, # 随机数种子
start_timestep, # 开始的时间步
num_inference_steps, # 推理的步骤数
height, # 图像的高度
width, # 图像的宽度
guidance_scale, # 引导尺度
noise, # 噪声数据
neg_prompt_ids, # 负提示 ID
):
# 调用管道的生成方法,传递所有必要的参数
return pipe._generate(
prompt_ids, # 提示 ID
image, # 图像数据
params, # 其他参数
prng_seed, # 随机数种子
start_timestep, # 开始时间步
num_inference_steps, # 推理步骤数
height, # 图像高度
width, # 图像宽度
guidance_scale, # 引导尺度
noise, # 噪声数据
neg_prompt_ids, # 负提示 ID
)
@partial(jax.pmap, static_broadcasted_argnums=(0,)) # 使用 JAX 的 pmap 函数进行并行映射
def _p_get_has_nsfw_concepts(pipe, features, params):
# 调用管道的方法以获取是否包含 NSFW 概念的特征
return pipe._get_has_nsfw_concepts(features, params)
def unshard(x: jnp.ndarray):
# 将输入张量 x 重组为适合的形状,合并设备和批次维度
num_devices, batch_size = x.shape[:2] # 获取设备数量和批次大小
rest = x.shape[2:] # 获取剩余维度
# 重新调整形状为 (num_devices * batch_size, 剩余维度)
return x.reshape(num_devices * batch_size, *rest)
def preprocess(image, dtype):
w, h = image.size # 获取图像的宽度和高度
# 调整宽度和高度为 32 的整数倍
w, h = (x - x % 32 for x in (w, h))
# 重新调整图像大小,使用 Lanczos 插值法
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
# 将图像转换为 NumPy 数组并归一化到 [0, 1] 范围
image = jnp.array(image).astype(dtype) / 255.0
# 调整数组维度为 (1, 通道数, 高度, 宽度)
image = image[None].transpose(0, 3, 1, 2)
# 将图像值范围转换为 [-1, 1]
return 2.0 * image - 1.0
.\diffusers\pipelines\stable_diffusion\pipeline_flax_stable_diffusion_inpaint.py
# 版权声明,指明该文件的版权信息
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache 许可证第 2.0 版许可使用本文件
# Licensed under the Apache License, Version 2.0 (the "License");
# 除非遵循许可证,否则不得使用本文件
# you may not use this file except in compliance with the License.
# 可以通过以下网址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议规定,否则根据许可证分发的软件
# Unless required by applicable law or agreed to in writing, software
# 是按“原样”基础分发的,不提供任何形式的担保或条件
# distributed under the License is distributed on an "AS IS" BASIS,
# 不论是明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 参见许可证以了解适用权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings # 导入 warnings 模块以处理警告
from functools import partial # 从 functools 导入 partial,用于部分应用函数
from typing import Dict, List, Optional, Union # 导入类型注解工具
import jax # 导入 jax 库,用于高性能数值计算
import jax.numpy as jnp # 导入 jax 的 numpy 作为 jnp
import numpy as np # 导入 numpy 库以进行数组操作
from flax.core.frozen_dict import FrozenDict # 从 flax 导入 FrozenDict 用于不可变字典
from flax.jax_utils import unreplicate # 从 flax 导入 unreplicate,用于去除复制
from flax.training.common_utils import shard # 从 flax 导入 shard,用于数据分片
from packaging import version # 导入 version 用于版本比较
from PIL import Image # 从 PIL 导入 Image 用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel # 导入 transformers 库的相关组件
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel # 导入自定义模型
from ...schedulers import ( # 从自定义调度器导入各类调度器
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring # 导入工具函数
from ..pipeline_flax_utils import FlaxDiffusionPipeline # 导入 FlaxDiffusionPipeline 类
from .pipeline_output import FlaxStableDiffusionPipelineOutput # 导入输出类
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker # 导入安全检查器类
logger = logging.get_logger(__name__) # 创建日志记录器,使用当前模块名称
# 设置为 True 时使用 Python 的 for 循环而不是 jax.fori_loop,以便于调试
DEBUG = False
EXAMPLE_DOC_STRING = """ # 定义示例文档字符串,通常用于文档生成
# 示例代码块,用于展示如何使用 JAX 和 Flax 进行图像处理
Examples:
```py
# 导入必要的库
>>> import jax
>>> import numpy as np
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> import PIL
>>> import requests
>>> from io import BytesIO
>>> from diffusers import FlaxStableDiffusionInpaintPipeline
# 定义一个函数,用于下载图像并转换为 RGB 格式
>>> def download_image(url):
... # 发送 GET 请求以获取图像内容
... response = requests.get(url)
... # 打开下载的内容并转换为 RGB 图像
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
# 定义图像和掩码的 URL
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
# 下载并调整初始图像和掩码图像的大小
>>> init_image = download_image(img_url).resize((512, 512))
>>> mask_image = download_image(mask_url).resize((512, 512))
# 从预训练模型加载管道和参数
>>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(
... "xvjiarui/stable-diffusion-2-inpainting"
... )
# 定义处理图像时使用的提示
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
# 初始化随机种子
>>> prng_seed = jax.random.PRNGKey(0)
# 定义推理步骤的数量
>>> num_inference_steps = 50
# 获取设备数量以便并行处理
>>> num_samples = jax.device_count()
# 将提示、初始图像和掩码图像扩展为设备数量的列表
>>> prompt = num_samples * [prompt]
>>> init_image = num_samples * [init_image]
>>> mask_image = num_samples * [mask_image]
# 准备输入,得到提示 ID 和处理后的图像
>>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(
... prompt, init_image, mask_image
... )
# 分割输入和随机数生成器
# 复制参数以适应每个设备
>>> params = replicate(params)
# 根据设备数量分割随机种子
>>> prng_seed = jax.random.split(prng_seed, jax.device_count())
# 将提示 ID 和处理后的图像分割以适应每个设备
>>> prompt_ids = shard(prompt_ids)
>>> processed_masked_images = shard(processed_masked_images)
>>> processed_masks = shard(processed_masks)
# 运行管道以生成图像
>>> images = pipeline(
... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
... ).images
# 将生成的图像数组转换为 PIL 图像格式
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
FlaxStableDiffusionInpaintPipeline 类定义,继承自 FlaxDiffusionPipeline
class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
r"""
Flax 基于 Stable Diffusion 的文本引导图像修补的管道。
<Tip warning={true}>
🧪 这是一个实验性功能!
</Tip>
该模型继承自 [`FlaxDiffusionPipeline`]。有关所有管道通用方法(下载、保存、在特定设备上运行等)的实现,请查看父类文档。
参数:
vae ([`FlaxAutoencoderKL`]):
用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。
text_encoder ([`~transformers.FlaxCLIPTextModel`]):
冻结的文本编码器 ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
tokenizer ([`~transformers.CLIPTokenizer`]):
用于标记化文本的 `CLIPTokenizer`。
unet ([`FlaxUNet2DConditionModel`]):
用于去噪编码图像潜在表示的 `FlaxUNet2DConditionModel`。
scheduler ([`SchedulerMixin`]):
与 `unet` 结合使用以去噪编码图像潜在表示的调度器。可以是以下之一
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
[`FlaxDPMSolverMultistepScheduler`]。
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
估计生成图像是否可能被认为是冒犯性或有害的分类模块。
请参考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以获取有关模型潜在危害的更多详细信息。
feature_extractor ([`~transformers.CLIPImageProcessor`]):
从生成图像中提取特征的 `CLIPImageProcessor`;用作 `safety_checker` 的输入。
"""
# 构造函数初始化
def __init__(
# 变分自编码器(VAE)模型实例
vae: FlaxAutoencoderKL,
# 文本编码器模型实例
text_encoder: FlaxCLIPTextModel,
# 标记器实例
tokenizer: CLIPTokenizer,
# 去噪模型实例
unet: FlaxUNet2DConditionModel,
# 调度器实例,指定可用的调度器类型
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
# 安全检查模块实例
safety_checker: FlaxStableDiffusionSafetyChecker,
# 特征提取器实例
feature_extractor: CLIPImageProcessor,
# 数据类型,默认为 float32
dtype: jnp.dtype = jnp.float32,
# 定义初始化方法,接收多个参数
):
# 调用父类的初始化方法
super().__init__()
# 设置数据类型属性
self.dtype = dtype
# 检查安全检查器是否为 None
if safety_checker is None:
# 记录警告信息,提醒用户禁用安全检查器的风险
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# 检查 UNet 版本是否小于 0.9.0
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
# 检查 UNet 的样本大小是否小于 64
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
# 如果满足两个条件,构造弃用警告信息
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
# 调用弃用函数,传递警告信息
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
# 创建新配置字典,并更新样本大小为 64
new_config = dict(unet.config)
new_config["sample_size"] = 64
# 将新配置赋值给 UNet 的内部字典
unet._internal_dict = FrozenDict(new_config)
# 注册多个模块以供使用
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# 计算 VAE 的缩放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 定义准备输入的方法,接收多个参数
def prepare_inputs(
self,
# 输入提示,可以是字符串或字符串列表
prompt: Union[str, List[str]],
# 输入图像,可以是单张图像或图像列表
image: Union[Image.Image, List[Image.Image]],
# 输入掩码,可以是单张掩码或掩码列表
mask: Union[Image.Image, List[Image.Image]],
):
# 检查 prompt 是否为字符串或列表类型,不符合则抛出异常
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查 image 是否为 PIL 图像或列表类型,不符合则抛出异常
if not isinstance(image, (Image.Image, list)):
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
# 如果 image 是单个 PIL 图像,则将其转为列表
if isinstance(image, Image.Image):
image = [image]
# 检查 mask 是否为 PIL 图像或列表类型,不符合则抛出异常
if not isinstance(mask, (Image.Image, list)):
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
# 如果 mask 是单个 PIL 图像,则将其转为列表
if isinstance(mask, Image.Image):
mask = [mask]
# 对图像进行预处理,并合并为一个数组
processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image])
# 对掩膜进行预处理,并合并为一个数组
processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask])
# 将处理后的掩膜中小于0.5的值设为0
processed_masks = processed_masks.at[processed_masks < 0.5].set(0)
# 将处理后的掩膜中大于等于0.5的值设为1
processed_masks = processed_masks.at[processed_masks >= 0.5].set(1)
# 根据掩膜对图像进行遮罩处理
processed_masked_images = processed_images * (processed_masks < 0.5)
# 将 prompt 进行编码,并设置最大长度、填充和截断
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
# 返回编码后的输入 ID、处理后的图像和掩膜
return text_input.input_ids, processed_masked_images, processed_masks
def _get_has_nsfw_concepts(self, features, params):
# 使用安全检查器检查特征中是否存在 NSFW 概念
has_nsfw_concepts = self.safety_checker(features, params)
# 返回 NSFW 概念的检测结果
return has_nsfw_concepts
def _run_safety_checker(self, images, safety_model_params, jit=False):
# 将传入的图像数组转换为 PIL 图像
pil_images = [Image.fromarray(image) for image in images]
# 提取图像特征并返回张量形式的像素值
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
# 如果开启 JIT 优化,则对特征进行分片
if jit:
features = shard(features)
# 使用 NSFW 概念检测函数获取结果
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
# 对结果进行反分片处理
has_nsfw_concepts = unshard(has_nsfw_concepts)
safety_model_params = unreplicate(safety_model_params)
else:
# 否则直接调用获取 NSFW 概念的函数
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
images_was_copied = False
# 遍历每个 NSFW 概念的检测结果
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
# 如果发现 NSFW 概念且尚未复制图像,则进行复制
if not images_was_copied:
images_was_copied = True
images = images.copy()
# 将对应图像替换为黑色图像
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
# 如果检测到任何 NSFW 概念,则发出警告
if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
# 返回处理后的图像和 NSFW 概念的检测结果
return images, has_nsfw_concepts
# 定义一个生成函数,处理图像生成的相关操作
def _generate(
# 输入的提示ID数组,通常用于模型输入
self,
prompt_ids: jnp.ndarray,
# 输入的掩码数组,指示哪些部分需要处理
mask: jnp.ndarray,
# 被掩码的图像数组,作为生成过程的基础
masked_image: jnp.ndarray,
# 模型参数,可以是字典或冻结字典类型
params: Union[Dict, FrozenDict],
# 随机数种子,用于生成可重复的结果
prng_seed: jax.Array,
# 推理步骤的数量,控制生成的细致程度
num_inference_steps: int,
# 生成图像的高度
height: int,
# 生成图像的宽度
width: int,
# 指导比例,用于调整生成图像与提示的相关性
guidance_scale: float,
# 可选的潜在表示,用于进一步控制生成过程
latents: Optional[jnp.ndarray] = None,
# 可选的负提示ID数组,用于增强生成效果
neg_prompt_ids: Optional[jnp.ndarray] = None,
# 使用装饰器替换示例文档字符串,提供函数的文档说明
@replace_example_docstring(EXAMPLE_DOC_STRING)
# 定义调用函数,进行图像生成操作
def __call__(
# 输入的提示ID数组
self,
prompt_ids: jnp.ndarray,
# 输入的掩码数组
mask: jnp.ndarray,
# 被掩码的图像数组
masked_image: jnp.ndarray,
# 模型参数
params: Union[Dict, FrozenDict],
# 随机数种子
prng_seed: jax.Array,
# 推理步骤的数量,默认为50
num_inference_steps: int = 50,
# 生成图像的高度,默认为None(可选)
height: Optional[int] = None,
# 生成图像的宽度,默认为None(可选)
width: Optional[int] = None,
# 指导比例,默认为7.5
guidance_scale: Union[float, jnp.ndarray] = 7.5,
# 可选的潜在表示,默认为None
latents: jnp.ndarray = None,
# 可选的负提示ID数组,默认为None
neg_prompt_ids: jnp.ndarray = None,
# 返回字典格式的结果,默认为True
return_dict: bool = True,
# 是否使用JIT编译,默认为False
jit: bool = False,
静态参数为管道、推理步骤数、高度和宽度。更改会触发重新编译。
非静态参数为在其第一维度(因此为0)映射的(分片)输入张量。
@partial(
jax.pmap, # 使用 JAX 的并行映射功能
in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), # 指定输入张量的维度映射
static_broadcasted_argnums=(0, 6, 7, 8), # 静态广播参数的索引
)
def _p_generate(
pipe, # 管道对象
prompt_ids, # 提示 ID
mask, # 掩码
masked_image, # 被掩码的图像
params, # 参数
prng_seed, # 随机种子
num_inference_steps, # 推理步骤数
height, # 图像高度
width, # 图像宽度
guidance_scale, # 引导比例
latents, # 潜在表示
neg_prompt_ids, # 负提示 ID
):
return pipe._generate( # 调用管道的生成方法
prompt_ids, # 提示 ID
mask, # 掩码
masked_image, # 被掩码的图像
params, # 参数
prng_seed, # 随机种子
num_inference_steps, # 推理步骤数
height, # 图像高度
width, # 图像宽度
guidance_scale, # 引导比例
latents, # 潜在表示
neg_prompt_ids, # 负提示 ID
)
@partial(jax.pmap, static_broadcasted_argnums=(0,)) # 使用 JAX 的并行映射功能
def _p_get_has_nsfw_concepts(pipe, features, params): # 检查特征是否包含 NSFW 概念
return pipe._get_has_nsfw_concepts(features, params) # 调用管道的方法
def unshard(x: jnp.ndarray): # 定义 unshard 函数,接受一个 ndarray
# einops.rearrange(x, 'd b ... -> (d b) ...') # 用于调整张量的形状
num_devices, batch_size = x.shape[:2] # 获取设备数量和批次大小
rest = x.shape[2:] # 获取其余维度
return x.reshape(num_devices * batch_size, rest) # 重新调整形状为 (db, ...)
def preprocess_image(image, dtype): # 定义预处理图像的函数
w, h = image.size # 获取图像的宽度和高度
w, h = (x - x % 32 for x in (w, h)) # 调整宽度和高度为 32 的整数倍
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) # 按新大小调整图像
image = jnp.array(image).astype(dtype) / 255.0 # 转换为 ndarray 并归一化
image = image[None].transpose(0, 3, 1, 2) # 调整维度顺序
return 2.0 * image - 1.0 # 将图像值范围调整到 [-1, 1]
def preprocess_mask(mask, dtype): # 定义预处理掩码的函数
w, h = mask.size # 获取掩码的宽度和高度
w, h = (x - x % 32 for x in (w, h)) # 调整宽度和高度为 32 的整数倍
mask = mask.resize((w, h)) # 按新大小调整掩码
mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 # 转换为灰度并归一化
mask = jnp.expand_dims(mask, axis=(0, 1)) # 扩展维度以适应模型输入
return mask # 返回处理后的掩码
# `.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion.py`
```py
# 版权声明,表明该代码的版权所有者及相关条款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版("许可证")进行许可;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件在 "按原样" 基础上分发,
# 不提供任何明示或暗示的担保或条件。
# 请参见许可证以获取有关特定语言治理权限和
# 限制的更多信息。
# 导入 inspect 模块以进行获取对象的文档字符串和源代码
import inspect
# 从 typing 模块导入类型提示所需的工具
from typing import Callable, List, Optional, Union
# 导入 numpy 库用于数值计算
import numpy as np
# 导入 torch 库用于深度学习模型的构建和训练
import torch
# 从 transformers 库导入 CLIP 图像处理器和 CLIP 分词器
from transformers import CLIPImageProcessor, CLIPTokenizer
# 从配置工具导入 FrozenDict 用于处理不可变字典
from ...configuration_utils import FrozenDict
# 从调度器导入不同类型的调度器
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
# 从工具模块导入去deprecated功能和日志记录
from ...utils import deprecate, logging
# 从 onnx_utils 导入 ONNX 相关的类型和模型
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
# 从 pipeline_utils 导入扩散管道
from ..pipeline_utils import DiffusionPipeline
# 导入 StableDiffusionPipelineOutput 模块
from . import StableDiffusionPipelineOutput
# 创建一个日志记录器,用于记录该模块的日志信息
logger = logging.get_logger(__name__)
# 定义 OnnxStableDiffusionPipeline 类,继承自 DiffusionPipeline
class OnnxStableDiffusionPipeline(DiffusionPipeline):
# 声明类的各个成员变量,表示使用的模型组件
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPImageProcessor
# 定义可选组件的列表,包括安全检查器和特征提取器
_optional_components = ["safety_checker", "feature_extractor"]
# 标记该管道为 ONNX 格式
_is_onnx = True
# 初始化函数,设置各个组件的参数
def __init__(
self,
vae_encoder: OnnxRuntimeModel, # VAE 编码器模型
vae_decoder: OnnxRuntimeModel, # VAE 解码器模型
text_encoder: OnnxRuntimeModel, # 文本编码器模型
tokenizer: CLIPTokenizer, # CLIP 分词器
unet: OnnxRuntimeModel, # U-Net 模型
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], # 调度器
safety_checker: OnnxRuntimeModel, # 安全检查器模型
feature_extractor: CLIPImageProcessor, # 特征提取器
requires_safety_checker: bool = True, # 是否需要安全检查器
):
# 定义用于编码提示的私有方法
def _encode_prompt(
self,
prompt: Union[str, List[str]], # 输入的提示文本,可以是字符串或字符串列表
num_images_per_prompt: Optional[int], # 每个提示生成的图像数量
do_classifier_free_guidance: bool, # 是否使用无分类器引导
negative_prompt: Optional[str], # 负面提示文本
prompt_embeds: Optional[np.ndarray] = None, # 提示的嵌入表示
negative_prompt_embeds: Optional[np.ndarray] = None, # 负面提示的嵌入表示
):
# 定义检查输入有效性的私有方法
def check_inputs(
self,
prompt: Union[str, List[str]], # 输入的提示文本
height: Optional[int], # 图像高度
width: Optional[int], # 图像宽度
callback_steps: int, # 回调步骤数量
negative_prompt: Optional[str] = None, # 负面提示文本
prompt_embeds: Optional[np.ndarray] = None, # 提示的嵌入表示
negative_prompt_embeds: Optional[np.ndarray] = None, # 负面提示的嵌入表示
# 进行一系列参数检查,以确保输入值的有效性
):
# 检查高度和宽度是否都能被8整除
if height % 8 != 0 or width % 8 != 0:
# 如果不能整除,抛出值错误异常,提示当前高度和宽度
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 检查回调步骤是否为正整数
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
# 如果条件不满足,抛出值错误异常,提示当前回调步骤的类型和值
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 检查同时传入 prompt 和 prompt_embeds
if prompt is not None and prompt_embeds is not None:
# 如果同时传入,抛出值错误异常,提示只能传入其中一个
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 检查 prompt 和 prompt_embeds 是否均未提供
elif prompt is None and prompt_embeds is None:
# 抛出值错误异常,提示必须提供至少一个
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 检查 prompt 的类型
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 如果类型不匹配,抛出值错误异常,提示类型不符合
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查同时传入 negative_prompt 和 negative_prompt_embeds
if negative_prompt is not None and negative_prompt_embeds is not None:
# 如果同时传入,抛出值错误异常,提示只能传入其中一个
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 检查 prompt_embeds 和 negative_prompt_embeds 是否都提供
if prompt_embeds is not None and negative_prompt_embeds is not None:
# 确保它们的形状相同
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 如果形状不匹配,抛出值错误异常,提示它们的形状
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 定义调用方法,接受多个参数
def __call__(
# 提供的提示,类型为字符串或字符串列表
prompt: Union[str, List[str]] = None,
# 图像高度,默认为512
height: Optional[int] = 512,
# 图像宽度,默认为512
width: Optional[int] = 512,
# 推理步骤的数量,默认为50
num_inference_steps: Optional[int] = 50,
# 指导尺度,默认为7.5
guidance_scale: Optional[float] = 7.5,
# 负提示,类型为字符串或字符串列表
negative_prompt: Optional[Union[str, List[str]]] = None,
# 每个提示生成的图像数量,默认为1
num_images_per_prompt: Optional[int] = 1,
# 额外的随机因素,默认为0.0
eta: Optional[float] = 0.0,
# 随机生成器,默认为None
generator: Optional[np.random.RandomState] = None,
# 潜在变量,默认为None
latents: Optional[np.ndarray] = None,
# 提示的嵌入表示,默认为None
prompt_embeds: Optional[np.ndarray] = None,
# 负提示的嵌入表示,默认为None
negative_prompt_embeds: Optional[np.ndarray] = None,
# 输出类型,默认为"pil"
output_type: Optional[str] = "pil",
# 是否返回字典,默认为True
return_dict: bool = True,
# 回调函数,默认为None
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
# 回调步骤,默认为1
callback_steps: int = 1,
# 定义一个名为 StableDiffusionOnnxPipeline 的类,继承自 OnnxStableDiffusionPipeline
class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
# 初始化方法,接受多个模型和处理器作为参数
def __init__(
self,
vae_encoder: OnnxRuntimeModel, # VAE 编码器模型
vae_decoder: OnnxRuntimeModel, # VAE 解码器模型
text_encoder: OnnxRuntimeModel, # 文本编码器模型
tokenizer: CLIPTokenizer, # 分词器
unet: OnnxRuntimeModel, # U-Net 模型
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], # 调度器,可以是多种类型
safety_checker: OnnxRuntimeModel, # 安全检查模型
feature_extractor: CLIPImageProcessor, # 特征提取器
):
# 定义弃用消息,提醒用户使用替代类
deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
# 调用弃用函数,记录弃用警告
deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
# 调用父类的初始化方法,传入所有参数
super().__init__(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion_img2img.py
# 版权信息,表示该代码的所有权归 HuggingFace 团队所有
# 许可信息,表明该文件在 Apache 2.0 许可下分发
# 除非遵守该许可,否则不得使用此文件
# 提供许可的获取地址
# 除非适用法律或书面同意,否则按照 "按现状" 基础分发软件,没有任何明示或暗示的担保
# 详细信息见许可中关于权限和限制的部分
import inspect # 导入 inspect 模块,用于获取对象的信息
from typing import Callable, List, Optional, Union # 导入类型提示,定义函数参数和返回值类型
import numpy as np # 导入 numpy,用于数组和矩阵操作
import PIL.Image # 导入 PIL.Image,用于图像处理
import torch # 导入 PyTorch,支持深度学习计算
from transformers import CLIPImageProcessor, CLIPTokenizer # 从 transformers 库导入 CLIP 图像处理器和分词器
from ...configuration_utils import FrozenDict # 从配置工具导入 FrozenDict,用于不可变字典
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler # 导入调度器类,用于模型训练调度
from ...utils import PIL_INTERPOLATION, deprecate, logging # 导入工具类,处理 PIL 插值、弃用警告和日志记录
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel # 从 ONNX 工具导入转换类型和模型类
from ..pipeline_utils import DiffusionPipeline # 导入 DiffusionPipeline,基础管道类
from . import StableDiffusionPipelineOutput # 导入稳定扩散管道的输出类
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess 中复制的 preprocess 函数,调整尺寸从 8 变为 64
def preprocess(image):
# 弃用消息,通知用户该方法将在未来版本中删除
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
# 触发弃用警告,提醒用户使用替代方法
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
# 如果输入是 PyTorch 张量,则直接返回
if isinstance(image, torch.Tensor):
return image
# 如果输入是 PIL 图像,将其封装到列表中
elif isinstance(image, PIL.Image.Image):
image = [image]
# 如果输入的第一个元素是 PIL 图像
if isinstance(image[0], PIL.Image.Image):
# 获取图像的宽度和高度
w, h = image[0].size
# 将宽和高调整为64的整数倍
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
# 调整图像大小并转换为 NumPy 数组
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
# 沿着第0维连接图像数组
image = np.concatenate(image, axis=0)
# 将图像数据类型转换为浮点型并归一化到 [0, 1]
image = np.array(image).astype(np.float32) / 255.0
# 调整数组的维度顺序
image = image.transpose(0, 3, 1, 2)
# 将图像数据缩放到 [-1, 1]
image = 2.0 * image - 1.0
# 将 NumPy 数组转换为 PyTorch 张量
image = torch.from_numpy(image)
# 如果输入的第一个元素是 PyTorch 张量
elif isinstance(image[0], torch.Tensor):
# 沿着第0维连接多个张量
image = torch.cat(image, dim=0)
# 返回处理后的图像
return image
# 定义一个用于文本引导的图像到图像生成的管道类,使用稳定扩散模型
class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
r"""
用于文本引导的图像到图像生成的管道,基于稳定扩散模型。
该模型继承自 [`DiffusionPipeline`]。查看超类文档,了解库为所有管道实现的通用方法
(例如下载或保存、在特定设备上运行等)。
# 参数说明
Args:
vae ([`AutoencoderKL`]): # 变分自编码器模型,用于将图像编码和解码为潜在表示。
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]): # 冻结的文本编码器,Stable Diffusion 使用 CLIP 的文本部分。
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),具体是
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 变体。
tokenizer (`CLIPTokenizer`): # CLIPTokenizer 类的分词器。
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): # 条件 U-Net 结构,用于去噪编码的图像潜在表示。
Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]): # 与 `unet` 结合使用的调度器,用于去噪编码的图像潜在表示。
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]): # 分类模块,估计生成的图像是否可能被视为冒犯或有害。
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]): # 从生成的图像中提取特征,以作为 `safety_checker` 的输入。
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
# 定义变量类型
vae_encoder: OnnxRuntimeModel # VAE 编码器的类型
vae_decoder: OnnxRuntimeModel # VAE 解码器的类型
text_encoder: OnnxRuntimeModel # 文本编码器的类型
tokenizer: CLIPTokenizer # 分词器的类型
unet: OnnxRuntimeModel # U-Net 模型的类型
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] # 调度器的类型
safety_checker: OnnxRuntimeModel # 安全检查器的类型
feature_extractor: CLIPImageProcessor # 特征提取器的类型
# 可选组件列表
_optional_components = ["safety_checker", "feature_extractor"] # 包含可选组件的列表
_is_onnx = True # 指示是否使用 ONNX 模型
# 构造函数初始化各个组件
def __init__( # 初始化方法
self,
vae_encoder: OnnxRuntimeModel, # 传入的 VAE 编码器
vae_decoder: OnnxRuntimeModel, # 传入的 VAE 解码器
text_encoder: OnnxRuntimeModel, # 传入的文本编码器
tokenizer: CLIPTokenizer, # 传入的分词器
unet: OnnxRuntimeModel, # 传入的 U-Net 模型
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], # 传入的调度器
safety_checker: OnnxRuntimeModel, # 传入的安全检查器
feature_extractor: CLIPImageProcessor, # 传入的特征提取器
requires_safety_checker: bool = True, # 是否需要安全检查器的标志
# 从 diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt 复制
def _encode_prompt( # 编码提示的方法
self,
prompt: Union[str, List[str]], # 提示文本,可以是字符串或字符串列表
num_images_per_prompt: Optional[int], # 每个提示生成的图像数量
do_classifier_free_guidance: bool, # 是否进行无分类器引导
negative_prompt: Optional[str], # 可选的负面提示文本
prompt_embeds: Optional[np.ndarray] = None, # 可选的提示嵌入
negative_prompt_embeds: Optional[np.ndarray] = None, # 可选的负面提示嵌入
# 定义一个检查输入参数的函数
def check_inputs(
self, # 类实例自身
prompt: Union[str, List[str]], # 提示信息,可以是字符串或字符串列表
callback_steps: int, # 回调步骤的整数值
negative_prompt: Optional[Union[str, List[str]]] = None, # 可选的负面提示,字符串或列表
prompt_embeds: Optional[np.ndarray] = None, # 可选的提示嵌入,NumPy 数组
negative_prompt_embeds: Optional[np.ndarray] = None, # 可选的负面提示嵌入,NumPy 数组
):
# 检查回调步骤是否为 None 或非正整数
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
# 如果回调步骤不符合要求,则抛出值错误
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 检查是否同时提供了提示和提示嵌入
if prompt is not None and prompt_embeds is not None:
# 如果同时提供了,则抛出值错误
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 检查提示和提示嵌入是否都未提供
elif prompt is None and prompt_embeds is None:
# 如果都未提供,则抛出值错误
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 检查提示类型是否为字符串或列表
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 如果类型不符合,则抛出值错误
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查是否同时提供了负面提示和负面提示嵌入
if negative_prompt is not None and negative_prompt_embeds is not None:
# 如果同时提供了,则抛出值错误
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 检查提示嵌入和负面提示嵌入的形状是否一致
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 如果形状不一致,则抛出值错误
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 定义可调用的方法,处理提示和生成图像
def __call__(
self, # 类实例自身
prompt: Union[str, List[str]], # 提示信息,可以是字符串或字符串列表
image: Union[np.ndarray, PIL.Image.Image] = None, # 可选的图像输入,可以是 NumPy 数组或 PIL 图像
strength: float = 0.8, # 图像强度的浮点值,默认为 0.8
num_inference_steps: Optional[int] = 50, # 可选的推理步骤数,默认为 50
guidance_scale: Optional[float] = 7.5, # 可选的引导尺度,默认为 7.5
negative_prompt: Optional[Union[str, List[str]]] = None, # 可选的负面提示,字符串或列表
num_images_per_prompt: Optional[int] = 1, # 每个提示生成的图像数量,默认为 1
eta: Optional[float] = 0.0, # 可选的 eta 值,默认为 0.0
generator: Optional[np.random.RandomState] = None, # 可选的随机数生成器
prompt_embeds: Optional[np.ndarray] = None, # 可选的提示嵌入,NumPy 数组
negative_prompt_embeds: Optional[np.ndarray] = None, # 可选的负面提示嵌入,NumPy 数组
output_type: Optional[str] = "pil", # 输出类型,默认为 'pil'
return_dict: bool = True, # 是否返回字典格式,默认为 True
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, # 可选的回调函数
callback_steps: int = 1, # 回调步骤的整数值,默认为 1
.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion_inpaint.py
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件按“现状”基础进行分发,
# 不附有任何明示或暗示的担保或条件。
# 请参见许可证以获取管理权限的特定语言和
# 限制条款。
import inspect # 导入 inspect 模块,用于检查对象的属性和方法
from typing import Callable, List, Optional, Union # 导入类型注释,便于类型检查
import numpy as np # 导入 NumPy 库,用于数组和矩阵操作
import PIL.Image # 导入 PIL 图像处理库
import torch # 导入 PyTorch 库,用于深度学习
from transformers import CLIPImageProcessor, CLIPTokenizer # 导入 Transformers 库中的图像处理和标记器
from ...configuration_utils import FrozenDict # 导入 FrozenDict,用于不可变字典
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler # 导入调度器类
from ...utils import PIL_INTERPOLATION, deprecate, logging # 导入工具函数和日志模块
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel # 导入 ONNX 相关工具
from ..pipeline_utils import DiffusionPipeline # 导入扩散管道基类
from . import StableDiffusionPipelineOutput # 导入稳定扩散管道输出类
logger = logging.get_logger(__name__) # 创建一个记录器,使用当前模块名称进行日志记录
NUM_UNET_INPUT_CHANNELS = 9 # 定义 UNet 输入通道的数量
NUM_LATENT_CHANNELS = 4 # 定义潜在通道的数量
def prepare_mask_and_masked_image(image, mask, latents_shape): # 定义准备掩模和掩模图像的函数
# 将输入图像转换为 RGB 格式,并调整大小以适应潜在形状
image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
# 调整数组形状以适配深度学习模型的输入要求
image = image[None].transpose(0, 3, 1, 2)
# 将图像数据类型转换为 float32,并归一化到 [-1, 1] 范围
image = image.astype(np.float32) / 127.5 - 1.0
# 将掩模图像转换为灰度并调整大小
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
# 应用掩模到图像,得到掩模图像
masked_image = image * (image_mask < 127.5)
# 调整掩模大小以匹配潜在形状,并转换为灰度格式
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
mask = np.array(mask.convert("L"))
# 将掩模数据类型转换为 float32,并归一化到 [0, 1] 范围
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None] # 添加维度以匹配模型输入要求
# 将小于 0.5 的值设为 0
mask[mask < 0.5] = 0
# 将大于等于 0.5 的值设为 1
mask[mask >= 0.5] = 1
return mask, masked_image # 返回处理后的掩模和掩模图像
class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): # 定义用于图像修补的扩散管道类
r"""
使用稳定扩散进行文本引导的图像修补管道。*这是一个实验特性*。
此模型继承自 [`DiffusionPipeline`]。请查看超类文档,以获取库为所有管道实现的通用方法
(例如下载或保存,在特定设备上运行等)。
# 文档字符串,定义类的参数和它们的类型
Args:
vae ([`AutoencoderKL`]): # 变分自编码器模型,用于编码和解码图像及其潜在表示
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]): # 冻结的文本编码器,用于处理文本输入
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`): # 处理文本的标记器
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): # 条件 U-Net 架构,用于去噪编码的图像潜在表示
Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]): # 调度器,与 `unet` 一起用于去噪图像潜在表示
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]): # 分类模块,评估生成图像是否可能被视为冒犯或有害
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]): # 从生成图像中提取特征的模型,用于 `safety_checker` 的输入
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
# 定义多个模型的类型,用于保存各个组件的实例
vae_encoder: OnnxRuntimeModel # 编码器模型的类型
vae_decoder: OnnxRuntimeModel # 解码器模型的类型
text_encoder: OnnxRuntimeModel # 文本编码器模型的类型
tokenizer: CLIPTokenizer # 文本标记器的类型
unet: OnnxRuntimeModel # U-Net 模型的类型
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] # 调度器的类型,可以是多种类型之一
safety_checker: OnnxRuntimeModel # 安全检查器模型的类型
feature_extractor: CLIPImageProcessor # 特征提取器模型的类型
_optional_components = ["safety_checker", "feature_extractor"] # 可选组件的列表
_is_onnx = True # 指示当前模型是否为 ONNX 格式
# 构造函数,用于初始化类的实例
def __init__(
self,
vae_encoder: OnnxRuntimeModel, # 传入编码器模型实例
vae_decoder: OnnxRuntimeModel, # 传入解码器模型实例
text_encoder: OnnxRuntimeModel, # 传入文本编码器模型实例
tokenizer: CLIPTokenizer, # 传入文本标记器实例
unet: OnnxRuntimeModel, # 传入 U-Net 模型实例
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], # 传入调度器实例
safety_checker: OnnxRuntimeModel, # 传入安全检查器模型实例
feature_extractor: CLIPImageProcessor, # 传入特征提取器模型实例
requires_safety_checker: bool = True, # 指示是否需要安全检查器的布尔参数
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
# 编码提示的函数
def _encode_prompt(
self,
prompt: Union[str, List[str]], # 输入的提示,可以是字符串或字符串列表
num_images_per_prompt: Optional[int], # 每个提示生成的图像数量,默认为可选
do_classifier_free_guidance: bool, # 是否执行无分类器自由引导的布尔参数
negative_prompt: Optional[str], # 可选的负面提示
prompt_embeds: Optional[np.ndarray] = None, # 可选的提示嵌入,默认为 None
negative_prompt_embeds: Optional[np.ndarray] = None, # 可选的负面提示嵌入,默认为 None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs
# 定义一个检查输入参数的函数,确保所有输入都符合预期
def check_inputs(
self, # 指向类实例的引用
prompt: Union[str, List[str]], # 输入的提示,类型为字符串或字符串列表
height: Optional[int], # 图像高度,类型为可选整数
width: Optional[int], # 图像宽度,类型为可选整数
callback_steps: int, # 回调步骤数,类型为整数
negative_prompt: Optional[str] = None, # 负提示,类型为可选字符串
prompt_embeds: Optional[np.ndarray] = None, # 提示的嵌入表示,类型为可选numpy数组
negative_prompt_embeds: Optional[np.ndarray] = None, # 负提示的嵌入表示,类型为可选numpy数组
):
# 检查高度和宽度是否都能被8整除
if height % 8 != 0 or width % 8 != 0:
# 如果不能整除,抛出值错误
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# 检查回调步骤是否为正整数
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
# 如果不是正整数,抛出值错误
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 检查提示和提示嵌入是否同时存在
if prompt is not None and prompt_embeds is not None:
# 如果两者都存在,抛出值错误
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
# 检查提示和提示嵌入是否同时为None
elif prompt is None and prompt_embeds is None:
# 如果都是None,抛出值错误
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
# 检查提示的类型是否为字符串或列表
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
# 如果不是,抛出值错误
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 检查负提示和负提示嵌入是否同时存在
if negative_prompt is not None and negative_prompt_embeds is not None:
# 如果同时存在,抛出值错误
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
# 检查提示嵌入和负提示嵌入的形状是否一致
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
# 如果形状不一致,抛出值错误
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# 在不计算梯度的情况下进行操作,节省内存和计算资源
@torch.no_grad()
# 定义一个可调用的类方法,用于处理图像生成
def __call__(
self,
# 用户输入的提示,可以是字符串或字符串列表
prompt: Union[str, List[str]],
# 输入的图像,类型为 PIL.Image.Image
image: PIL.Image.Image,
# 掩模图像,类型为 PIL.Image.Image
mask_image: PIL.Image.Image,
# 输出图像的高度,默认为 512
height: Optional[int] = 512,
# 输出图像的宽度,默认为 512
width: Optional[int] = 512,
# 推理步骤的数量,默认为 50
num_inference_steps: int = 50,
# 指导尺度,默认为 7.5
guidance_scale: float = 7.5,
# 负提示,可以是字符串或字符串列表,默认为 None
negative_prompt: Optional[Union[str, List[str]]] = None,
# 每个提示生成的图像数量,默认为 1
num_images_per_prompt: Optional[int] = 1,
# 噪声的 eta 值,默认为 0.0
eta: float = 0.0,
# 随机数生成器,默认为 None
generator: Optional[np.random.RandomState] = None,
# 潜在表示,默认为 None
latents: Optional[np.ndarray] = None,
# 提示嵌入,默认为 None
prompt_embeds: Optional[np.ndarray] = None,
# 负提示嵌入,默认为 None
negative_prompt_embeds: Optional[np.ndarray] = None,
# 输出类型,默认为 "pil"
output_type: Optional[str] = "pil",
# 是否返回字典,默认为 True
return_dict: bool = True,
# 回调函数,默认为 None
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
# 回调步骤的间隔,默认为 1
callback_steps: int = 1,
.\diffusers\pipelines\stable_diffusion\pipeline_onnx_stable_diffusion_upscale.py
# 版权声明,标明版权和许可信息
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache 许可证版本 2.0("许可证")授权;
# 除非遵循该许可证,否则不可使用此文件。
# 可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非根据适用法律或书面协议另有约定,
# 否则根据许可证分发的软件是基于“原样”提供的,
# 不附带任何明示或暗示的担保或条件。
# 有关许可证具体条款的信息,见许可证。
import inspect # 导入inspect模块,用于获取对象的信息
from typing import Any, Callable, List, Optional, Union # 导入类型注解
import numpy as np # 导入numpy库,用于数值计算
import PIL.Image # 导入PIL库,用于图像处理
import torch # 导入PyTorch库,用于深度学习
from transformers import CLIPImageProcessor, CLIPTokenizer # 从transformers库导入图像处理和标记器
from ...configuration_utils import FrozenDict # 导入FrozenDict,用于配置管理
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers # 导入调度器,用于模型训练
from ...utils import deprecate, logging # 导入工具模块,用于日志记录和弃用警告
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel # 导入ONNX相关工具和模型类
from ..pipeline_utils import DiffusionPipeline # 导入扩散管道类
from . import StableDiffusionPipelineOutput # 导入稳定扩散管道输出类
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
def preprocess(image): # 定义预处理函数,接收图像作为输入
if isinstance(image, torch.Tensor): # 检查图像是否为PyTorch张量
return image # 如果是,则直接返回
elif isinstance(image, PIL.Image.Image): # 检查图像是否为PIL图像
image = [image] # 将其封装为列表
if isinstance(image[0], PIL.Image.Image): # 检查列表中的第一个元素是否为PIL图像
w, h = image[0].size # 获取图像的宽度和高度
w, h = (x - x % 64 for x in (w, h)) # 调整宽高,使其为64的整数倍
image = [np.array(i.resize((w, h)))[None, :] for i in image] # 调整所有图像大小并转为数组
image = np.concatenate(image, axis=0) # 将所有图像数组沿第0轴合并
image = np.array(image).astype(np.float32) / 255.0 # 转换为浮点数并归一化到[0, 1]
image = image.transpose(0, 3, 1, 2) # 变换数组维度为[batch, channels, height, width]
image = 2.0 * image - 1.0 # 将值归一化到[-1, 1]
image = torch.from_numpy(image) # 转换为PyTorch张量
elif isinstance(image[0], torch.Tensor): # 如果列表中的第一个元素是PyTorch张量
image = torch.cat(image, dim=0) # 在第0维连接所有张量
return image # 返回处理后的图像
class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): # 定义ONNX稳定扩散上采样管道类,继承自DiffusionPipeline
vae: OnnxRuntimeModel # 定义变分自编码器模型
text_encoder: OnnxRuntimeModel # 定义文本编码器模型
tokenizer: CLIPTokenizer # 定义CLIP标记器
unet: OnnxRuntimeModel # 定义U-Net模型
low_res_scheduler: DDPMScheduler # 定义低分辨率调度器
scheduler: KarrasDiffusionSchedulers # 定义Karras扩散调度器
safety_checker: OnnxRuntimeModel # 定义安全检查模型
feature_extractor: CLIPImageProcessor # 定义特征提取器
_optional_components = ["safety_checker", "feature_extractor"] # 可选组件列表
_is_onnx = True # 指示该类是否为ONNX格式
def __init__( # 定义构造函数
self,
vae: OnnxRuntimeModel, # 变分自编码器模型
text_encoder: OnnxRuntimeModel, # 文本编码器模型
tokenizer: Any, # 任意类型的标记器
unet: OnnxRuntimeModel, # U-Net模型
low_res_scheduler: DDPMScheduler, # 低分辨率调度器
scheduler: KarrasDiffusionSchedulers, # Karras调度器
safety_checker: Optional[OnnxRuntimeModel] = None, # 可选的安全检查模型
feature_extractor: Optional[CLIPImageProcessor] = None, # 可选的特征提取器
max_noise_level: int = 350, # 最大噪声级别
num_latent_channels=4, # 潜在通道数量
num_unet_input_channels=7, # U-Net输入通道数量
requires_safety_checker: bool = True, # 是否需要安全检查器
# 定义一个检查输入参数的函数,确保输入有效
def check_inputs(
self, # 表示该方法属于某个类
prompt: Union[str, List[str]], # 输入的提示,支持字符串或字符串列表
image, # 输入的图像,类型不固定
noise_level, # 噪声级别,通常用于控制生成图像的噪声程度
callback_steps, # 回调步数,用于更新或监控生成过程
negative_prompt=None, # 可选的负面提示,控制生成内容的方向
prompt_embeds=None, # 可选的提示嵌入,直接传入嵌入向量
negative_prompt_embeds=None, # 可选的负面提示嵌入,直接传入嵌入向量
# 定义一个准备潜在变量的函数,用于生成图像的潜在表示
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
# 定义潜在变量的形状,根据批大小、通道数、高度和宽度
shape = (batch_size, num_channels_latents, height, width)
# 如果没有提供潜在变量,则生成新的随机潜在变量
if latents is None:
latents = generator.randn(*shape).astype(dtype) # 从生成器中生成随机潜在变量并转换为指定数据类型
# 如果提供的潜在变量形状不符合预期,则引发错误
elif latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
return latents # 返回准备好的潜在变量
# 定义一个解码潜在变量的函数,将潜在表示转换为图像
def decode_latents(self, latents):
# 调整潜在变量的尺度,以匹配解码器的输入要求
latents = 1 / 0.08333 * latents
# 使用变分自编码器(VAE)解码潜在变量,获取生成的图像
image = self.vae(latent_sample=latents)[0]
# 将图像值缩放到 [0, 1] 范围内,并进行裁剪
image = np.clip(image / 2 + 0.5, 0, 1)
# 调整图像的维度顺序,从 (N, C, H, W) 转换为 (N, H, W, C)
image = image.transpose((0, 2, 3, 1))
return image # 返回解码后的图像
# 定义一个编码提示的函数,将文本提示转换为嵌入向量
def _encode_prompt(
self,
prompt: Union[str, List[str]], # 输入的提示,支持字符串或字符串列表
num_images_per_prompt: Optional[int], # 每个提示生成的图像数量
do_classifier_free_guidance: bool, # 是否进行无分类器引导
negative_prompt: Optional[str], # 可选的负面提示
prompt_embeds: Optional[np.ndarray] = None, # 可选的提示嵌入
negative_prompt_embeds: Optional[np.ndarray] = None, # 可选的负面提示嵌入
# 定义一个调用函数,用于生成图像
def __call__(
self,
prompt: Union[str, List[str]], # 输入的提示,支持字符串或字符串列表
image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]], # 输入的图像,可以是 ndarray 或 PIL 图像
num_inference_steps: int = 75, # 推理步骤的数量,默认设置为 75
guidance_scale: float = 9.0, # 引导的缩放因子,控制生成图像的质量
noise_level: int = 20, # 噪声级别,影响生成图像的随机性
negative_prompt: Optional[Union[str, List[str]]] = None, # 可选的负面提示
num_images_per_prompt: Optional[int] = 1, # 每个提示生成的图像数量,默认设置为 1
eta: float = 0.0, # 控制随机性和确定性的超参数
generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None, # 随机数生成器
latents: Optional[np.ndarray] = None, # 可选的潜在变量
prompt_embeds: Optional[np.ndarray] = None, # 可选的提示嵌入
negative_prompt_embeds: Optional[np.ndarray] = None, # 可选的负面提示嵌入
output_type: Optional[str] = "pil", # 输出类型,默认设置为 PIL 图像
return_dict: bool = True, # 是否以字典形式返回结果
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, # 可选的回调函数,用于处理生成过程中的状态
callback_steps: Optional[int] = 1, # 回调的步数,控制回调的频率
.\diffusers\pipelines\stable_diffusion\pipeline_output.py
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入 List, Optional, Union 类型注解
from typing import List, Optional, Union
# 导入 numpy 库并重命名为 np
import numpy as np
# 导入 PIL.Image 模块
import PIL.Image
# 从上层模块导入 BaseOutput 和 is_flax_available 函数
from ...utils import BaseOutput, is_flax_available
# 定义一个数据类,作为 Stable Diffusion 管道的输出
@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
"""
Stable Diffusion 管道的输出类。
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
包含去噪后的 PIL 图像列表,长度为 `batch_size`,或形状为 `(batch_size, height, width,
num_channels)` 的 NumPy 数组。
nsfw_content_detected (`List[bool]`)
指示对应生成图像是否包含“不可安全观看” (nsfw) 内容的列表,若无法进行安全检查则为 `None`。
"""
# 存储图像,类型为 PIL 图像列表或 NumPy 数组
images: Union[List[PIL.Image.Image], np.ndarray]
# 存储 nsfw 内容检测结果的可选列表
nsfw_content_detected: Optional[List[bool]]
# 检查是否可用 Flax 库
if is_flax_available():
# 导入 flax 库
import flax
# 定义一个数据类,作为 Flax 基于 Stable Diffusion 管道的输出
@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Flax 基于 Stable Diffusion 管道的输出类。
Args:
images (`np.ndarray`):
形状为 `(batch_size, height, width, num_channels)` 的去噪图像数组。
nsfw_content_detected (`List[bool]`):
指示对应生成图像是否包含“不可安全观看” (nsfw) 内容的列表,
或 `None` 如果无法进行安全检查。
"""
# 存储图像,类型为 NumPy 数组
images: np.ndarray
# 存储 nsfw 内容检测结果的列表
nsfw_content_detected: List[bool]
.\diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion.py
# 版权声明,表明此文件的版权归 HuggingFace 团队所有
#
# 根据 Apache License 2.0 许可协议进行授权;
# 除非遵循此许可协议,否则不得使用此文件。
# 可以通过以下网址获取许可的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,根据该许可协议分发的软件在“按原样”基础上分发,
# 不提供任何明示或暗示的担保或条件。
# 详细信息请参见许可协议中关于权限和限制的具体条款。
import inspect # 导入用于检查对象的模块
from typing import Any, Callable, Dict, List, Optional, Union # 导入类型注解工具
import torch # 导入 PyTorch 库
from packaging import version # 导入版本管理工具
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection # 导入特定的转换器模型
from ...callbacks import MultiPipelineCallbacks, PipelineCallback # 导入多管道回调相关类
from ...configuration_utils import FrozenDict # 导入不可变字典工具
from ...image_processor import PipelineImageInput, VaeImageProcessor # 导入图像处理相关类
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin # 导入加载器混合类
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel # 导入不同模型
from ...models.lora import adjust_lora_scale_text_encoder # 导入用于调整文本编码器的 LoRA 函数
from ...schedulers import KarrasDiffusionSchedulers # 导入 Karras 扩散调度器
from ...utils import ( # 导入多个实用工具
USE_PEFT_BACKEND, # 指定使用 PEFT 后端的标志
deprecate, # 导入弃用装饰器
logging, # 导入日志工具
replace_example_docstring, # 导入替换示例文档字符串的工具
scale_lora_layers, # 导入缩放 LoRA 层的工具
unscale_lora_layers, # 导入反缩放 LoRA 层的工具
)
from ...utils.torch_utils import randn_tensor # 导入生成随机张量的工具
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin # 导入扩散管道相关类
from .pipeline_output import StableDiffusionPipelineOutput # 导入稳定扩散管道输出类
from .safety_checker import StableDiffusionSafetyChecker # 导入稳定扩散安全检查器类
logger = logging.get_logger(__name__) # 创建当前模块的日志记录器,禁止 pylint 检查命名
EXAMPLE_DOC_STRING = """ # 示例文档字符串,展示使用示例
Examples:
```py
>>> import torch # 导入 PyTorch 库
>>> from diffusers import StableDiffusionPipeline # 从 diffusers 导入稳定扩散管道
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) # 从预训练模型加载管道并设置数据类型
>>> pipe = pipe.to("cuda") # 将管道转移到 GPU
>>> prompt = "a photo of an astronaut riding a horse on mars" # 定义文本提示
>>> image = pipe(prompt).images[0] # 生成图像并提取第一张图像
```py
"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # 定义噪声配置重标定函数
"""
根据 `guidance_rescale` 对 `noise_cfg` 进行重标定。基于论文[Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)的发现。参见第 3.4 节
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) # 计算文本噪声的标准差,保持维度
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # 计算噪声配置的标准差,保持维度
# 使用文本标准差调整噪声配置,以修正曝光过度的问题
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # 进行重标定
# 按照指导比例将重标定的噪声与原始噪声混合,避免生成“平面”图像
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg # 更新噪声配置
# 返回噪声配置对象
return noise_cfg
# 定义一个函数用于检索时间步,接受多个参数
def retrieve_timesteps(
# 调度器实例,用于获取时间步
scheduler,
# 可选的推理步骤数量,默认为 None
num_inference_steps: Optional[int] = None,
# 可选的设备参数,可以是字符串或 torch.device,默认为 None
device: Optional[Union[str, torch.device]] = None,
# 可选的自定义时间步,默认为 None
timesteps: Optional[List[int]] = None,
# 可选的自定义 sigma 值,默认为 None
sigmas: Optional[List[float]] = None,
# 其他可选参数,将传递给调度器的 set_timesteps 方法
**kwargs,
):
"""
调用调度器的 `set_timesteps` 方法,并在调用后从调度器获取时间步。处理自定义时间步。
所有 kwargs 将传递给 `scheduler.set_timesteps`。
参数:
scheduler (`SchedulerMixin`):
用于获取时间步的调度器。
num_inference_steps (`int`):
生成样本时使用的扩散步骤数量。如果使用此参数,则 `timesteps` 必须为 `None`。
device (`str` 或 `torch.device`, *可选*):
时间步应移动到的设备。如果为 `None`,则不移动时间步。
timesteps (`List[int]`, *可选*):
用于覆盖调度器时间步间距策略的自定义时间步。如果传入 `timesteps`,
`num_inference_steps` 和 `sigmas` 必须为 `None`。
sigmas (`List[float]`, *可选*):
用于覆盖调度器时间步间距策略的自定义 sigma 值。如果传入 `sigmas`,
`num_inference_steps` 和 `timesteps` 必须为 `None`。
返回:
`Tuple[torch.Tensor, int]`: 一个元组,第一个元素是调度器的时间步调度,
第二个元素是推理步骤的数量。
"""
# 检查是否同时传入了自定义时间步和 sigma 值,如果是则抛出异常
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
# 如果传入了自定义时间步
if timesteps is not None:
# 检查调度器的 set_timesteps 方法是否接受时间步参数
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
# 如果不接受,抛出异常
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
# 调用调度器的 set_timesteps 方法设置自定义时间步
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
# 获取调度器中的时间步
timesteps = scheduler.timesteps
# 计算推理步骤数量
num_inference_steps = len(timesteps)
# 如果传入了自定义 sigma 值
elif sigmas is not None:
# 检查调度器的 set_timesteps 方法是否接受 sigma 参数
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
# 如果不接受,抛出异常
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
# 调用调度器的 set_timesteps 方法设置自定义 sigma 值
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
# 获取调度器中的时间步
timesteps = scheduler.timesteps
# 计算推理步骤数量
num_inference_steps = len(timesteps)
# 如果没有传入自定义时间步或 sigma
else:
# 调用调度器的 set_timesteps 方法,使用推理步骤数量
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
# 获取调度器中的时间步
timesteps = scheduler.timesteps
# 返回时间步长和推理步骤的数量
return timesteps, num_inference_steps
# 定义一个名为 StableDiffusionPipeline 的类,继承多个混入类
class StableDiffusionPipeline(
# 继承 DiffusionPipeline 基础功能
DiffusionPipeline,
# 继承稳定扩散特有的功能
StableDiffusionMixin,
# 继承文本反演加载功能
TextualInversionLoaderMixin,
# 继承 LoRA 加载功能
StableDiffusionLoraLoaderMixin,
# 继承 IP 适配器功能
IPAdapterMixin,
# 继承从单一文件加载功能
FromSingleFileMixin,
):
# 文档字符串,描述该类用于文本到图像生成
r"""
Pipeline for text-to-image generation using Stable Diffusion.
# 说明此模型继承自 DiffusionPipeline,并指出可以查看超类文档获取通用方法
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
# 说明该管道也继承了多种加载方法
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
# 参数说明,定义构造函数需要的各类参数及其类型
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
# 定义一个字符串,表示模型在 CPU 上的卸载顺序
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
# 定义可选组件的列表
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
# 定义不包含在 CPU 卸载中的组件
_exclude_from_cpu_offload = ["safety_checker"]
# 定义回调张量输入的列表
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# 初始化方法,构造类的实例并接受多个参数
def __init__(
self,
vae: AutoencoderKL, # 变分自编码器,用于图像生成
text_encoder: CLIPTextModel, # 文本编码器,用于将文本转换为嵌入
tokenizer: CLIPTokenizer, # 分词器,用于处理文本数据
unet: UNet2DConditionModel, # UNet模型,用于条件生成
scheduler: KarrasDiffusionSchedulers, # 调度器,控制生成过程中的步伐
safety_checker: StableDiffusionSafetyChecker, # 安全检查器,确保生成内容符合安全标准
feature_extractor: CLIPImageProcessor, # 特征提取器,用于处理图像
image_encoder: CLIPVisionModelWithProjection = None, # 可选图像编码器,用于图像嵌入
requires_safety_checker: bool = True, # 是否需要安全检查器,默认为True
def _encode_prompt(
self,
prompt, # 输入的提示文本
device, # 设备信息,指定运行的硬件
num_images_per_prompt, # 每个提示生成的图像数量
do_classifier_free_guidance, # 是否使用无分类器的引导
negative_prompt=None, # 可选的负面提示文本
prompt_embeds: Optional[torch.Tensor] = None, # 可选的提示嵌入
negative_prompt_embeds: Optional[torch.Tensor] = None, # 可选的负面提示嵌入
lora_scale: Optional[float] = None, # 可选的Lora缩放因子
**kwargs, # 其他可选参数
):
# 生成弃用警告信息,提示用户该方法将被移除
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
# 发出弃用警告,通知版本号和信息
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
# 调用新的编码方法,获取提示嵌入的元组
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt, # 输入提示
device=device, # 设备信息
num_images_per_prompt=num_images_per_prompt, # 图像数量
do_classifier_free_guidance=do_classifier_free_guidance, # 引导选项
negative_prompt=negative_prompt, # 负面提示
prompt_embeds=prompt_embeds, # 提示嵌入
negative_prompt_embeds=negative_prompt_embeds, # 负面提示嵌入
lora_scale=lora_scale, # Lora缩放因子
**kwargs, # 其他参数
)
# 将元组中的提示嵌入连接为一个张量,便于后续处理
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
# 返回连接后的提示嵌入
return prompt_embeds
# 新的编码提示方法,接受多个参数以生成提示嵌入
def encode_prompt(
self,
prompt, # 输入的提示文本
device, # 设备信息,指定运行的硬件
num_images_per_prompt, # 每个提示生成的图像数量
do_classifier_free_guidance, # 是否使用无分类器的引导
negative_prompt=None, # 可选的负面提示文本
prompt_embeds: Optional[torch.Tensor] = None, # 可选的提示嵌入
negative_prompt_embeds: Optional[torch.Tensor] = None, # 可选的负面提示嵌入
lora_scale: Optional[float] = None, # 可选的Lora缩放因子
clip_skip: Optional[int] = None, # 可选的跳过参数,用于调节处理流程
# 定义一个方法用于编码图像,接受图像、设备、每个提示的图像数量及可选的隐藏状态输出
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
# 获取图像编码器参数的数据类型
dtype = next(self.image_encoder.parameters()).dtype
# 如果输入的图像不是张量,则通过特征提取器转换为张量
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
# 将图像移动到指定设备,并转换为正确的数据类型
image = image.to(device=device, dtype=dtype)
# 如果需要输出隐藏状态
if output_hidden_states:
# 通过图像编码器处理图像,获取倒数第二层的隐藏状态
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
# 将隐藏状态重复,以适应每个提示的图像数量
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
# 创建一个与输入图像大小相同的零张量,通过图像编码器获取未条件化的隐藏状态
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
# 将未条件化的隐藏状态重复,以适应每个提示的图像数量
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
# 返回编码后的图像隐藏状态和未条件化图像的隐藏状态
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
# 通过图像编码器处理图像,获取图像嵌入
image_embeds = self.image_encoder(image).image_embeds
# 将图像嵌入重复,以适应每个提示的图像数量
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
# 创建一个与图像嵌入相同形状的零张量作为未条件化的图像嵌入
uncond_image_embeds = torch.zeros_like(image_embeds)
# 返回图像嵌入和未条件化图像嵌入
return image_embeds, uncond_image_embeds
# 定义一个方法用于准备 IP 适配器的图像嵌入,接受 IP 适配器图像、图像嵌入、设备、每个提示的图像数量及是否进行分类自由引导
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
# 初始化图像嵌入列表
image_embeds = []
# 如果启用无分类器自由引导,则初始化负图像嵌入列表
if do_classifier_free_guidance:
negative_image_embeds = []
# 如果输入适配器图像嵌入为 None
if ip_adapter_image_embeds is None:
# 如果输入适配器图像不是列表,则将其转换为列表
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
# 检查输入适配器图像的长度是否与 IP 适配器数量相同
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
# 抛出错误,说明输入适配器图像长度不匹配
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
# 遍历输入适配器图像和对应的图像投影层
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
# 判断是否输出隐藏状态
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
# 编码单个图像,返回图像嵌入和负图像嵌入
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
# 将图像嵌入添加到列表中
image_embeds.append(single_image_embeds[None, :])
# 如果启用无分类器自由引导,则将负图像嵌入添加到列表中
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
# 遍历已存在的输入适配器图像嵌入
for single_image_embeds in ip_adapter_image_embeds:
# 如果启用无分类器自由引导,分离负图像嵌入和图像嵌入
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
# 将图像嵌入添加到列表中
image_embeds.append(single_image_embeds)
# 初始化适配器图像嵌入列表
ip_adapter_image_embeds = []
# 遍历图像嵌入列表
for i, single_image_embeds in enumerate(image_embeds):
# 将单个图像嵌入复制指定次数以适应每个提示的图像数量
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
# 如果启用无分类器自由引导,复制负图像嵌入
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
# 将负图像嵌入和图像嵌入拼接
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
# 将图像嵌入移动到指定设备
single_image_embeds = single_image_embeds.to(device=device)
# 将处理后的图像嵌入添加到列表中
ip_adapter_image_embeds.append(single_image_embeds)
# 返回适配器图像嵌入列表
return ip_adapter_image_embeds
# 安全检查器的执行方法
def run_safety_checker(self, image, device, dtype):
# 如果没有安全检查器,初始化不适合的概念为 None
if self.safety_checker is None:
has_nsfw_concept = None
else:
# 如果输入图像是张量,则进行后处理
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
# 如果输入图像是 numpy 数组,则转换为 PIL 格式
feature_extractor_input = self.image_processor.numpy_to_pil(image)
# 将处理后的输入图像提取特征并移动到设备
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
# 运行安全检查器,获取处理后的图像和不适合的概念标志
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
# 返回处理后的图像和不适合的概念标志
return image, has_nsfw_concept
# 解码潜在表示
def decode_latents(self, latents):
# 定义弃用提示信息
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
# 调用弃用函数,提示用户
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
# 根据缩放因子调整潜在表示
latents = 1 / self.vae.config.scaling_factor * latents
# 解码潜在表示,返回图像
image = self.vae.decode(latents, return_dict=False)[0]
# 将图像值缩放到[0, 1]范围内
image = (image / 2 + 0.5).clamp(0, 1)
# 转换图像为 float32 类型以确保兼容性
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# 返回最终图像
return image
# 准备额外的步骤参数
def prepare_extra_step_kwargs(self, generator, eta):
# 为调度器步骤准备额外的关键字参数,因不同调度器的参数签名不同
# eta (η) 仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
# eta 对应 DDIM 论文中的 η,应在 [0, 1] 之间
# 检查调度器是否接受 eta 参数
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
# 如果接受,则将 eta 添加到额外参数中
extra_step_kwargs["eta"] = eta
# 检查调度器是否接受 generator 参数
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
# 如果接受,则将 generator 添加到额外参数中
extra_step_kwargs["generator"] = generator
# 返回准备好的额外参数
return extra_step_kwargs
# 检查输入参数的有效性
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
# 方法体未提供,此处无进一步操作
# 准备潜在表示
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
# 根据批大小和图像尺寸定义潜在表示的形状
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
# 如果生成器列表长度与批大小不匹配,抛出错误
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# 如果未提供潜在表示,则随机生成
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# 将已提供的潜在表示转移到指定设备
latents = latents.to(device)
# 按调度器要求的标准差缩放初始噪声
latents = latents * self.scheduler.init_noise_sigma
# 返回准备好的潜在表示
return latents
# 从 latent_consistency_models 获取指导尺度嵌入的方法复制
# 定义生成指导缩放嵌入的函数,接受张量 w 和其他参数
def get_guidance_scale_embedding(
# 输入参数 w,为一维的张量
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
"""
参考链接,提供生成嵌入向量的信息
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
用指定的指导缩放生成嵌入向量,以此丰富时间步嵌入。
embedding_dim (`int`, *optional*, defaults to 512):
要生成的嵌入的维度。
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
生成的嵌入的数据类型。
Returns:
`torch.Tensor`: 嵌入向量,形状为 `(len(w), embedding_dim)`。
"""
# 确保输入张量 w 是一维的
assert len(w.shape) == 1
# 将 w 的值乘以 1000.0
w = w * 1000.0
# 计算嵌入的半维度
half_dim = embedding_dim // 2
# 计算常量 emb,用于后续的指数计算
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
# 生成半维度的指数衰减嵌入
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
# 将 w 转换为指定数据类型,并与 emb 进行广播相乘
emb = w.to(dtype)[:, None] * emb[None, :]
# 将正弦和余弦嵌入拼接在一起
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# 如果嵌入维度是奇数,则在最后添加零填充
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
# 确保最终嵌入的形状符合预期
assert emb.shape == (w.shape[0], embedding_dim)
# 返回生成的嵌入
return emb
# 定义属性,返回指导缩放值
@property
def guidance_scale(self):
return self._guidance_scale
# 定义属性,返回指导重缩放值
@property
def guidance_rescale(self):
return self._guidance_rescale
# 定义属性,返回跨注意力的值
@property
def clip_skip(self):
return self._clip_skip
# 定义属性,判断是否进行无分类器引导
# 这里 `guidance_scale` 是类似于 Imagen 论文中方程 (2) 的指导权重 `w`
# 当 `guidance_scale = 1` 时,相当于不进行分类器无引导
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
# 定义属性,返回跨注意力的关键字参数
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
# 定义属性,返回时间步数
@property
def num_timesteps(self):
return self._num_timesteps
# 定义属性,返回中断标志
@property
def interrupt(self):
return self._interrupt
# 指定在此上下文中不计算梯度
@torch.no_grad()
# 替换示例文档字符串
@replace_example_docstring(EXAMPLE_DOC_STRING)
# 定义一个可调用的类方法
def __call__(
# 输入提示,可以是字符串或字符串列表,默认为 None
self,
prompt: Union[str, List[str]] = None,
# 输出图像的高度,默认为 None
height: Optional[int] = None,
# 输出图像的宽度,默认为 None
width: Optional[int] = None,
# 进行推理的步骤数,默认为 50
num_inference_steps: int = 50,
# 指定的时间步列表,默认为 None
timesteps: List[int] = None,
# 指定的 sigma 值列表,默认为 None
sigmas: List[float] = None,
# 指导尺度,默认为 7.5
guidance_scale: float = 7.5,
# 负面提示,可以是字符串或字符串列表,默认为 None
negative_prompt: Optional[Union[str, List[str]]] = None,
# 每个提示生成的图像数量,默认为 1
num_images_per_prompt: Optional[int] = 1,
# 控制生成过程中的随机性,默认为 0.0
eta: float = 0.0,
# 随机数生成器,可以是单个或多个 PyTorch 生成器,默认为 None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# 潜在变量张量,默认为 None
latents: Optional[torch.Tensor] = None,
# 提示嵌入的张量,默认为 None
prompt_embeds: Optional[torch.Tensor] = None,
# 负面提示嵌入的张量,默认为 None
negative_prompt_embeds: Optional[torch.Tensor] = None,
# 输入适配器图像,默认为 None
ip_adapter_image: Optional[PipelineImageInput] = None,
# 输入适配器图像的嵌入列表,默认为 None
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
# 输出类型,默认为 "pil"(Python Imaging Library)
output_type: Optional[str] = "pil",
# 是否返回字典,默认为 True
return_dict: bool = True,
# 交叉注意力的额外参数,默认为 None
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 指导重缩放因子,默认为 0.0
guidance_rescale: float = 0.0,
# 跳过的剪辑层,默认为 None
clip_skip: Optional[int] = None,
# 步骤结束时的回调函数,可以是多种类型,默认为 None
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
# 步骤结束时的张量输入列表,默认为 ["latents"]
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
# 额外的关键字参数
**kwargs,


浙公网安备 33010602011771号