diffusers 源码解析(四十四)
.\diffusers\pipelines\stable_cascade\__init__.py
# 从类型检查模块导入常量
from typing import TYPE_CHECKING
# 从上级目录导入实用工具函数和常量
from ...utils import (
DIFFUSERS_SLOW_IMPORT, # 慢导入标志
OptionalDependencyNotAvailable, # 可选依赖不可用异常
_LazyModule, # 懒加载模块工具
get_objects_from_module, # 从模块获取对象的函数
is_torch_available, # 检查 PyTorch 是否可用
is_transformers_available, # 检查 Transformers 是否可用
)
# 初始化一个空字典,用于存储占位符对象
_dummy_objects = {}
# 初始化一个空字典,用于存储导入结构
_import_structure = {}
# 尝试检查依赖项是否可用
try:
# 如果 Transformers 和 PyTorch 不可用,抛出异常
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用异常
except OptionalDependencyNotAvailable:
# 从实用工具导入占位符对象
from ...utils import dummy_torch_and_transformers_objects
# 更新占位符对象字典
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
# 如果依赖项可用,更新导入结构字典
_import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"]
_import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"]
_import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"]
# 根据类型检查或慢导入标志执行以下操作
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
# 尝试检查依赖项是否可用
try:
# 如果 Transformers 和 PyTorch 不可用,抛出异常
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用异常
except OptionalDependencyNotAvailable:
# 导入占位符对象以避免错误
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
# 导入可用的管道类
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
else:
# 如果不进行类型检查和慢导入,执行懒加载
import sys
# 用懒加载模块替换当前模块
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
# 设置当前模块的占位符对象
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
.\diffusers\pipelines\stable_diffusion\clip_image_project_model.py
# 版权所有 2024 GLIGEN 作者和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律规定或书面同意,依据许可证分发的软件是以“原样”基础进行的,
# 不提供任何形式的保证或条件,无论是明示的还是暗示的。
# 请参阅许可证以获取有关权限和限制的具体条款。
# 从 PyTorch 的 nn 模块导入神经网络相关功能
from torch import nn
# 从配置工具导入 ConfigMixin 和注册配置的装饰器
from ...configuration_utils import ConfigMixin, register_to_config
# 从模型工具导入 ModelMixin,用于模型相关功能
from ...models.modeling_utils import ModelMixin
# 定义一个 CLIP 图像投影类,继承 ModelMixin 和 ConfigMixin
class CLIPImageProjection(ModelMixin, ConfigMixin):
# 使用装饰器注册配置,定义初始化方法
@register_to_config
def __init__(self, hidden_size: int = 768):
# 调用父类的初始化方法
super().__init__()
# 设置隐藏层大小,默认为 768
self.hidden_size = hidden_size
# 定义一个线性层,用于投影,输入和输出维度均为隐藏层大小,不使用偏置
self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
# 定义前向传播方法,接受输入 x
def forward(self, x):
# 将输入 x 通过线性层进行投影并返回结果
return self.project(x)
.\diffusers\pipelines\stable_diffusion\convert_from_ckpt.py
# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,表明文件版权归 HuggingFace Inc. 团队所有
# Copyright 2024 The HuggingFace Inc. team.
#
# 根据 Apache License 2.0 许可协议提供文件使用条款
# Licensed under the Apache License, Version 2.0 (the "License");
# 仅在遵循许可协议的情况下使用该文件
# you may not use this file except in compliance with the License.
# 可以通过以下链接获取许可协议的副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律适用或书面协议规定,否则软件按 "现状" 基础分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解具体权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 稳定扩散检查点的转换脚本
"""Conversion script for the Stable Diffusion checkpoints."""
# 导入正则表达式模块
import re
# 从上下文管理器导入空上下文
from contextlib import nullcontext
# 导入字节流模块
from io import BytesIO
# 导入类型定义
from typing import Dict, Optional, Union
# 导入请求库
import requests
# 导入 PyTorch 库
import torch
# 导入 YAML 解析库
import yaml
# 从 transformers 库导入所需的类和函数
from transformers import (
AutoFeatureExtractor, # 自动特征提取器
BertTokenizerFast, # 快速的 BERT 分词器
CLIPImageProcessor, # CLIP 图像处理器
CLIPTextConfig, # CLIP 文本配置
CLIPTextModel, # CLIP 文本模型
CLIPTextModelWithProjection, # 带投影的 CLIP 文本模型
CLIPTokenizer, # CLIP 分词器
CLIPVisionConfig, # CLIP 视觉配置
CLIPVisionModelWithProjection, # 带投影的 CLIP 视觉模型
)
# 从本地模型导入所需的类
from ...models import (
AutoencoderKL, # 自动编码器
ControlNetModel, # 控制网络模型
PriorTransformer, # 先验变换模型
UNet2DConditionModel, # 2D 条件 U-Net 模型
)
# 从调度器导入所需的类
from ...schedulers import (
DDIMScheduler, # DDIM 调度器
DDPMScheduler, # DDPMScheduler
DPMSolverMultistepScheduler, # DPM 多步求解调度器
EulerAncestralDiscreteScheduler, # 欧拉祖先离散调度器
EulerDiscreteScheduler, # 欧拉离散调度器
HeunDiscreteScheduler, # Heun 离散调度器
LMSDiscreteScheduler, # LMS 离散调度器
PNDMScheduler, # PNDM 调度器
UnCLIPScheduler, # UnCLIP 调度器
)
# 从工具模块导入功能
from ...utils import is_accelerate_available, logging
# 从潜在扩散管道导入所需的类
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
# 从图像编码模块导入
from ..paint_by_example import PaintByExampleImageEncoder
# 从管道工具模块导入
from ..pipeline_utils import DiffusionPipeline
# 从安全检查模块导入
from .safety_checker import StableDiffusionSafetyChecker
# 从图像归一化模块导入
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
# 检查加速库是否可用,如果可用则导入相关功能
if is_accelerate_available():
# 从 accelerate 库导入初始化空权重的函数
from accelerate import init_empty_weights
# 从 accelerate.utils 导入设置模块张量到设备的函数
from accelerate.utils import set_module_tensor_to_device
# 创建日志记录器
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定义函数以剃除路径中的段
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
# 如果剃除段数为非负值
if n_shave_prefix_segments >= 0:
# 从路径中剃除前 n_shave_prefix_segments 段
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
# 从路径中剃除最后 n_shave_prefix_segments 段
return ".".join(path.split(".")[:n_shave_prefix_segments])
# 定义函数以更新 ResNet 路径
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
# 创建一个空的映射列表
mapping = []
# 遍历旧列表中的每个项目
for old_item in old_list:
# 将旧项目中的 "in_layers.0" 替换为 "norm1"
new_item = old_item.replace("in_layers.0", "norm1")
# 将旧项目中的 "in_layers.2" 替换为 "conv1"
new_item = new_item.replace("in_layers.2", "conv1")
# 将旧项目中的 "out_layers.0" 替换为 "norm2"
new_item = new_item.replace("out_layers.0", "norm2")
# 将旧项目中的 "out_layers.3" 替换为 "conv2"
new_item = new_item.replace("out_layers.3", "conv2")
# 将旧项目中的 "emb_layers.1" 替换为 "time_emb_proj"
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
# 将旧项目中的 "skip_connection" 替换为 "conv_shortcut"
new_item = new_item.replace("skip_connection", "conv_shortcut")
# 对新项目进行修剪,去掉指定数量的前缀段
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
# 将旧项目和新项目的映射添加到列表中
mapping.append({"old": old_item, "new": new_item})
# 返回旧项目和新项目的映射列表
return mapping
# 更新 VAE ResNet 中路径以符合新的命名规范(局部重命名)
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
# 初始化一个映射列表,用于存储旧路径和新路径的对应关系
mapping = []
# 遍历旧路径列表中的每个项目
for old_item in old_list:
# 将当前旧路径赋值给新路径变量
new_item = old_item
# 将 'nin_shortcut' 替换为 'conv_shortcut'
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
# 根据需要去除前缀段
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
# 将旧路径和新路径的映射添加到列表中
mapping.append({"old": old_item, "new": new_item})
# 返回旧路径和新路径的映射列表
return mapping
# 更新注意力层中的路径以符合新的命名规范(局部重命名)
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
# 初始化一个映射列表,用于存储旧路径和新路径的对应关系
mapping = []
# 遍历旧路径列表中的每个项目
for old_item in old_list:
# 将当前旧路径赋值给新路径变量
new_item = old_item
# 下面的代码行是注释掉的,用于替换 norm.weight 和 norm.bias 的命名
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# 下面的代码行是注释掉的,用于替换 proj_out.weight 和 proj_out.bias 的命名
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# 下面的代码行是注释掉的,根据需要去除前缀段
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
# 将旧路径和新路径的映射添加到列表中
mapping.append({"old": old_item, "new": new_item})
# 返回旧路径和新路径的映射列表
return mapping
# 更新 VAE 注意力层中的路径以符合新的命名规范(局部重命名)
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
# 初始化一个映射列表,用于存储旧路径和新路径的对应关系
mapping = []
# 遍历旧路径列表中的每个项目
for old_item in old_list:
# 将当前旧路径赋值给新路径变量
new_item = old_item
# 将 'norm.weight' 替换为 'group_norm.weight'
new_item = new_item.replace("norm.weight", "group_norm.weight")
# 将 'norm.bias' 替换为 'group_norm.bias'
new_item = new_item.replace("norm.bias", "group_norm.bias")
# 将 'q.weight' 替换为 'to_q.weight'
new_item = new_item.replace("q.weight", "to_q.weight")
# 将 'q.bias' 替换为 'to_q.bias'
new_item = new_item.replace("q.bias", "to_q.bias")
# 将 'k.weight' 替换为 'to_k.weight'
new_item = new_item.replace("k.weight", "to_k.weight")
# 将 'k.bias' 替换为 'to_k.bias'
new_item = new_item.replace("k.bias", "to_k.bias")
# 将 'v.weight' 替换为 'to_v.weight'
new_item = new_item.replace("v.weight", "to_v.weight")
# 将 'v.bias' 替换为 'to_v.bias'
new_item = new_item.replace("v.bias", "to_v.bias")
# 将 'proj_out.weight' 替换为 'to_out.0.weight'
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
# 将 'proj_out.bias' 替换为 'to_out.0.bias'
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
# 根据需要去除前缀段
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
# 将旧路径和新路径的映射添加到列表中
mapping.append({"old": old_item, "new": new_item})
# 返回旧路径和新路径的映射列表
return mapping
# 将转换后的权重分配给新的检查点,应用全局重命名
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
# 确保路径是一个包含 'old' 和 'new' 键的字典列表
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# 将注意力层拆分为三个变量
# 检查是否有需要拆分的注意力路径
if attention_paths_to_split is not None:
# 遍历需要拆分的注意力路径及其映射
for path, path_map in attention_paths_to_split.items():
# 获取旧检查点中对应路径的张量
old_tensor = old_checkpoint[path]
# 计算通道数,假设每个注意力有三个通道
channels = old_tensor.shape[0] // 3
# 根据张量维度确定目标形状
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
# 计算头数
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
# 重塑旧张量为(头数,通道数/头数,其他维度)
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
# 将旧张量分割为查询、键和值
query, key, value = old_tensor.split(channels // num_heads, dim=1)
# 将查询、键、值重塑并存入检查点
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
# 遍历所有路径
for path in paths:
# 获取新路径
new_path = path["new"]
# 如果新路径已在拆分路径中,跳过
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# 执行全局重命名
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
# 如果有额外的替换规则,应用它们
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# 检查是否需要转换注意力权重
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
# 获取旧检查点中对应路径的形状
shape = old_checkpoint[path["old"]].shape
# 根据形状和类型存入新检查点
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
# 将检查点中的注意力层转换为线性层
def conv_attn_to_linear(checkpoint):
# 获取检查点字典中的所有键
keys = list(checkpoint.keys())
# 定义注意力层权重的关键字
attn_keys = ["query.weight", "key.weight", "value.weight"]
# 遍历检查点的所有键
for key in keys:
# 如果键对应的权重在注意力层关键字中
if ".".join(key.split(".")[-2:]) in attn_keys:
# 如果权重的维度大于2,则进行切片操作
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
# 如果键中包含投影注意力权重
elif "proj_attn.weight" in key:
# 如果权重的维度大于2,则进行切片操作
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
# 创建适用于 Diffusers 的 UNet 配置
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
根据 LDM 模型的配置创建 Diffusers 的配置。
"""
# 如果使用 ControlNet,则获取相关参数
if controlnet:
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
else:
# 检查原始配置中是否包含 UNet 配置
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
# 获取 UNet 参数
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else:
# 否则获取网络配置参数
unet_params = original_config["model"]["params"]["network_config"]["params"]
# 获取 VAE 的相关参数
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
# 计算每个块的输出通道数
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
# 初始化下采样块类型列表
down_block_types = []
# 初始化分辨率
resolution = 1
# 遍历每个块的输出通道
for i in range(len(block_out_channels)):
# 根据当前分辨率决定下采样块类型
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
# 更新分辨率(如果不是最后一个块)
if i != len(block_out_channels) - 1:
resolution *= 2
# 初始化上采样块类型列表
up_block_types = []
# 遍历每个块的输出通道
for i in range(len(block_out_channels)):
# 根据当前分辨率决定上采样块类型
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
# 更新分辨率
resolution //= 2
# 如果定义了 transformer 深度
if unet_params["transformer_depth"] is not None:
# 判断 transformer 深度是整数还是列表
transformer_layers_per_block = (
unet_params["transformer_depth"]
if isinstance(unet_params["transformer_depth"], int)
else list(unet_params["transformer_depth"])
)
else:
# 默认设置为1层
transformer_layers_per_block = 1
# 计算 VAE 的缩放因子
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
# 获取头部维度(如果存在)
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
# 确定是否使用线性投影
use_linear_projection = (
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
# 如果使用线性投影
if use_linear_projection:
# 针对稳定扩散的特定模型设置
if head_dim is None:
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
# 初始化额外的嵌入类型和维度
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
# 检查 unet_params 中的 context_dim 是否不为 None
if unet_params["context_dim"] is not None:
# 根据 context_dim 的类型设置其值
context_dim = (
# 如果是整数,则直接使用该值
unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
# 否则使用其第一个元素
else unet_params["context_dim"][0]
)
# 检查 unet_params 中是否包含 num_classes
if "num_classes" in unet_params:
# 如果 num_classes 的值为 "sequential"
if unet_params["num_classes"] == "sequential":
# 如果 context_dim 在特定值中
if context_dim in [2048, 1280]:
# 设置附加嵌入类型为 "text_time"
addition_embed_type = "text_time"
# 设置附加时间嵌入维度为 256
addition_time_embed_dim = 256
else:
# 否则设置类嵌入类型为 "projection"
class_embed_type = "projection"
# 确保 unet_params 中存在 "adm_in_channels"
assert "adm_in_channels" in unet_params
# 获取投影类嵌入的输入维度
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
# 构建配置字典
config = {
# 计算样本大小
"sample_size": image_size // vae_scale_factor,
# 获取输入通道数
"in_channels": unet_params["in_channels"],
# 将下采样块类型转换为元组
"down_block_types": tuple(down_block_types),
# 将块输出通道转换为元组
"block_out_channels": tuple(block_out_channels),
# 获取每个块的层数
"layers_per_block": unet_params["num_res_blocks"],
# 设置交叉注意力维度
"cross_attention_dim": context_dim,
# 设置注意力头的维度
"attention_head_dim": head_dim,
# 设置是否使用线性投影
"use_linear_projection": use_linear_projection,
# 设置类嵌入类型
"class_embed_type": class_embed_type,
# 设置附加嵌入类型
"addition_embed_type": addition_embed_type,
# 设置附加时间嵌入维度
"addition_time_embed_dim": addition_time_embed_dim,
# 设置投影类嵌入的输入维度
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
# 设置每个块的变换层数
"transformer_layers_per_block": transformer_layers_per_block,
}
# 如果 unet_params 中包含 "disable_self_attentions"
if "disable_self_attentions" in unet_params:
# 设置仅使用交叉注意力
config["only_cross_attention"] = unet_params["disable_self_attentions"]
# 如果 unet_params 中包含 num_classes 并且是整数
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
# 设置类嵌入的数量
config["num_class_embeds"] = unet_params["num_classes"]
# 如果 controlnet 为 True
if controlnet:
# 设置条件通道数
config["conditioning_channels"] = unet_params["hint_channels"]
else:
# 否则设置输出通道数
config["out_channels"] = unet_params["out_channels"]
# 设置上采样块类型
config["up_block_types"] = tuple(up_block_types)
# 返回配置字典
return config
# 创建一个基于 LDM 模型配置的 diffusers 配置
def create_vae_diffusers_config(original_config, image_size: int):
# 从原始配置中提取 VAE 参数
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
# 提取嵌入维度(未使用)
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
# 计算每个块的输出通道数量
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
# 创建每个下采样块的类型列表
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
# 创建每个上采样块的类型列表
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
# 创建配置字典,包含所需参数
config = {
"sample_size": image_size, # 设置样本大小
"in_channels": vae_params["in_channels"], # 设置输入通道数
"out_channels": vae_params["out_ch"], # 设置输出通道数
"down_block_types": tuple(down_block_types), # 转换为元组并设置下采样块类型
"up_block_types": tuple(up_block_types), # 转换为元组并设置上采样块类型
"block_out_channels": tuple(block_out_channels), # 转换为元组并设置块输出通道
"latent_channels": vae_params["z_channels"], # 设置潜在通道数
"layers_per_block": vae_params["num_res_blocks"], # 设置每个块的层数
}
# 返回创建的配置
return config
# 创建一个调度器配置基于原始配置
def create_diffusers_schedular(original_config):
# 初始化 DDIMScheduler 对象,设置训练时间步和 beta 参数
schedular = DDIMScheduler(
num_train_timesteps=original_config["model"]["params"]["timesteps"], # 设置训练时间步数
beta_start=original_config["model"]["params"]["linear_start"], # 设置 beta 开始值
beta_end=original_config["model"]["params"]["linear_end"], # 设置 beta 结束值
beta_schedule="scaled_linear", # 设置 beta 调度类型
)
# 返回创建的调度器
return schedular
# 创建 LDM 的 BERT 配置
def create_ldm_bert_config(original_config):
# 从原始配置中提取 BERT 参数
bert_params = original_config["model"]["params"]["cond_stage_config"]["params"]
# 创建 LDMBertConfig 对象,设置模型参数
config = LDMBertConfig(
d_model=bert_params.n_embed, # 设置嵌入维度
encoder_layers=bert_params.n_layer, # 设置编码器层数
encoder_ffn_dim=bert_params.n_embed * 4, # 设置编码器前馈层维度
)
# 返回创建的配置
return config
# 转换 LDM UNet 检查点
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
# 对给定的状态字典和配置进行转换,返回转换后的检查点
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# 如果跳过提取状态字典,直接使用检查点
if skip_extract_state_dict:
unet_state_dict = checkpoint # 将 UNet 状态字典设为检查点
else:
# 提取 UNet 的状态字典
unet_state_dict = {} # 初始化一个空字典用于存储 UNet 的状态
keys = list(checkpoint.keys()) # 获取检查点中所有键的列表
if controlnet:
unet_key = "control_model." # 如果使用 controlnet,设置对应的键前缀
else:
unet_key = "model.diffusion_model." # 否则设置为默认的模型前缀
# 检查是否有超过 100 个参数以 `model_ema` 开头,如果是且需要提取 EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") # 记录警告信息,表明检查点同时有 EMA 和非 EMA 权重
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
) # 提示用户如果想提取非 EMA 权重需要移除 `--extract_ema` 标志
for key in keys: # 遍历所有键
if key.startswith("model.diffusion_model"): # 检查键是否以指定前缀开头
flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) # 创建相应的 EMA 键
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) # 从检查点中移除 EMA 权重并存储在字典中
else:
# 如果有超过 100 个 `model_ema` 开头的参数但不提取 EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
) # 提示用户如果想提取 EMA 权重需要添加 `--extract_ema` 标志
for key in keys: # 遍历所有键
if key.startswith(unet_key): # 检查键是否以 UNet 的前缀开头
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) # 从检查点中移除对应的权重并存储在字典中
new_checkpoint = {} # 初始化一个新的检查点字典
# 从 UNet 状态字典中提取时间嵌入的权重和偏置
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] # 提取时间嵌入第 1 层的权重
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] # 提取时间嵌入第 1 层的偏置
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] # 提取时间嵌入第 2 层的权重
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] # 提取时间嵌入第 2 层的偏置
if config["class_embed_type"] is None: # 如果类嵌入类型为 None
# 无需迁移的参数
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": # 如果类嵌入类型为 "timestep" 或 "projection"
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] # 提取类嵌入第 1 层的权重
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] # 提取类嵌入第 1 层的偏置
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] # 提取类嵌入第 2 层的权重
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] # 提取类嵌入第 2 层的偏置
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") # 抛出未实现错误,提示未知的类嵌入类型
# 检查配置中的嵌入类型是否为文本时间
if config["addition_embed_type"] == "text_time":
# 将 UNet 状态字典中的权重赋值给新的检查点的线性层1的权重
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
# 将 UNet 状态字典中的偏置赋值给新的检查点的线性层1的偏置
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
# 将 UNet 状态字典中的权重赋值给新的检查点的线性层2的权重
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
# 将 UNet 状态字典中的偏置赋值给新的检查点的线性层2的偏置
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# 相关于 StableDiffusionUpscalePipeline
# 检查配置中是否存在 num_class_embeds 键
if "num_class_embeds" in config:
# 确保 num_class_embeds 不为空且 UNet 状态字典中存在 label_emb.weight
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
# 将 UNet 状态字典中的类嵌入权重赋值给新的检查点
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
# 将 UNet 状态字典中的输入块权重赋值给新的检查点
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
# 将 UNet 状态字典中的输入块偏置赋值给新的检查点
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
# 如果不使用 controlnet
if not controlnet:
# 将 UNet 状态字典中的输出块的权重赋值给新的检查点
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
# 将 UNet 状态字典中的输出块的偏置赋值给新的检查点
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
# 将 UNet 状态字典中的输出块的权重赋值给新的检查点
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
# 将 UNet 状态字典中的输出块的偏置赋值给新的检查点
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# 检索输入块的键
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
# 创建一个字典,其中包含每个输入块的所有相关键
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# 检索中间块的键
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
# 创建一个字典,其中包含每个中间块的所有相关键
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# 检索输出块的键
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
# 创建一个字典,其中包含每个输出块的所有相关键
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
# 遍历输入块,从第一个输入块开始,直到指定数量的输入块
for i in range(1, num_input_blocks):
# 计算当前块所属的区块 ID
block_id = (i - 1) // (config["layers_per_block"] + 1)
# 计算当前层在块中的 ID
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
# 获取当前输入块中与 ResNet 相关的键,排除掉特定的操作键
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
# 获取当前输入块中与注意力相关的键
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
# 检查 UNet 状态字典中是否包含权重信息
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
# 将权重从 UNet 状态字典移动到新的检查点字典中
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
# 将偏置从 UNet 状态字典移动到新的检查点字典中
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
# 更新 ResNet 路径
paths = renew_resnet_paths(resnets)
# 定义旧路径和新路径的映射关系
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
# 将更新后的路径信息赋值到检查点中
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# 如果当前块中存在注意力路径
if len(attentions):
# 更新注意力路径
paths = renew_attention_paths(attentions)
# 定义注意力的旧路径和新路径的映射关系
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
# 将更新后的路径信息赋值到检查点中
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# 获取中间块的第一个 ResNet
resnet_0 = middle_blocks[0]
# 获取中间块的注意力部分
attentions = middle_blocks[1]
# 获取中间块的第二个 ResNet
resnet_1 = middle_blocks[2]
# 更新第一个 ResNet 的路径
resnet_0_paths = renew_resnet_paths(resnet_0)
# 将第一个 ResNet 的路径信息赋值到检查点中
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
# 更新第二个 ResNet 的路径
resnet_1_paths = renew_resnet_paths(resnet_1)
# 将第二个 ResNet 的路径信息赋值到检查点中
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
# 更新注意力路径
attentions_paths = renew_attention_paths(attentions)
# 定义注意力的旧路径和新路径的映射关系
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
# 将更新后的路径信息赋值到检查点中
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# 遍历输出块的数量
for i in range(num_output_blocks):
# 计算当前块的 ID
block_id = i // (config["layers_per_block"] + 1)
# 计算当前块内层的 ID
layer_in_block_id = i % (config["layers_per_block"] + 1)
# 对当前输出块的每一层进行修剪
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
# 初始化当前输出块的层列表
output_block_list = {}
# 遍历当前块的每一层
for layer in output_block_layers:
# 分离层的 ID 和名称
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
# 如果该层 ID 已存在,则添加层名称
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
# 否则新建该层 ID 的列表
output_block_list[layer_id] = [layer_name]
# 如果当前块的层数大于 1
if len(output_block_list) > 1:
# 获取当前块的残差网络路径
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
# 获取当前块的注意力路径
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
# 更新残差网络路径
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
# 创建元路径字典以进行替换
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
# 将路径和新检查点赋值
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# 对输出块列表进行排序
output_block_list = {k: sorted(v) for k, v in sorted(output_block_list.items())}
# 检查是否存在特定的卷积层
if ["conv.bias", "conv.weight"] in output_block_list.values():
# 获取卷积层的索引
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
# 将权重和偏差值赋值给新的检查点
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# 清除注意力层,因为它们已在上面处理
if len(attentions) == 2:
attentions = []
# 如果存在注意力层
if len(attentions):
# 更新注意力路径
paths = renew_attention_paths(attentions)
# 创建元路径字典以进行替换
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
# 将路径和新检查点赋值
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
# 更新残差网络路径,去除前缀
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
# 遍历每个路径
for path in resnet_0_paths:
# 生成旧路径和新路径
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
# 将新路径的值赋值给新的检查点
new_checkpoint[new_path] = unet_state_dict[old_path]
# 检查是否启用 ControlNet
if controlnet:
# 初始化原始索引,用于提取控制嵌入的权重和偏置
orig_index = 0
# 从 UNet 状态字典中弹出输入卷积层的权重,并存入新检查点
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
# 从 UNet 状态字典中弹出输入卷积层的偏置,并存入新检查点
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# 更新原始索引以指向下一个层
orig_index += 2
# 初始化 Diffusers 索引,用于提取后续层的权重和偏置
diffusers_index = 0
# 循环提取 6 个控制嵌入块的权重和偏置
while diffusers_index < 6:
# 从 UNet 状态字典中弹出当前控制嵌入块的权重,并存入新检查点
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
# 从 UNet 状态字典中弹出当前控制嵌入块的偏置,并存入新检查点
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# 更新 Diffusers 索引和原始索引以处理下一个块
diffusers_index += 1
orig_index += 2
# 从 UNet 状态字典中弹出输出卷积层的权重,并存入新检查点
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
# 从 UNet 状态字典中弹出输出卷积层的偏置,并存入新检查点
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# 提取下行块的权重和偏置
for i in range(num_input_blocks):
# 从 UNet 状态字典中弹出下行块的权重,并存入新检查点
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
# 从 UNet 状态字典中弹出下行块的偏置,并存入新检查点
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
# 提取中间块的权重和偏置
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
# 返回新检查点
return new_checkpoint
# 将 VAE 检查点转换为新的格式
def convert_ldm_vae_checkpoint(checkpoint, config):
# 初始化 VAE 状态字典
vae_state_dict = {}
# 获取检查点的所有键
keys = list(checkpoint.keys())
# 检查是否有以 "first_stage_model." 开头的键
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
# 遍历所有键
for key in keys:
# 如果键以 vae_key 开头,则添加到 VAE 状态字典
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
# 初始化新的检查点字典
new_checkpoint = {}
# 从 VAE 状态字典中提取编码器的权重和偏置
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
# 从 VAE 状态字典中提取解码器的权重和偏置
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv.in.weight"]
new_checkpoint["decoder.conv.in.bias"] = vae_state_dict["decoder.conv.in.bias"]
new_checkpoint["decoder.conv.out.weight"] = vae_state_dict["decoder.conv.out.weight"]
new_checkpoint["decoder.conv.out.bias"] = vae_state_dict["decoder.conv.out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm.out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm.out.bias"]
# 提取量化层的权重和偏置
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# 获取仅包含编码器下采样块的键
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# 获取仅包含解码器上采样块的键
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
# 遍历下采样块的数量
for i in range(num_down_blocks):
# 从当前下采样块中筛选出符合条件的残差网络层
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
# 检查 VAE 状态字典中是否存在对应的卷积权重
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
# 从状态字典中移除卷积权重,并添加到新的检查点
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
# 从状态字典中移除卷积偏置,并添加到新的检查点
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
# 更新残差网络的路径
paths = renew_vae_resnet_paths(resnets)
# 定义旧路径和新路径的映射
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
# 将路径分配到检查点中
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
# 从状态字典中筛选出中间残差网络的关键字
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
# 设置中间残差块的数量
num_mid_res_blocks = 2
# 遍历中间残差块的数量
for i in range(1, num_mid_res_blocks + 1):
# 筛选当前中间块的残差网络
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
# 更新残差网络的路径
paths = renew_vae_resnet_paths(resnets)
# 定义旧路径和新路径的映射
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
# 将路径分配到检查点中
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
# 从状态字典中筛选出中间注意力层的关键字
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
# 更新注意力层的路径
paths = renew_vae_attention_paths(mid_attentions)
# 定义旧路径和新路径的映射
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
# 将路径分配到检查点中
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
# 将卷积注意力层转换为线性层
conv_attn_to_linear(new_checkpoint)
# 遍历上采样块的数量
for i in range(num_up_blocks):
# 计算当前上采样块的 ID
block_id = num_up_blocks - 1 - i
# 筛选当前上采样块中的残差网络层
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
# 检查 VAE 状态字典中是否存在对应的上采样卷积权重
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
# 从状态字典中移除上采样卷积权重,并添加到新的检查点
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
# 从状态字典中移除上采样卷积偏置,并添加到新的检查点
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
# 更新残差网络的路径
paths = renew_vae_resnet_paths(resnets)
# 定义旧路径和新路径的映射
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
# 将路径分配到检查点中
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
# 从状态字典中筛选出解码器中间块的关键字
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
# 设置中间残差块的数量
num_mid_res_blocks = 2
# 遍历中间残差块的索引,从 1 到 num_mid_res_blocks(包含)
for i in range(1, num_mid_res_blocks + 1):
# 收集当前索引的中间残差网络中的所有相关键
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
# 更新 VAE 残差网络路径
paths = renew_vae_resnet_paths(resnets)
# 创建一个字典,记录旧路径和新路径的映射
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
# 将更新后的路径和映射信息分配到检查点
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
# 收集所有与中间注意力层相关的键
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
# 更新 VAE 注意力层路径
paths = renew_vae_attention_paths(mid_attentions)
# 创建一个字典,记录旧注意力路径和新路径的映射
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
# 将更新后的路径和映射信息分配到检查点
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
# 将卷积注意力层转换为线性注意力层
conv_attn_to_linear(new_checkpoint)
# 返回更新后的检查点
return new_checkpoint
# 定义函数,转换 LDM BERT 检查点到 Hugging Face 模型
def convert_ldm_bert_checkpoint(checkpoint, config):
# 定义内部函数,复制注意力层的权重和偏置
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
# 复制查询权重
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
# 复制键权重
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
# 复制值权重
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
# 复制输出层的权重
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
# 复制输出层的偏置
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
# 定义内部函数,复制线性层的权重和偏置
def _copy_linear(hf_linear, pt_linear):
# 复制线性层的权重
hf_linear.weight = pt_linear.weight
# 复制线性层的偏置
hf_linear.bias = pt_linear.bias
# 定义内部函数,复制整个层的参数
def _copy_layer(hf_layer, pt_layer):
# 复制层归一化
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
# 复制注意力层
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
# 复制 MLP
pt_mlp = pt_layer[1][1]
# 复制 MLP 的第一层
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
# 复制 MLP 的第二层
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
# 定义内部函数,复制多个层的参数
def _copy_layers(hf_layers, pt_layers):
# 遍历每一层
for i, hf_layer in enumerate(hf_layers):
# 跳过第一层(不复制)
if i != 0:
i += i
# 获取对应的 PyTorch 层
pt_layer = pt_layers[i : i + 2]
# 复制当前层的参数
_copy_layer(hf_layer, pt_layer)
# 创建 LDM BERT 模型实例,并设置为评估模式
hf_model = LDMBertModel(config).eval()
# 复制嵌入层的权重
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
# 复制位置嵌入层的权重
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
# 复制层归一化
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
# 复制隐藏层
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
# 复制最终线性层
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
# 返回转换后的 Hugging Face 模型
return hf_model
# 定义函数,转换 LDM CLIP 检查点到 Hugging Face 模型
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
# 如果没有提供文本编码器
if text_encoder is None:
# 定义默认配置名称
config_name = "openai/clip-vit-large-patch14"
try:
# 从预训练模型中加载配置
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
# 捕获异常并提示用户
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
)
# 根据是否可用选择上下文管理器
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# 在上下文中初始化文本模型
with ctx():
text_model = CLIPTextModel(config)
# 如果提供了文本编码器
else:
# 使用提供的文本编码器
text_model = text_encoder
# 获取检查点中的所有键
keys = list(checkpoint.keys())
# 创建一个空字典来存储文本模型的权重
text_model_dict = {}
# 定义需要移除的前缀
remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
# 遍历所有键
for key in keys:
# 遍历每个需要移除的前缀
for prefix in remove_prefixes:
# 如果键以前缀开头
if key.startswith(prefix):
# 将去掉前缀后的键值对存入字典
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
# 检查是否可以使用加速功能
if is_accelerate_available():
# 遍历文本模型字典中的参数名称及其对应的参数
for param_name, param in text_model_dict.items():
# 将参数的张量设置到文本模型的设备上,这里是 CPU
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
# 检查文本模型是否具有嵌入层及位置 ID
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
# 从文本模型字典中移除位置 ID
text_model_dict.pop("text_model.embeddings.position_ids", None)
# 加载文本模型字典中的状态字典
text_model.load_state_dict(text_model_dict)
# 返回处理后的文本模型
return text_model
# 创建文本编码转换列表,包含源名称和目标名称的元组
textenc_conversion_lst = [
# 位置嵌入的源名称和目标名称
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
# 令牌嵌入权重的源名称和目标名称
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
# 最终层归一化的权重源名称和目标名称
("ln_final.weight", "text_model.final_layer_norm.weight"),
# 最终层归一化的偏置源名称和目标名称
("ln_final.bias", "text_model.final_layer_norm.bias"),
# 文本投影的源名称和目标名称
("text_projection", "text_projection.weight"),
]
# 生成文本编码转换映射字典,键为源名称,值为目标名称
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
# 创建文本编码转换列表,用于转换稳定扩散模型和 HF Diffusers
textenc_transformer_conversion_lst = [
# (稳定扩散, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
# 层归一化1的名称映射
("ln_1", "layer_norm1"),
# 层归一化2的名称映射
("ln_2", "layer_norm2"),
# 全连接层的前向映射
(".c_fc.", ".fc1."),
# 全连接层的后向映射
(".c_proj.", ".fc2."),
# 注意力机制的名称映射
(".attn", ".self_attn"),
# 最终层归一化的名称映射
("ln_final.", "transformer.text_model.final_layer_norm."),
# 令牌嵌入权重的名称映射
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
# 位置嵌入的名称映射
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
# 生成受保护的映射字典,使用正则表达式转义源名称
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
# 创建正则表达式模式,匹配受保护的源名称
textenc_pattern = re.compile("|".join(protected.keys()))
# 定义函数,用于转换按示例绘制的检查点
def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
# 从预训练模型加载配置
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
# 创建图像编码器模型
model = PaintByExampleImageEncoder(config)
# 获取检查点中的所有键
keys = list(checkpoint.keys())
# 初始化文本模型字典
text_model_dict = {}
# 遍历检查点键,提取符合条件的键值对
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# 加载 CLIP 视觉模型的状态字典
model.model.load_state_dict(text_model_dict)
# 加载映射器
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}
# 定义映射规则
MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}
# 初始化映射权重字典
mapped_weights = {}
# 遍历映射键,进行权重映射
for key, value in keys_mapper.items():
# 获取前缀和后缀
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
# 提取名称
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]
# 计算拆分数量
num_splits = len(mapped_names)
# 遍历映射名称并更新映射权重
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
# 加载映射器的状态字典
model.mapper.load_state_dict(mapped_weights)
# 加载最终层归一化的状态字典
model.final_layer_norm.load_state_dict(
{
# 加载偏置
"bias": checkpoint["cond_stage_model.final_ln.bias"],
# 加载权重
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)
# 加载最终投影
# 加载模型的投影输出层的状态字典
model.proj_out.load_state_dict(
# 创建一个字典,包含偏置和权重参数
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)
# 加载无条件向量
# 将检查点中的可学习向量赋值给模型的无条件向量
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
# 返回更新后的模型
return model
# 定义一个函数,用于转换 OpenCLIP 的检查点
def convert_open_clip_checkpoint(
# 检查点数据
checkpoint,
# 配置名称
config_name,
# 模型前缀,默认为 "cond_stage_model.model."
prefix="cond_stage_model.model.",
# 是否包含投影层,默认为 False
has_projection=False,
# 是否仅使用本地文件,默认为 False
local_files_only=False,
# 其他配置参数
**config_kwargs,
):
# 加载 CLIP 文本模型配置,可能抛出异常
try:
# 从预训练模型加载配置
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
# 捕获异常并抛出自定义错误信息
except Exception:
raise ValueError(
# 指出需要本地保存的配置路径
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
)
# 根据加速库的可用性选择上下文管理器
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# 在选择的上下文中执行模型初始化
with ctx():
# 如果有投影层,则创建带投影的文本模型,否则创建普通文本模型
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
# 获取检查点中的所有键
keys = list(checkpoint.keys())
# 定义一个列表,用于存储需要忽略的键
keys_to_ignore = []
# 如果配置名称和隐藏层数符合条件,添加需要忽略的键
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
# 确保移除所有大于 22 的键
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
keys_to_ignore += ["cond_stage_model.model.text_projection"]
# 初始化文本模型的字典
text_model_dict = {}
# 检查检查点中是否存在文本投影的键
if prefix + "text_projection" in checkpoint:
# 获取文本投影的维度
d_model = int(checkpoint[prefix + "text_projection"].shape[0])
else:
# 默认维度为 1024
d_model = 1024
# 从文本模型中获取位置 IDs 并保存到字典
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
# 遍历所有关键字
for key in keys:
# 如果关键字在忽略列表中,则跳过当前循环
if key in keys_to_ignore:
continue
# 检查去掉前缀后的关键字是否在文本编码转换映射中
if key[len(prefix) :] in textenc_conversion_map:
# 如果关键字以 "text_projection" 结尾
if key.endswith("text_projection"):
# 获取检查点中对应关键字的转置并保持为连续内存
value = checkpoint[key].T.contiguous()
else:
# 否则直接获取检查点中对应关键字的值
value = checkpoint[key]
# 将转换后的关键字和对应的值添加到文本模型字典中
text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
# 检查关键字是否以 "transformer." 为前缀
if key.startswith(prefix + "transformer."):
# 去掉前缀后的新关键字
new_key = key[len(prefix + "transformer.") :]
# 如果新关键字以 ".in_proj_weight" 结尾
if new_key.endswith(".in_proj_weight"):
# 去掉 ".in_proj_weight" 后缀
new_key = new_key[: -len(".in_proj_weight")]
# 使用正则表达式替换模式
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
# 将查询投影权重添加到文本模型字典中
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
# 将键值投影权重添加到文本模型字典中
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
# 将值投影权重添加到文本模型字典中
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
# 如果新关键字以 ".in_proj_bias" 结尾
elif new_key.endswith(".in_proj_bias"):
# 去掉 ".in_proj_bias" 后缀
new_key = new_key[: -len(".in_proj_bias")]
# 使用正则表达式替换模式
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
# 将查询偏置添加到文本模型字典中
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
# 将键值偏置添加到文本模型字典中
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
# 将值偏置添加到文本模型字典中
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
else:
# 对新关键字应用正则表达式替换模式
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
# 将处理后的新关键字和对应值添加到文本模型字典中
text_model_dict[new_key] = checkpoint[key]
# 检查是否可用 accelerate 库
if is_accelerate_available():
# 遍历文本模型字典中的所有参数名及其参数
for param_name, param in text_model_dict.items():
# 将模型参数设置到设备(CPU)
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
# 如果文本模型没有嵌入或没有位置 ID 属性
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
# 从文本模型字典中移除位置 ID
text_model_dict.pop("text_model.embeddings.position_ids", None)
# 加载文本模型字典中的状态字典
text_model.load_state_dict(text_model_dict)
# 返回文本模型
return text_model
# 定义函数,返回用于 img2img unclip 流水线的图像处理器和 clip 图像编码器
def stable_unclip_image_encoder(original_config, local_files_only=False):
"""
返回 img2img unclip 流水线的图像处理器和 clip 图像编码器。
我们目前知道有两种类型的稳定 unclip 模型,分别使用 clip 和 openclip 图像
编码器。
"""
# 获取嵌入器配置
image_embedder_config = original_config["model"]["params"]["embedder_config"]
# 提取目标嵌入器的类名
sd_clip_image_embedder_class = image_embedder_config["target"]
# 仅保留类名部分
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
# 检查嵌入器类名是否为 ClipImageEmbedder
if sd_clip_image_embedder_class == "ClipImageEmbedder":
# 获取 CLIP 模型名称
clip_model_name = image_embedder_config.params.model
# 如果模型名称为 ViT-L/14,创建特征提取器和图像编码器
if clip_model_name == "ViT-L/14":
feature_extractor = CLIPImageProcessor() # 初始化特征提取器
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
) # 从预训练模型加载图像编码器
else:
# 如果模型名称未知,抛出未实现错误
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
# 检查嵌入器类名是否为 FrozenOpenCLIPImageEmbedder
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
feature_extractor = CLIPImageProcessor() # 初始化特征提取器
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only
) # 从预训练模型加载图像编码器
else:
# 如果嵌入器类名未知,抛出未实现错误
raise NotImplementedError(
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
)
# 返回特征提取器和图像编码器
return feature_extractor, image_encoder
# 定义函数,返回用于 img2img 和 txt2img unclip 流水线的噪声组件
def stable_unclip_image_noising_components(
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
):
"""
返回 img2img 和 txt2img unclip 流水线的噪声组件。
将稳定性噪声增强器转换为
1. 用于保存 CLIP 统计信息的 `StableUnCLIPImageNormalizer`
2. 用于保存噪声调度的 `DDPMScheduler`
如果噪声增强器配置指定了 CLIP 统计信息路径,则必须提供 `clip_stats_path`。
"""
# 获取噪声增强器配置
noise_aug_config = original_config["model"]["params"]["noise_aug_config"]
# 提取噪声增强器的类名
noise_aug_class = noise_aug_config["target"]
# 仅保留类名部分
noise_aug_class = noise_aug_class.split(".")[-1]
# 检查是否使用 CLIP 噪声增强类
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
# 获取噪声增强配置的参数
noise_aug_config = noise_aug_config.params
# 获取时间步长维度
embedding_dim = noise_aug_config.timestep_dim
# 获取最大噪声级别
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
# 获取贝塔调度配置
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
# 创建图像归一化器,基于嵌入维度
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
# 创建 DDPM 调度器,基于最大训练时间步长和贝塔调度
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
# 检查噪声增强配置中是否包含 clip_stats_path
if "clip_stats_path" in noise_aug_config:
# 如果 clip_stats_path 为空,则抛出错误
if clip_stats_path is None:
raise ValueError("This stable unclip config requires a `clip_stats_path`")
# 从给定路径加载 CLIP 均值和标准差,适应设备
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
# 增加维度以适应后续操作
clip_mean = clip_mean[None, :]
clip_std = clip_std[None, :]
# 创建字典以保存均值和标准差
clip_stats_state_dict = {
"mean": clip_mean,
"std": clip_std,
}
# 加载 CLIP 统计信息到图像归一化器
image_normalizer.load_state_dict(clip_stats_state_dict)
else:
# 如果噪声增强类未知,抛出未实现的错误
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
# 返回图像归一化器和噪声调度器
return image_normalizer, image_noising_scheduler
# 定义一个转换 ControlNet 检查点的函数
def convert_controlnet_checkpoint(
# 检查点数据
checkpoint,
# 原始配置
original_config,
# 检查点路径
checkpoint_path,
# 图像大小
image_size,
# 是否上溯注意力
upcast_attention,
# 是否提取 EMA(指数移动平均)
extract_ema,
# 可选的线性投影参数
use_linear_projection=None,
# 可选的交叉注意力维度
cross_attention_dim=None,
):
# 创建 UNet Diffusers 配置,设置为 ControlNet
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
# 设置上溯注意力配置
ctrlnet_config["upcast_attention"] = upcast_attention
# 移除样本大小配置
ctrlnet_config.pop("sample_size")
# 如果有线性投影参数,则加入配置
if use_linear_projection is not None:
ctrlnet_config["use_linear_projection"] = use_linear_projection
# 如果有交叉注意力维度,则加入配置
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
# 根据是否可用加速功能,选择初始化上下文
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# 在上下文中创建 ControlNet 模型
with ctx():
controlnet = ControlNetModel(**ctrlnet_config)
# 检查点文件可能独立分发与模型组件
if "time_embed.0.weight" in checkpoint:
# 如果检查点包含特定权重,则跳过提取状态字典
skip_extract_state_dict = True
else:
# 否则不跳过提取状态字典
skip_extract_state_dict = False
# 转换 LDM UNet 检查点
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint,
ctrlnet_config,
path=checkpoint_path,
extract_ema=extract_ema,
controlnet=True,
skip_extract_state_dict=skip_extract_state_dict,
)
# 如果可用加速功能,则将参数设置到 ControlNet 模型
if is_accelerate_available():
for param_name, param in converted_ctrl_checkpoint.items():
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
else:
# 否则直接加载状态字典
controlnet.load_state_dict(converted_ctrl_checkpoint)
# 返回构建好的 ControlNet 模型
return controlnet
# 定义从原始 Stable Diffusion 检查点下载的函数
def download_from_original_stable_diffusion_ckpt(
# 检查点路径或字典
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
# 原始配置文件路径
original_config_file: str = None,
# 图像大小
image_size: Optional[int] = None,
# 预测类型
prediction_type: str = None,
# 模型类型
model_type: str = None,
# 是否提取 EMA
extract_ema: bool = False,
# 调度器类型
scheduler_type: str = "pndm",
# 输入通道数
num_in_channels: Optional[int] = None,
# 是否上溯注意力
upcast_attention: Optional[bool] = None,
# 设备类型
device: str = None,
# 是否从安全张量加载
from_safetensors: bool = False,
# 可选的稳定解码器路径
stable_unclip: Optional[str] = None,
# 可选的稳定解码器优先级路径
stable_unclip_prior: Optional[str] = None,
# 可选的剪辑统计路径
clip_stats_path: Optional[str] = None,
# 是否使用 ControlNet
controlnet: Optional[bool] = None,
# 是否使用适配器
adapter: Optional[bool] = None,
# 是否加载安全检查器
load_safety_checker: bool = True,
# 可选的安全检查器对象
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
# 可选的特征提取器对象
feature_extractor: Optional[AutoFeatureExtractor] = None,
# 可选的管道类
pipeline_class: DiffusionPipeline = None,
# 是否只从本地文件加载
local_files_only=False,
# 可选的 VAE 路径
vae_path=None,
# 可选的 VAE 对象
vae=None,
# 可选的文本编码器对象
text_encoder=None,
# 可选的第二文本编码器对象
text_encoder_2=None,
# 可选的标记器对象
tokenizer=None,
# 可选的第二标记器对象
tokenizer_2=None,
# 可选的配置文件列表
config_files=None,
) -> DiffusionPipeline:
"""
从 CompVis 风格的 `.ckpt`/`.safetensors` 文件和(理想情况下)`.yaml` 配置文件加载 Stable Diffusion 管道对象。
尽管许多参数可以自动推断,但其中一些依赖于脆弱的检查。
# 声明全局变量 step,以便在多个函数中使用该变量,但会影响经过进一步微调的模型
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
# 建议在可能的情况下覆盖默认值和/或提供 original_config_file
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
"""
# 导入 pipelines 以避免使用 from_single_file 方法时的循环导入错误
# 从 diffusers 库中导入所需的模型管道
from diffusers import (
# 导入 LDM 文本到图像管道
LDMTextToImagePipeline,
# 导入基于示例的绘画管道
PaintByExamplePipeline,
# 导入控制网络的稳定扩散管道
StableDiffusionControlNetPipeline,
# 导入稳定扩散的修复管道
StableDiffusionInpaintPipeline,
# 导入标准稳定扩散管道
StableDiffusionPipeline,
# 导入稳定扩散的超分辨率管道
StableDiffusionUpscalePipeline,
# 导入稳定扩散 XL 的控制网络修复管道
StableDiffusionXLControlNetInpaintPipeline,
# 导入稳定扩散 XL 的图像到图像管道
StableDiffusionXLImg2ImgPipeline,
# 导入稳定扩散 XL 的修复管道
StableDiffusionXLInpaintPipeline,
# 导入稳定扩散 XL 管道
StableDiffusionXLPipeline,
# 导入稳定 UnCLIP 的图像到图像管道
StableUnCLIPImg2ImgPipeline,
# 导入稳定 UnCLIP 管道
StableUnCLIPPipeline,
)
# 如果预测类型是 "v-prediction",则将其修改为 "v_prediction"
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
# 检查检查点路径或字典是否为字符串类型
if isinstance(checkpoint_path_or_dict, str):
# 如果使用安全张量加载
if from_safetensors:
# 从 safetensors 库导入安全加载函数
from safetensors.torch import load_file as safe_load
# 使用安全加载函数加载检查点到 CPU
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
# 如果未指定设备,则根据可用性选择 CUDA 或 CPU
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载检查点到指定设备
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else:
# 加载检查点到指定设备
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
# 如果检查点是字典类型
elif isinstance(checkpoint_path_or_dict, dict):
# 直接使用该字典作为检查点
checkpoint = checkpoint_path_or_dict
# 检查点中有时没有 global_step 项
if "global_step" in checkpoint:
# 从检查点中获取 global_step
global_step = checkpoint["global_step"]
else:
# 记录调试信息:未找到 global_step 键
logger.debug("global_step key not found in model")
# 如果未找到,设置 global_step 为 None
global_step = None
# 注意:这个 while 循环不是很理想,但这个 controlnet 检查点有一个额外的 "state_dict" 键
# https://huggingface.co/thibaud/controlnet-canny-sd21
while "state_dict" in checkpoint:
# 更新检查点为其 "state_dict" 内容
checkpoint = checkpoint["state_dict"]
# 检查原始配置文件是否为 None
if original_config_file is None:
# 定义 V2.1 模型的关键名称
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
# 定义 SD XL 基础模型的关键名称
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
# 定义 SD XL 精修模型的关键名称
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
# 判断是否是上采样管道
is_upscale = pipeline_class == StableDiffusionUpscalePipeline
# 初始化配置 URL 为 None
config_url = None
# model_type = "v1"
# 检查 config_files 是否存在并包含 "v1"
if config_files is not None and "v1" in config_files:
# 设置原始配置文件为 v1 的配置文件
original_config_file = config_files["v1"]
else:
# 设置配置 URL 为 v1 的 YAML 文件
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
# 检查检查点中是否包含 V2.1 的关键名称并且其形状最后一维为 1024
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2"
# 检查 config_files 是否存在并包含 "v2"
if config_files is not None and "v2" in config_files:
# 设置原始配置文件为 v2 的配置文件
original_config_file = config_files["v2"]
else:
# 设置配置 URL 为 v2 的 YAML 文件
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
# 检查全局步骤是否为 110000
if global_step == 110000:
# v2.1 需要上采样注意力
upcast_attention = True
# 检查 SD XL 基础模型的关键名称是否在检查点中
elif key_name_sd_xl_base in checkpoint:
# 只有基础 XL 模型有两个文本嵌入器
# 检查 config_files 是否存在并包含 "xl"
if config_files is not None and "xl" in config_files:
# 设置原始配置文件为 xl 的配置文件
original_config_file = config_files["xl"]
else:
# 设置配置 URL 为 XL 基础模型的 YAML 文件
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
# 检查 SD XL 精修模型的关键名称是否在检查点中
elif key_name_sd_xl_refiner in checkpoint:
# 只有精修 XL 模型有嵌入器和一个文本嵌入器
# 检查 config_files 是否存在并包含 "xl_refiner"
if config_files is not None and "xl_refiner" in config_files:
# 设置原始配置文件为 xl_refiner 的配置文件
original_config_file = config_files["xl_refiner"]
else:
# 设置配置 URL 为 XL 精修模型的 YAML 文件
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
# 如果是上采样,设置相应的配置 URL
if is_upscale:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
# 如果配置 URL 不为 None
if config_url is not None:
# 将原始配置文件设置为从 URL 获取的内容
original_config_file = BytesIO(requests.get(config_url).content)
else:
# 打开原始配置文件并读取内容
with open(original_config_file, "r") as f:
original_config_file = f.read()
else:
# 如果原始配置文件不为 None,直接打开并读取内容
with open(original_config_file, "r") as f:
original_config_file = f.read()
# 使用 yaml 库安全加载原始配置文件
original_config = yaml.safe_load(original_config_file)
# 转换文本模型。
if (
model_type is None
and "cond_stage_config" in original_config["model"]["params"]
and original_config["model"]["params"]["cond_stage_config"] is not None
):
# 从原始配置中获取模型类型,并从字符串中提取最后一部分
model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
# 记录调试信息,显示推断出的模型类型
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
# 如果模型类型为 None 且网络配置不为空
elif model_type is None and original_config["model"]["params"]["network_config"] is not None:
# 检查上下文维度是否为 2048,并根据其值设置模型类型
if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048:
model_type = "SDXL" # 设置模型类型为 SDXL
else:
model_type = "SDXL-Refiner" # 设置模型类型为 SDXL-Refiner
# 如果图像大小为 None,则默认设置为 1024
if image_size is None:
image_size = 1024
# 如果管道类为 None
if pipeline_class is None:
# 检查当前模型类型,初始化默认管道
if model_type not in ["SDXL", "SDXL-Refiner"]:
# 根据控制网络的状态选择合适的管道类
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
else:
# 根据模型类型选择 SDXL 管道或 SDXL Img2Img 管道
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
# 如果输入通道数量为 None,且管道类在给定列表中
if num_in_channels is None and pipeline_class in [
StableDiffusionInpaintPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLControlNetInpaintPipeline,
]:
num_in_channels = 9 # 设置输入通道数量为 9
# 如果输入通道数量为 None 且管道类为超分辨率管道
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7 # 设置输入通道数量为 7
# 如果输入通道数量仍为 None
elif num_in_channels is None:
num_in_channels = 4 # 设置输入通道数量为 4
# 如果原始配置中包含 "unet_config"
if "unet_config" in original_config["model"]["params"]:
# 设置 U-Net 配置中的输入通道数量
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
# 如果原始配置中包含 "network_config"
elif "network_config" in original_config["model"]["params"]:
# 设置网络配置中的输入通道数量
original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels
# 如果原始配置中包含 "parameterization" 且其值为 "v"
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
# 如果预测类型为 None
if prediction_type is None:
# 记录提示信息,建议使用 "epsilon" 作为预测类型
# 因为此处依赖于一个不稳定的全局步骤参数
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
# 如果图像大小为 None
if image_size is None:
# 记录提示信息,建议设置图像大小为 512
# 因为此处依赖于一个不稳定的全局步骤参数
image_size = 512 if global_step == 875000 else 768
else:
# 如果预测类型为 None
if prediction_type is None:
prediction_type = "epsilon" # 设置默认预测类型为 "epsilon"
# 如果图像大小为 None
if image_size is None:
image_size = 512 # 设置默认图像大小为 512
# 如果控制网络为 None 且原始配置中包含 "control_stage_config"
if controlnet is None and "control_stage_config" in original_config["model"]["params"]:
# 根据检查点路径或字典设置路径
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
# 转换控制网络检查点
controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
)
# 如果原始配置中包含 "timesteps"
if "timesteps" in original_config["model"]["params"]:
# 从原始配置中获取训练时间步数
num_train_timesteps = original_config["model"]["params"]["timesteps"]
else:
# 如果不是特定模型类型,则将训练时间步数设置为 1000
num_train_timesteps = 1000
# 检查模型类型是否为 SDXL 或 SDXL-Refiner
if model_type in ["SDXL", "SDXL-Refiner"]:
# 定义调度器的参数字典
scheduler_dict = {
"beta_schedule": "scaled_linear", # 设置 beta 调度为 scaled_linear
"beta_start": 0.00085, # 设置 beta 的起始值
"beta_end": 0.012, # 设置 beta 的结束值
"interpolation_type": "linear", # 设置插值类型为线性
"num_train_timesteps": num_train_timesteps, # 使用之前定义的训练时间步数
"prediction_type": "epsilon", # 设置预测类型为 epsilon
"sample_max_value": 1.0, # 设置采样的最大值
"set_alpha_to_one": False, # 不将 alpha 设置为 1
"skip_prk_steps": True, # 启用跳过 PRK 步骤
"steps_offset": 1, # 设置步骤偏移量
"timestep_spacing": "leading", # 设置时间步间隔为 leading
}
# 从配置字典中创建调度器
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
# 将调度器类型设置为 euler
scheduler_type = "euler"
else:
# 如果不是上述模型类型,检查 original_config 中是否包含 linear_start
if "linear_start" in original_config["model"]["params"]:
# 如果存在,则从配置中获取 beta_start
beta_start = original_config["model"]["params"]["linear_start"]
else:
# 否则设置 beta_start 为 0.02
beta_start = 0.02
# 检查 original_config 中是否包含 linear_end
if "linear_end" in original_config["model"]["params"]:
# 如果存在,则从配置中获取 beta_end
beta_end = original_config["model"]["params"]["linear_end"]
else:
# 否则设置 beta_end 为 0.085
beta_end = 0.085
# 创建 DDIM 调度器,并传入相应参数
scheduler = DDIMScheduler(
beta_end=beta_end, # 使用之前设置的 beta_end
beta_schedule="scaled_linear", # 设置 beta 调度为 scaled_linear
beta_start=beta_start, # 使用之前设置的 beta_start
num_train_timesteps=num_train_timesteps, # 使用训练时间步数
steps_offset=1, # 设置步骤偏移量
clip_sample=False, # 不剪裁样本
set_alpha_to_one=False, # 不将 alpha 设置为 1
prediction_type=prediction_type, # 使用预测类型
)
# 确保调度器与 DDIM 正常工作
scheduler.register_to_config(clip_sample=False)
# 根据调度器类型创建相应的调度器
if scheduler_type == "pndm":
# 从调度器的配置字典中创建新的 PNDM 调度器
config = dict(scheduler.config)
config["skip_prk_steps"] = True # 启用跳过 PRK 步骤
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
# 从配置中创建 LMS 调度器
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
# 从配置中创建 Heun 调度器
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
# 从配置中创建 Euler 调度器
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
# 从配置中创建 Euler Ancestral 调度器
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
# 从配置中创建 DPM 调度器
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
# 如果是 DDIM 调度器,则直接使用现有调度器
scheduler = scheduler
else:
# 如果调度器类型不存在,抛出错误
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
# 如果使用的是 StableDiffusionUpscalePipeline,获取图像大小
if pipeline_class == StableDiffusionUpscalePipeline:
# 从配置中获取 UNet 的图像大小
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
# 转换 UNet2DConditionModel 模型
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
# 设置 upcast_attention 参数
unet_config["upcast_attention"] = upcast_attention
# 检查点路径或字典,如果是字符串,则使用该路径
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
# 转换 LDM UNet 检查点
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, # 传入检查点
unet_config, # 传入 UNet 配置
path=path, # 传入路径
extract_ema=extract_ema # 是否提取 EMA
)
# 根据是否可用加速器,初始化上下文环境
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# 使用上下文管理器创建 UNet2DConditionModel 实例
with ctx():
unet = UNet2DConditionModel(**unet_config)
# 如果可用加速器
if is_accelerate_available():
# 检查模型类型是否为 SDXL 或 SDXL-Refiner
if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this.
# 遍历转换后的 UNet 检查点中的参数
for param_name, param in converted_unet_checkpoint.items():
# 将模块参数设置到指定设备上
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
else:
# 从检查点加载 UNet 状态字典
unet.load_state_dict(converted_unet_checkpoint)
# 转换 VAE 模型
if vae_path is None and vae is None:
# 创建 VAE 配置
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
# 转换 LDM VAE 检查点
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
# 检查配置中是否存在 scale_factor 参数
if (
"model" in original_config
and "params" in original_config["model"]
and "scale_factor" in original_config["model"]["params"]
):
# 获取 VAE 缩放因子
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
else:
# 默认 SD 缩放因子
vae_scaling_factor = 0.18215 # default SD scaling factor
# 更新 VAE 配置中的缩放因子
vae_config["scaling_factor"] = vae_scaling_factor
# 根据加速器可用性初始化上下文
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# 使用上下文管理器创建 AutoencoderKL 实例
with ctx():
vae = AutoencoderKL(**vae_config)
# 如果可用加速器
if is_accelerate_available():
# 遍历转换后的 VAE 检查点中的参数
for param_name, param in converted_vae_checkpoint.items():
# 将模块参数设置到指定设备上
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
# 从检查点加载 VAE 状态字典
vae.load_state_dict(converted_vae_checkpoint)
# 如果 VAE 为 None,但 VAE 路径存在
elif vae is None:
# 从预训练模型加载 VAE
vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only)
# 如果模型类型为 PaintByExample
elif model_type == "PaintByExample":
# 转换 PaintByExample 检查点
vision_model = convert_paint_by_example_checkpoint(checkpoint)
# 尝试加载 CLIPTokenizer
try:
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
except Exception:
# 抛出错误提示本地文件必须保存
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
# 尝试加载 AutoFeatureExtractor
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
except Exception:
# 抛出错误提示本地文件必须保存
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
)
# 创建 PaintByExamplePipeline 实例
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
# 检查模型类型是否为 "FrozenCLIPEmbedder"
elif model_type == "FrozenCLIPEmbedder":
# 将 LDM CLIP 检查点转换为文本模型
text_model = convert_ldm_clip_checkpoint(
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
)
# 尝试加载 CLIP 分词器
try:
tokenizer = (
# 从预训练模型中加载 CLIP 分词器,如果 tokenizer 为 None 则加载
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if tokenizer is None
else tokenizer
)
# 捕获加载分词器时的异常
except Exception:
raise ValueError(
# 抛出错误,提示必须先本地保存分词器
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
# 如果需要加载安全检查器
if load_safety_checker:
# 从预训练模型中加载稳定扩散安全检查器
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
# 从预训练模型中加载特征提取器
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
# 如果启用 ControlNet
if controlnet:
# 创建包含 ControlNet 的管道
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
# 创建不包含 ControlNet 的管道
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# 处理其他模型类型
else:
# 创建 LDM BERT 配置
text_config = create_ldm_bert_config(original_config)
# 将 LDM BERT 检查点转换为文本模型
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
# 从预训练模型中加载 BERT 分词器
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only)
# 创建 LDM 文本到图像的管道
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
# 返回创建的管道
return pipe
# 下载并控制原始检查点,返回 DiffusionPipeline 对象
def download_controlnet_from_original_ckpt(
# 检查点文件路径
checkpoint_path: str,
# 原始配置文件路径
original_config_file: str,
# 图像尺寸,默认512
image_size: int = 512,
# 是否提取 EMA 权重,默认 False
extract_ema: bool = False,
# 输入通道数,默认为 None
num_in_channels: Optional[int] = None,
# 是否上溯注意力,默认为 None
upcast_attention: Optional[bool] = None,
# 设备类型,默认为 None
device: str = None,
# 是否使用 safetensors 格式,默认为 False
from_safetensors: bool = False,
# 是否使用线性投影,默认为 None
use_linear_projection: Optional[bool] = None,
# 跨注意力维度,默认为 None
cross_attention_dim: Optional[bool] = None,
) -> DiffusionPipeline:
# 如果使用 safetensors 格式
if from_safetensors:
# 导入 safe_open 函数
from safetensors import safe_open
# 初始化检查点字典
checkpoint = {}
# 打开 safetensors 文件,使用 PyTorch 设备
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
# 遍历文件中的所有键
for key in f.keys():
# 获取张量并存储到检查点字典
checkpoint[key] = f.get_tensor(key)
else:
# 如果设备未指定
if device is None:
# 根据 CUDA 可用性选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 从指定路径加载检查点
checkpoint = torch.load(checkpoint_path, map_location=device)
else:
# 使用指定设备加载检查点
checkpoint = torch.load(checkpoint_path, map_location=device)
# 注:此 while 循环用于处理控制点检查点的 "state_dict" 键
while "state_dict" in checkpoint:
# 更新检查点为其状态字典
checkpoint = checkpoint["state_dict"]
# 打开原始配置文件进行读取
with open(original_config_file, "r") as f:
# 读取文件内容
original_config_file = f.read()
# 使用 YAML 加载原始配置
original_config = yaml.safe_load(original_config_file)
# 如果指定输入通道数
if num_in_channels is not None:
# 更新原始配置中的输入通道数
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
# 如果原始配置中不存在控制阶段配置
if "control_stage_config" not in original_config["model"]["params"]:
# 抛出值错误
raise ValueError("`control_stage_config` not present in original config")
# 转换控制点检查点为控制网络对象
controlnet = convert_controlnet_checkpoint(
checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=use_linear_projection,
cross_attention_dim=cross_attention_dim,
)
# 返回转换后的控制网络对象
return controlnet
.\diffusers\pipelines\stable_diffusion\pipeline_flax_stable_diffusion.py
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)进行授权;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下地址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有约定,
# 根据许可证分发的软件均按“原样”提供,
# 不提供任何形式的明示或暗示的保证或条件。
# 请参见许可证以获取管理权限的特定语言和
# 限制条件。
import warnings # 导入警告模块,用于显示警告信息
from functools import partial # 从functools导入partial,用于部分函数应用
from typing import Dict, List, Optional, Union # 导入类型提示相关的类
import jax # 导入jax库,用于高效的数值计算
import jax.numpy as jnp # 导入jax的numpy模块
import numpy as np # 导入numpy库
from flax.core.frozen_dict import FrozenDict # 从flax导入FrozenDict,用于不可变字典
from flax.jax_utils import unreplicate # 导入unreplicate,用于将数据从多个设备上收集回单个设备
from flax.training.common_utils import shard # 从flax导入shard,用于数据分片
from packaging import version # 导入version,用于处理版本号
from PIL import Image # 从PIL导入Image类,用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel # 导入Transformers库中与CLIP相关的类
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel # 从相对路径导入FlaxAutoencoderKL和FlaxUNet2DConditionModel
from ...schedulers import ( # 从相对路径导入多个调度器
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import deprecate, logging, replace_example_docstring # 从相对路径导入实用工具函数
from ..pipeline_flax_utils import FlaxDiffusionPipeline # 从上级模块导入FlaxDiffusionPipeline
from .pipeline_output import FlaxStableDiffusionPipelineOutput # 从当前模块导入FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker # 从当前模块导入FlaxStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # 创建一个记录器实例,用于记录模块的日志信息
# 设置为True时使用Python循环而不是jax.fori_loop,以便更容易调试
DEBUG = False # 定义DEBUG常量,初始值为False
EXAMPLE_DOC_STRING = """ # 定义示例文档字符串,包含使用示例
Examples:
```py
>>> import jax # 导入jax库
>>> import numpy as np # 导入numpy库
>>> from flax.jax_utils import replicate # 从flax导入replicate,用于数据复制
>>> from flax.training.common_utils import shard # 从flax导入shard,用于数据分片
>>> from diffusers import FlaxStableDiffusionPipeline # 从diffusers库导入FlaxStableDiffusionPipeline
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( # 从预训练模型加载管道和参数
... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
... )
>>> prompt = "a photo of an astronaut riding a horse on mars" # 定义生成图像的提示文本
>>> prng_seed = jax.random.PRNGKey(0) # 创建一个随机数生成器的种子
>>> num_inference_steps = 50 # 定义推理的步数
>>> num_samples = jax.device_count() # 获取可用设备的数量
>>> prompt = num_samples * [prompt] # 为每个设备创建相同的提示文本列表
>>> prompt_ids = pipeline.prepare_inputs(prompt) # 准备输入的提示文本ID
# shard inputs and rng # 注释,说明将输入和随机数分片
>>> params = replicate(params) # 复制参数到所有设备
>>> prng_seed = jax.random.split(prng_seed, jax.device_count()) # 将随机种子分割到每个设备
>>> prompt_ids = shard(prompt_ids) # 将提示文本ID分片到每个设备
>>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images # 生成图像
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) # 将生成的图像转换为PIL格式
```py
""" # 示例文档字符串的结束
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # 定义FlaxStableDiffusionPipeline类,继承自FlaxDiffusionPipeline
r""" # 定义类文档字符串
Flax-based pipeline for text-to-image generation using Stable Diffusion. # 描述此类的功能
# 该模型继承自 [`FlaxDiffusionPipeline`]。请查看父类文档以获取所有管道实现的通用方法(如下载、保存、在特定设备上运行等)。
# 参数说明:
# vae ([`FlaxAutoencoderKL`]):
# 变分自编码器(VAE)模型,用于将图像编码和解码为潜在表示。
# text_encoder ([`~transformers.FlaxCLIPTextModel`]):
# 冻结的文本编码器 ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
# tokenizer ([`~transformers.CLIPTokenizer`]):
# 用于文本分词的 `CLIPTokenizer`。
# unet ([`FlaxUNet2DConditionModel`]):
# 用于对编码图像潜在值进行去噪的 `FlaxUNet2DConditionModel`。
# scheduler ([`SchedulerMixin`]):
# 用于与 `unet` 结合使用以去噪编码图像潜在值的调度器。可以是
# [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
# [`FlaxDPMSolverMultistepScheduler`] 之一。
# safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
# 分类模块,估计生成的图像是否可能被认为是冒犯或有害的。
# 有关模型潜在危害的更多详细信息,请参阅 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5)。
# feature_extractor ([`~transformers.CLIPImageProcessor`]):
# 从生成图像中提取特征的 `CLIPImageProcessor`;用作 `safety_checker` 的输入。
# dtype: jnp.dtype = jnp.float32, # 默认数据类型为 jnp.float32
def __init__( # 初始化方法,用于创建类的实例
self, # 实例自身
vae: FlaxAutoencoderKL, # 变分自编码器实例
text_encoder: FlaxCLIPTextModel, # 文本编码器实例
tokenizer: CLIPTokenizer, # 分词器实例
unet: FlaxUNet2DConditionModel, # UNet去噪模型实例
scheduler: Union[ # 可选调度器,支持多种类型
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker, # 安全检查器实例
feature_extractor: CLIPImageProcessor, # 特征提取器实例
dtype: jnp.dtype = jnp.float32, # 数据类型参数,默认值为 jnp.float32
# 初始化方法,调用父类构造函数
):
super().__init__()
# 设置数据类型
self.dtype = dtype
# 检查安全检查器是否为 None
if safety_checker is None:
# 记录警告信息,提醒用户安全检查器已禁用
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# 检查 UNet 版本是否低于 0.9.0
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
# 检查 UNet 采样大小是否小于 64
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
# 如果版本和采样大小都不符合要求,给出弃用警告
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
# 创建弃用信息,提示用户修改配置文件
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
# 调用弃用函数,记录弃用警告
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
# 创建新的配置字典并将采样大小设为 64
new_config = dict(unet.config)
new_config["sample_size"] = 64
# 更新 UNet 内部字典
unet._internal_dict = FrozenDict(new_config)
# 注册模块,将各个组件关联起来
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# 计算 VAE 的缩放因子
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# 准备输入,接受字符串或字符串列表
def prepare_inputs(self, prompt: Union[str, List[str]]):
# 检查 prompt 是否为字符串或列表类型,若不是则抛出异常
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 使用分词器处理 prompt,生成带有填充和截断的张量
text_input = self.tokenizer(
prompt,
padding="max_length", # 填充至最大长度
max_length=self.tokenizer.model_max_length, # 使用模型最大长度
truncation=True, # 启用截断
return_tensors="np", # 返回 NumPy 格式的张量
)
# 返回处理后的输入 ID
return text_input.input_ids
# 获取是否存在 NSFW 概念
def _get_has_nsfw_concepts(self, features, params):
# 使用安全检查器处理特征和参数,返回是否存在 NSFW 概念
has_nsfw_concepts = self.safety_checker(features, params)
# 返回检测结果
return has_nsfw_concepts
# 运行安全检查器
def _run_safety_checker(self, images, safety_model_params, jit=False):
# 将输入的图像数组转换为 PIL 图像
pil_images = [Image.fromarray(image) for image in images]
# 提取特征,返回 NumPy 格式的张量
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
# 如果启用 JIT,则对特征进行分片处理
if jit:
features = shard(features) # 分片特征
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) # 检查 NSFW 概念
has_nsfw_concepts = unshard(has_nsfw_concepts) # 合并分片结果
safety_model_params = unreplicate(safety_model_params) # 取消复制安全模型参数
else:
# 直接获取 NSFW 概念
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
images_was_copied = False # 标记图像是否已复制
# 遍历每个图像的 NSFW 概念检测结果
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept: # 如果检测到 NSFW 概念
if not images_was_copied: # 如果图像还没有复制
images_was_copied = True # 标记为已复制
images = images.copy() # 复制图像数组
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # 替换为黑色图像
# 如果有任何 NSFW 概念
if any(has_nsfw_concepts):
# 发出警告,提示可能检测到不适合的内容
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
# 返回处理后的图像和 NSFW 概念检测结果
return images, has_nsfw_concepts
# 生成图像的主函数
def _generate(
self,
prompt_ids: jnp.array, # 输入的提示 ID
params: Union[Dict, FrozenDict], # 模型参数
prng_seed: jax.Array, # 随机种子
num_inference_steps: int, # 推理步骤数
height: int, # 生成图像的高度
width: int, # 生成图像的宽度
guidance_scale: float, # 引导尺度
latents: Optional[jnp.ndarray] = None, # 可选的潜在变量
neg_prompt_ids: Optional[jnp.ndarray] = None, # 可选的负提示 ID
@replace_example_docstring(EXAMPLE_DOC_STRING) # 替换示例文档字符串的装饰器
def __call__( # 定义调用方法
self,
prompt_ids: jnp.array, # 输入的提示 ID
params: Union[Dict, FrozenDict], # 模型参数
prng_seed: jax.Array, # 随机种子
num_inference_steps: int = 50, # 推理步骤数,默认为 50
height: Optional[int] = None, # 可选的图像高度
width: Optional[int] = None, # 可选的图像宽度
guidance_scale: Union[float, jnp.ndarray] = 7.5, # 引导尺度,默认为 7.5
latents: jnp.ndarray = None, # 可选的潜在变量
neg_prompt_ids: jnp.ndarray = None, # 可选的负提示 ID
return_dict: bool = True, # 是否返回字典格式的结果,默认为 True
jit: bool = False, # 是否启用 JIT 编译
# 静态参数包括管道、推理步数、高度和宽度。任何更改都会触发重新编译。
# 非静态参数是映射在其第一维上的(分片)输入张量(因此为 `0`)。
@partial(
jax.pmap, # 应用并行映射以支持多设备计算
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), # 指定输入张量的维度映射
static_broadcasted_argnums=(0, 4, 5, 6), # 静态广播参数的索引
)
def _p_generate(
pipe, # 生成管道
prompt_ids, # 提示的 ID 列表
params, # 模型参数
prng_seed, # 随机数生成种子
num_inference_steps, # 推理步骤的数量
height, # 输出图像的高度
width, # 输出图像的宽度
guidance_scale, # 引导尺度,用于控制生成的效果
latents, # 潜在向量
neg_prompt_ids, # 负提示的 ID 列表
):
# 调用生成管道的方法以生成输出
return pipe._generate(
prompt_ids, # 提示 ID
params, # 模型参数
prng_seed, # 随机种子
num_inference_steps, # 推理步骤
height, # 高度
width, # 宽度
guidance_scale, # 引导尺度
latents, # 潜在向量
neg_prompt_ids, # 负提示 ID
)
@partial(jax.pmap, static_broadcasted_argnums=(0,)) # 应用并行映射,静态广播第一个参数
def _p_get_has_nsfw_concepts(pipe, features, params):
# 调用管道方法以检查是否有不适宜内容概念
return pipe._get_has_nsfw_concepts(features, params)
def unshard(x: jnp.ndarray):
# 使用 einops 对输入进行重排,将设备和批次维度合并
num_devices, batch_size = x.shape[:2] # 获取设备数和批次大小
rest = x.shape[2:] # 获取剩余维度
# 重塑张量,使得设备和批次维度合并
return x.reshape(num_devices * batch_size, *rest)