diffusers-源码解析-四十三-

diffusers 源码解析(四十三)

.\diffusers\pipelines\shap_e\pipeline_shap_e_img2img.py

# 版权声明,说明该文件的版权所有者及其权利
# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证,版本 2.0("许可证")进行许可;
# 除非遵守许可证,否则不得使用此文件。
# 可在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,否则根据许可证分发的软件
# 是在“按原样”基础上分发的,没有任何明示或暗示的担保或条件。
# 有关许可证的具体权限和限制,请参见许可证。
#
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入 List、Optional 和 Union 类型
from typing import List, Optional, Union

# 导入 numpy 库并简化为 np
import numpy as np
# 导入 PIL.Image 模块以处理图像
import PIL.Image
# 导入 torch 库用于深度学习
import torch
# 从 transformers 库中导入 CLIP 图像处理器和 CLIP 视觉模型
from transformers import CLIPImageProcessor, CLIPVisionModel

# 从本地模型导入 PriorTransformer 类
from ...models import PriorTransformer
# 从调度器模块导入 HeunDiscreteScheduler 类
from ...schedulers import HeunDiscreteScheduler
# 从 utils 模块导入多个实用工具
from ...utils import (
    BaseOutput,         # 基类输出
    logging,           # 日志记录
    replace_example_docstring,  # 替换示例文档字符串的函数
)
# 从 torch_utils 模块导入随机张量生成函数
from ...utils.torch_utils import randn_tensor
# 从 pipeline_utils 模块导入 DiffusionPipeline 类
from ..pipeline_utils import DiffusionPipeline
# 从 renderer 模块导入 ShapERenderer 类
from .renderer import ShapERenderer

# 创建一个日志记录器,用于记录当前模块的日志
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 示例文档字符串,展示了如何使用 DiffusionPipeline
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> from PIL import Image  # 导入 PIL 图像处理库
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import DiffusionPipeline  # 导入 DiffusionPipeline 类
        >>> from diffusers.utils import export_to_gif, load_image  # 导入实用函数

        >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检测可用设备(CUDA或CPU)

        >>> repo = "openai/shap-e-img2img"  # 定义模型的仓库名称
        >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)  # 从预训练模型加载管道
        >>> pipe = pipe.to(device)  # 将管道转移到指定设备上

        >>> guidance_scale = 3.0  # 设置引导比例
        >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png"  # 定义图像 URL
        >>> image = load_image(image_url).convert("RGB")  # 加载并转换图像为 RGB 格式

        >>> images = pipe(  # 调用管道生成图像
        ...     image,  # 输入图像
        ...     guidance_scale=guidance_scale,  # 使用的引导比例
        ...     num_inference_steps=64,  # 推理步骤数量
        ...     frame_size=256,  # 设置帧大小
        ... ).images  # 获取生成的图像列表

        >>> gif_path = export_to_gif(images[0], "corgi_3d.gif")  # 将生成的第一张图像导出为 GIF
        ```py
"""

# 定义 ShapEPipelineOutput 类,继承自 BaseOutput,用于表示输出
@dataclass
class ShapEPipelineOutput(BaseOutput):
    """
    [`ShapEPipeline`] 和 [`ShapEImg2ImgPipeline`] 的输出类。

    Args:
        images (`torch.Tensor`)  # 图像的张量列表,用于 3D 渲染
            A list of images for 3D rendering.
    """

    # 定义输出的图像属性,可以是 PIL.Image.Image 或 np.ndarray 类型
    images: Union[PIL.Image.Image, np.ndarray]


# 定义 ShapEImg2ImgPipeline 类,继承自 DiffusionPipeline
class ShapEImg2ImgPipeline(DiffusionPipeline):
    """
    从图像生成 3D 资产的潜在表示并使用 NeRF 方法进行渲染的管道。

    此模型继承自 [`DiffusionPipeline`]。请查看超类文档以了解所有管道实现的通用方法
    (下载、保存、在特定设备上运行等)。
    # 参数说明
    Args:
        prior ([`PriorTransformer`]):
            用于近似文本嵌入生成图像嵌入的标准 unCLIP 先验。
        image_encoder ([`~transformers.CLIPVisionModel`]):
            冻结的图像编码器。
        image_processor ([`~transformers.CLIPImageProcessor`]):
             用于处理图像的 `CLIPImageProcessor`。
        scheduler ([`HeunDiscreteScheduler`]):
            与 `prior` 模型结合使用以生成图像嵌入的调度器。
        shap_e_renderer ([`ShapERenderer`]):
            Shap-E 渲染器将生成的潜在变量投影为 MLP 参数,以使用 NeRF 渲染方法创建 3D 对象。
    """

    # 定义模型在 CPU 上卸载的顺序
    model_cpu_offload_seq = "image_encoder->prior"
    # 定义在 CPU 卸载中排除的模块
    _exclude_from_cpu_offload = ["shap_e_renderer"]

    # 初始化方法
    def __init__(
        self,
        prior: PriorTransformer,  # 定义 prior 参数
        image_encoder: CLIPVisionModel,  # 定义图像编码器参数
        image_processor: CLIPImageProcessor,  # 定义图像处理器参数
        scheduler: HeunDiscreteScheduler,  # 定义调度器参数
        shap_e_renderer: ShapERenderer,  # 定义 Shap-E 渲染器参数
    ):
        super().__init__()  # 调用父类初始化方法

        # 注册模块以供后续使用
        self.register_modules(
            prior=prior,  # 注册 prior 模块
            image_encoder=image_encoder,  # 注册图像编码器模块
            image_processor=image_processor,  # 注册图像处理器模块
            scheduler=scheduler,  # 注册调度器模块
            shap_e_renderer=shap_e_renderer,  # 注册 Shap-E 渲染器模块
        )

    # 从 diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents 复制的准备潜在变量方法
    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
        # 检查潜在变量是否为 None
        if latents is None:
            # 生成随机潜在变量
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 检查潜在变量形状是否与预期一致
            if latents.shape != shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
            # 将潜在变量移动到指定设备
            latents = latents.to(device)

        # 将潜在变量与调度器的初始噪声标准差相乘
        latents = latents * scheduler.init_noise_sigma
        # 返回处理后的潜在变量
        return latents

    # 定义编码图像的方法
    def _encode_image(
        self,
        image,  # 输入图像
        device,  # 指定设备
        num_images_per_prompt,  # 每个提示的图像数量
        do_classifier_free_guidance,  # 是否进行无分类器引导
    ):
        # 检查输入的 image 是否是一个列表且列表的第一个元素是 torch.Tensor
        if isinstance(image, List) and isinstance(image[0], torch.Tensor):
            # 如果第一个元素的维度是4,使用 torch.cat 沿着第0维连接,否则使用 torch.stack 叠加
            image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)

        # 如果 image 不是 torch.Tensor 类型,进行处理
        if not isinstance(image, torch.Tensor):
            # 使用图像处理器处理 image,将结果转换为 PyTorch 张量,并提取 pixel_values 的第一个元素
            image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0)

        # 将 image 转换为指定的数据类型和设备(CPU或GPU)
        image = image.to(dtype=self.image_encoder.dtype, device=device)

        # 使用图像编码器对图像进行编码,获取最后隐藏状态
        image_embeds = self.image_encoder(image)["last_hidden_state"]
        # 取出最后隐藏状态的切片,忽略第一个维度,并确保内存连续
        image_embeds = image_embeds[:, 1:, :].contiguous()  # batch_size, dim, 256

        # 对图像嵌入进行重复扩展,以适应每个提示的图像数量
        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

        # 如果需要进行无分类器引导
        if do_classifier_free_guidance:
            # 创建与 image_embeds 形状相同的零张量作为负样本嵌入
            negative_image_embeds = torch.zeros_like(image_embeds)

            # 对于无分类器引导,我们需要进行两次前向传播
            # 在这里将无条件和文本嵌入拼接成一个批次,以避免进行两次前向传播
            image_embeds = torch.cat([negative_image_embeds, image_embeds])

        # 返回最终的图像嵌入
        return image_embeds

    # 禁用梯度计算,以提高推理速度并节省内存
    @torch.no_grad()
    # 替换示例文档字符串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        # 定义可调用函数的输入参数,支持 PIL 图像或图像列表
        image: Union[PIL.Image.Image, List[PIL.Image.Image]],
        # 每个提示生成的图像数量,默认为1
        num_images_per_prompt: int = 1,
        # 推理步骤的数量,默认为25
        num_inference_steps: int = 25,
        # 随机数生成器,可以是单个生成器或生成器列表
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潜在变量的张量,默认为 None
        latents: Optional[torch.Tensor] = None,
        # 引导比例,默认为4.0
        guidance_scale: float = 4.0,
        # 帧的大小,默认为64
        frame_size: int = 64,
        # 输出类型,可选:'pil'、'np'、'latent'、'mesh'
        output_type: Optional[str] = "pil",  # pil, np, latent, mesh
        # 是否返回字典格式的结果,默认为 True
        return_dict: bool = True,

.\diffusers\pipelines\shap_e\renderer.py

# 版权信息,声明此代码的版权归 Open AI 和 HuggingFace 团队所有
# 
# 根据 Apache License 2.0 版本(“许可证”)进行许可;
# 您不得在不遵守许可证的情况下使用此文件。
# 您可以在以下网址获得许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,根据许可证分发的软件是在“按现状”基础上分发的,
# 不附带任何明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和限制的具体说明。

# 导入数学模块
import math
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入字典、可选类型和元组
from typing import Dict, Optional, Tuple

# 导入 numpy 库
import numpy as np
# 导入 torch 库
import torch
# 导入 PyTorch 的功能模块
import torch.nn.functional as F
# 从 torch 库导入 nn 模块
from torch import nn

# 从配置工具导入 ConfigMixin 和 register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
# 从模型模块导入 ModelMixin
from ...models import ModelMixin
# 从工具模块导入 BaseOutput
from ...utils import BaseOutput
# 从当前目录的 camera 模块导入 create_pan_cameras 函数
from .camera import create_pan_cameras


def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
    r"""
    从给定的离散概率分布中进行有替换的采样。

    假定第 i 个箱子的质量为 pmf[i]。

    参数:
        pmf: [batch_size, *shape, n_samples, 1],其中 (pmf.sum(dim=-2) == 1).all()
        n_samples: 采样的数量

    返回:
        用替换方式采样的索引
    """

    # 获取 pmf 的形状,并提取支持大小和最后一个维度
    *shape, support_size, last_dim = pmf.shape
    # 确保最后一个维度为 1
    assert last_dim == 1

    # 计算 pmf 的累积分布函数(CDF)
    cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
    # 在 CDF 中查找随机数的索引
    inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))

    # 返回形状调整后的索引,并限制在有效范围内
    return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)


def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
    """
    按照 NeRF 的方式将 x 及其位置编码进行连接。

    参考文献: https://arxiv.org/pdf/2210.04628.pdf
    """
    # 如果最小和最大角度相同,则直接返回 x
    if min_deg == max_deg:
        return x

    # 生成尺度,范围为 [min_deg, max_deg)
    scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device)
    # 获取 x 的形状和维度
    *shape, dim = x.shape
    # 将 x 重新形状并与尺度相乘,然后调整形状
    xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
    # 确保 xb 的最后一个维度与预期相符
    assert xb.shape[-1] == dim * (max_deg - min_deg)
    # 计算正弦值并进行连接
    emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
    # 返回原始 x 和位置编码的连接
    return torch.cat([x, emb], dim=-1)


def encode_position(position):
    # 使用 posenc_nerf 函数对位置进行编码
    return posenc_nerf(position, min_deg=0, max_deg=15)


def encode_direction(position, direction=None):
    # 如果未提供方向,则返回与位置编码相同形状的零张量
    if direction is None:
        return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
    else:
        # 使用 posenc_nerf 函数对方向进行编码
        return posenc_nerf(direction, min_deg=0, max_deg=8)


def _sanitize_name(x: str) -> str:
    # 替换字符串中的点为双下划线
    return x.replace(".", "__")


def integrate_samples(volume_range, ts, density, channels):
    r"""
    集成模型输出的函数。

    参数:
        volume_range: 指定积分范围 [t0, t1]
        ts: 时间步
        density: torch.Tensor [batch_size, *shape, n_samples, 1]
        channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
    # 返回值说明
    returns:
        # channels: 集成的 RGB 输出权重,类型为 torch.Tensor,形状为 [batch_size, *shape, n_samples, 1] 
        # (density * transmittance)[i] 表示在 [...] 中每个 RGB 输出的权重 
        # transmittance: 表示此体积的透射率
    )

    # 1. 计算权重
    # 使用 volume_range 对象的 partition 方法划分 ts,得到三个输出值,前两个被忽略
    _, _, dt = volume_range.partition(ts)
    # 计算密度与时间间隔 dt 的乘积,得到每个体素的密度变化
    ddensity = density * dt

    # 对 ddensity 进行累加,计算质量随深度的变化
    mass = torch.cumsum(ddensity, dim=-2)
    # 计算体积的透射率,使用指数衰减公式
    transmittance = torch.exp(-mass[..., -1, :])

    # 计算 alpha 值,表示每个体素的透明度
    alphas = 1.0 - torch.exp(-ddensity)
    # 计算 T 值,表示光通过每个体素的概率,使用累积质量
    Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
    # 这是光在深度 [..., i, :] 上击中并反射的概率
    weights = alphas * Ts

    # 2. 集成通道
    # 计算每个通道的加权和,得到最终的 RGB 输出
    channels = torch.sum(channels * weights, dim=-2)

    # 返回计算得到的通道、权重和透射率
    return channels, weights, transmittance
# 定义一个函数,查询体积内的点坐标
def volume_query_points(volume, grid_size):
    # 创建一个张量,包含从0到grid_size^3-1的索引,设备为volume的最小边界设备
    indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
    # 计算每个索引在grid_size维度上的z坐标
    zs = indices % grid_size
    # 计算每个索引在grid_size维度上的y坐标
    ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
    # 计算每个索引在grid_size维度上的x坐标
    xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
    # 将x, y, z坐标组合成一个张量,维度为(数量, 3)
    combined = torch.stack([xs, ys, zs], dim=1)
    # 归一化坐标并转换为体积的坐标范围
    return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min


# 定义一个函数,将sRGB颜色值转换为线性颜色值
def _convert_srgb_to_linear(u: torch.Tensor):
    # 使用条件语句,按照sRGB到线性空间的转换公式进行转换
    return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)


# 定义一个函数,创建平面边缘索引
def _create_flat_edge_indices(
    flat_cube_indices: torch.Tensor,
    grid_size: Tuple[int, int, int],
):
    # 计算在x方向上的索引数量
    num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
    # 计算y方向的偏移量
    y_offset = num_xs
    # 计算在y方向上的索引数量
    num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
    # 计算z方向的偏移量
    z_offset = num_xs + num_ys
    # 将一组张量堆叠成一个新的张量,指定最后一个维度
    return torch.stack(
        [
            # 表示跨越 x 轴的边
            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]  # 计算 x 轴索引的基值
            + flat_cube_indices[:, 1] * grid_size[2]  # 加上 y 轴索引的偏移
            + flat_cube_indices[:, 2],  # 加上 z 轴索引
            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]  # 计算 x 轴索引的基值
            + (flat_cube_indices[:, 1] + 1) * grid_size[2]  # 加上 y 轴索引偏移(+1)
            + flat_cube_indices[:, 2],  # 加上 z 轴索引
            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]  # 计算 x 轴索引的基值
            + flat_cube_indices[:, 1] * grid_size[2]  # 加上 y 轴索引的偏移
            + flat_cube_indices[:, 2]  # 加上 z 轴索引
            + 1,  # 取下一个 z 轴的索引
            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]  # 计算 x 轴索引的基值
            + (flat_cube_indices[:, 1] + 1) * grid_size[2]  # 加上 y 轴索引偏移(+1)
            + flat_cube_indices[:, 2]  # 加上 z 轴索引
            + 1,  # 取下一个 z 轴的索引
            # 表示跨越 y 轴的边
            (
                y_offset  # y 轴的偏移量
                + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]  # 计算 x 轴索引的基值
                + flat_cube_indices[:, 1] * grid_size[2]  # 加上 y 轴索引的偏移
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
            ),
            (
                y_offset  # y 轴的偏移量
                + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]  # 计算 x 轴索引的基值(+1)
                + flat_cube_indices[:, 1] * grid_size[2]  # 加上 y 轴索引的偏移
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
            ),
            (
                y_offset  # y 轴的偏移量
                + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]  # 计算 x 轴索引的基值
                + flat_cube_indices[:, 1] * grid_size[2]  # 加上 y 轴索引的偏移
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
                + 1  # 取下一个 z 轴的索引
            ),
            (
                y_offset  # y 轴的偏移量
                + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]  # 计算 x 轴索引的基值(+1)
                + flat_cube_indices[:, 1] * grid_size[2]  # 加上 y 轴索引的偏移
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
                + 1  # 取下一个 z 轴的索引
            ),
            # 表示跨越 z 轴的边
            (
                z_offset  # z 轴的偏移量
                + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)  # 计算 x 轴索引的基值
                + flat_cube_indices[:, 1] * (grid_size[2] - 1)  # 加上 y 轴索引的偏移
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
            ),
            (
                z_offset  # z 轴的偏移量
                + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)  # 计算 x 轴索引的基值(+1)
                + flat_cube_indices[:, 1] * (grid_size[2] - 1)  # 加上 y 轴索引的偏移
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
            ),
            (
                z_offset  # z 轴的偏移量
                + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)  # 计算 x 轴索引的基值
                + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)  # 加上 y 轴索引的偏移(+1)
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
            ),
            (
                z_offset  # z 轴的偏移量
                + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)  # 计算 x 轴索引的基值(+1)
                + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)  # 加上 y 轴索引的偏移(+1)
                + flat_cube_indices[:, 2]  # 加上 z 轴索引
            ),
        ],
        dim=-1,  # 指定堆叠的维度
    )
# 定义一个名为 VoidNeRFModel 的类,继承自 nn.Module
class VoidNeRFModel(nn.Module):
    """
    实现默认的空空间模型,所有查询渲染为背景。
    """

    # 初始化方法,接收背景和通道缩放参数
    def __init__(self, background, channel_scale=255.0):
        # 调用父类的初始化方法
        super().__init__()
        # 将背景数据转换为张量并归一化
        background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale)
        # 注册背景为模型的缓冲区
        self.register_buffer("background", background)

    # 前向传播方法,接收位置参数
    def forward(self, position):
        # 将背景张量扩展到与输入位置相同的设备
        background = self.background[None].to(position.device)
        # 获取位置的形状,除去最后一维
        shape = position.shape[:-1]
        # 创建一个与 shape 维度相同的 ones 列表
        ones = [1] * (len(shape) - 1)
        # 获取背景的通道数
        n_channels = background.shape[-1]
        # 将背景张量广播到与位置相同的形状
        background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels])
        # 返回背景张量
        return background


@dataclass
# 定义一个数据类 VolumeRange,包含 t0、t1 和 intersected
class VolumeRange:
    t0: torch.Tensor
    t1: torch.Tensor
    intersected: torch.Tensor

    # 后置初始化方法,检查张量形状是否一致
    def __post_init__(self):
        assert self.t0.shape == self.t1.shape == self.intersected.shape

    # 分区方法,将 t0 和 t1 分成 n_samples 区间
    def partition(self, ts):
        """
        将 t0 和 t1 分区成 n_samples 区间。

        参数:
            ts: [batch_size, *shape, n_samples, 1]

        返回:
            lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size,
            *shape, n_samples, 1]

        其中
            ts \\in [lower, upper] deltas = upper - lower
        """

        # 计算 ts 的中间值
        mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
        # 将 t0 和中间值拼接形成 lower
        lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
        # 将中间值和 t1 拼接形成 upper
        upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
        # 计算 upper 和 lower 之间的差值
        delta = upper - lower
        # 确保 lower、upper 和 delta 的形状一致
        assert lower.shape == upper.shape == delta.shape == ts.shape
        # 返回 lower、upper 和 delta
        return lower, upper, delta


# 定义一个名为 BoundingBoxVolume 的类,继承自 nn.Module
class BoundingBoxVolume(nn.Module):
    """
    由两个对角点定义的轴对齐边界框。
    """

    # 初始化方法,接收边界框的最小和最大角点
    def __init__(
        self,
        *,
        bbox_min,
        bbox_max,
        min_dist: float = 0.0,
        min_t_range: float = 1e-3,
    ):
        """
        参数:
            bbox_min: 边界框的左/底角点
            bbox_max: 边界框的另一角点
            min_dist: 所有光线应至少从该距离开始。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 保存最小距离和最小 t 范围
        self.min_dist = min_dist
        self.min_t_range = min_t_range
        # 将最小和最大边界框角点转换为张量
        self.bbox_min = torch.tensor(bbox_min)
        self.bbox_max = torch.tensor(bbox_max)
        # 堆叠最小和最大角点形成边界框
        self.bbox = torch.stack([self.bbox_min, self.bbox_max])
        # 确保边界框形状正确
        assert self.bbox.shape == (2, 3)
        # 确保最小距离和最小 t 范围有效
        assert min_dist >= 0.0
        assert min_t_range > 0.0

    # 相交方法,接收光线的原点和方向
    def intersect(
        self,
        origin: torch.Tensor,
        direction: torch.Tensor,
        t0_lower: Optional[torch.Tensor] = None,
        epsilon=1e-6,
    # 定义文档字符串,描述函数参数和返回值的格式
        ):
            """
            Args:
                origin: [batch_size, *shape, 3] 原点坐标的张量,表示光线起始点
                direction: [batch_size, *shape, 3] 方向向量的张量,表示光线方向
                t0_lower: Optional [batch_size, *shape, 1] 可选参数,表示相交体积时 t0 的下界
                params: Optional meta parameters in case Volume is parametric 可选元参数,用于参数化体积
                epsilon: to stabilize calculations 计算时用于稳定的小常数
    
            Return:
                A tuple of (t0, t1, intersected) 返回一个元组,包含 t0, t1 和交集信息
            """
    
            # 获取 origin 张量的 batch_size 和形状,忽略最后一个维度
            batch_size, *shape, _ = origin.shape
            # 创建一个与 shape 长度相同的列表,填充 1
            ones = [1] * len(shape)
            # 将边界框转换为与 origin 设备相同的张量,形状为 [1, *ones, 2, 3]
            bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device)
    
            # 定义安全除法函数,避免除以零的情况
            def _safe_divide(a, b, epsilon=1e-6):
                return a / torch.where(b < 0, b - epsilon, b + epsilon)
    
            # 计算 t 的值,表示光线与边界框的交点
            ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
    
            # 考虑光线与边界框相交的不同情况
            #
            #   1. t1 <= t0: 光线未通过 AABB。
            #   2. t0 < t1 <= 0: 光线相交,但边界框在原点后面。
            #   3. t0 <= 0 <= t1: 光线从边界框内部开始。
            #   4. 0 <= t0 < t1: 光线不在内部并且与边界框相交两次。
            #
            # 情况 1 和 4 已通过 t0 < t1 处理。
            # 通过将 t0 至少设为 min_dist (>= 0) 处理情况 2 和 3。
            t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
            # 计算 t1,取 ts 的最大值
            t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
            # 断言 t0 和 t1 的形状相同
            assert t0.shape == t1.shape == (batch_size, *shape, 1)
            # 如果 t0_lower 不为空,则取 t0 和 t0_lower 的最大值
            if t0_lower is not None:
                assert t0.shape == t0_lower.shape
                t0 = torch.maximum(t0, t0_lower)
    
            # 计算光线是否与体积相交
            intersected = t0 + self.min_t_range < t1
            # 如果相交,保持 t0 否则设为零
            t0 = torch.where(intersected, t0, torch.zeros_like(t0))
            # 如果相交,保持 t1 否则设为一
            t1 = torch.where(intersected, t1, torch.ones_like(t1))
    
            # 返回一个包含 t0, t1 和相交信息的 VolumeRange 对象
            return VolumeRange(t0=t0, t1=t1, intersected=intersected)
# 定义一个分层射线采样器类,继承自 nn.Module
class StratifiedRaySampler(nn.Module):
    """
    在每个间隔内随机均匀地抽样,而不是使用固定的间隔。
    """

    # 初始化方法,接受一个深度模式参数,默认为线性
    def __init__(self, depth_mode: str = "linear"):
        """
        :param depth_mode: 线性样本在深度上线性分布。谐波模式确保
            更靠近的点被更密集地采样。
        """
        # 保存深度模式参数
        self.depth_mode = depth_mode
        # 确保深度模式是允许的选项之一
        assert self.depth_mode in ("linear", "geometric", "harmonic")

    # 定义采样方法
    def sample(
        self,
        t0: torch.Tensor,
        t1: torch.Tensor,
        n_samples: int,
        epsilon: float = 1e-3,
    ) -> torch.Tensor:
        """
        Args:
            t0: 开始时间,形状为 [batch_size, *shape, 1]
            t1: 结束时间,形状为 [batch_size, *shape, 1]
            n_samples: 要采样的时间戳数量
        Return:
            采样的时间戳,形状为 [batch_size, *shape, n_samples, 1]
        """
        # 创建一个列表,长度为 t0 形状的维度减一,元素全为 1
        ones = [1] * (len(t0.shape) - 1)
        # 创建从 0 到 1 的线性间隔,并调整形状以适应 t0 的维度
        ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)

        # 根据深度模式计算时间戳
        if self.depth_mode == "linear":
            # 线性插值计算时间戳
            ts = t0 * (1.0 - ts) + t1 * ts
        elif self.depth_mode == "geometric":
            # 对数插值计算时间戳,使用 clamp 限制最小值
            ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
        elif self.depth_mode == "harmonic":
            # 原始 NeRF 推荐的插值方案,适用于球形场景
            ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)

        # 计算中间时间戳
        mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
        # 创建上界和下界,分别为中间时间戳和结束时间、开始时间
        upper = torch.cat([mids, t1], dim=-1)
        lower = torch.cat([t0, mids], dim=-1)
        # yiyi 注释:这里添加一个随机种子以便于测试,记得在生产中移除
        torch.manual_seed(0)
        # 生成与 ts 形状相同的随机数
        t_rand = torch.rand_like(ts)

        # 根据随机数计算最终的时间戳
        ts = lower + (upper - lower) * t_rand
        # 返回增加了一个维度的时间戳
        return ts.unsqueeze(-1)


# 定义一个重要性射线采样器类,继承自 nn.Module
class ImportanceRaySampler(nn.Module):
    """
    根据初始密度估计,从预期有物体的区域/箱中进行更多采样。
    """

    # 初始化方法,接受多个参数
    def __init__(
        self,
        volume_range: VolumeRange,
        ts: torch.Tensor,
        weights: torch.Tensor,
        blur_pool: bool = False,
        alpha: float = 1e-5,
    ):
        """
        Args:
            volume_range: 射线与给定体积相交的范围。
            ts: 来自粗渲染步骤的早期采样
            weights: 密度 * 透射率的离散版本
            blur_pool: 如果为真,则使用来自 mip-NeRF 的 2-tap 最大 + 2-tap 模糊滤波器。
            alpha: 添加到权重的小值。
        """
        # 保存体积范围
        self.volume_range = volume_range
        # 克隆并分离传入的时间戳
        self.ts = ts.clone().detach()
        # 克隆并分离传入的权重
        self.weights = weights.clone().detach()
        # 保存是否使用模糊池的标志
        self.blur_pool = blur_pool
        # 保存 alpha 参数
        self.alpha = alpha

    # 标记方法为不需要计算梯度
    @torch.no_grad()
    # 定义一个名为 sample 的方法,接受两个张量 t0 和 t1 以及样本数量 n_samples
    def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
        """
        Args:
            t0: start time has shape [batch_size, *shape, 1]
            t1: finish time has shape [batch_size, *shape, 1]
            n_samples: number of ts to sample
        Return:
            sampled ts of shape [batch_size, *shape, n_samples, 1]
        """
        # 从 volume_range 获取 t 的范围,分割成 lower 和 upper
        lower, upper, _ = self.volume_range.partition(self.ts)

        # 获取输入张量 ts 的批大小、形状和粗样本数量
        batch_size, *shape, n_coarse_samples, _ = self.ts.shape

        # 获取权重
        weights = self.weights
        # 如果启用了模糊池,进行权重的处理
        if self.blur_pool:
            # 在权重的前后各添加一层,以便进行边界处理
            padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
            # 计算相邻权重的最大值
            maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
            # 更新权重为相邻最大值的平均
            weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
        # 在权重上加上 alpha 值
        weights = weights + self.alpha
        # 计算权重的概率质量函数 (pmf)
        pmf = weights / weights.sum(dim=-2, keepdim=True)
        # 根据 pmf 进行采样,获取样本索引
        inds = sample_pmf(pmf, n_samples)
        # 确保索引的形状符合预期
        assert inds.shape == (batch_size, *shape, n_samples, 1)
        # 确保索引在有效范围内
        assert (inds >= 0).all() and (inds < n_coarse_samples).all()

        # 生成与索引形状相同的随机数
        t_rand = torch.rand(inds.shape, device=inds.device)
        # 根据索引从 lower 和 upper 中获取对应的值
        lower_ = torch.gather(lower, -2, inds)
        upper_ = torch.gather(upper, -2, inds)

        # 根据随机数和上下限计算采样时间
        ts = lower_ + (upper_ - lower_) * t_rand
        # 对采样结果进行排序
        ts = torch.sort(ts, dim=-2).values
        # 返回采样后的时间序列
        return ts
# 定义一个数据类,用于存储三维三角网格及其可选的顶点和面数据
@dataclass
class MeshDecoderOutput(BaseOutput):
    """
    A 3D triangle mesh with optional data at the vertices and faces.

    Args:
        verts (`torch.Tensor` of shape `(N, 3)`):
            array of vertext coordinates
        faces (`torch.Tensor` of shape `(N, 3)`):
            array of triangles, pointing to indices in verts.
        vertext_channels (Dict):
            vertext coordinates for each color channel
    """

    # 顶点坐标的张量
    verts: torch.Tensor
    # 三角形面索引的张量
    faces: torch.Tensor
    # 每个颜色通道的顶点坐标字典
    vertex_channels: Dict[str, torch.Tensor]


# 定义一个神经网络模块,用于通过有符号距离函数构建网格
class MeshDecoder(nn.Module):
    """
    Construct meshes from Signed distance functions (SDFs) using marching cubes method
    """

    # 初始化方法,构建基本组件
    def __init__(self):
        super().__init__()
        # 创建一个大小为 (256, 5, 3) 的零张量,用于存储网格案例
        cases = torch.zeros(256, 5, 3, dtype=torch.long)
        # 创建一个大小为 (256, 5) 的零布尔张量,用于存储掩码
        masks = torch.zeros(256, 5, dtype=torch.bool)

        # 将案例和掩码注册为模块的缓冲区
        self.register_buffer("cases", cases)
        self.register_buffer("masks", masks)

# 定义一个数据类,用于存储MLP NeRF模型的输出
@dataclass
class MLPNeRFModelOutput(BaseOutput):
    # 存储密度的张量
    density: torch.Tensor
    # 存储有符号距离的张量
    signed_distance: torch.Tensor
    # 存储通道的张量
    channels: torch.Tensor
    # 存储时间步长的张量
    ts: torch.Tensor


# 定义一个混合模型类,用于构建MLP NeRF
class MLPNeRSTFModel(ModelMixin, ConfigMixin):
    @register_to_config
    # 初始化方法,接受多个参数以配置模型
    def __init__(
        self,
        d_hidden: int = 256,
        n_output: int = 12,
        n_hidden_layers: int = 6,
        act_fn: str = "swish",
        insert_direction_at: int = 4,
    ):
        super().__init__()

        # 实例化MLP

        # 创建一个单位矩阵以找到编码位置和方向的维度
        dummy = torch.eye(1, 3)
        # 编码位置的维度
        d_posenc_pos = encode_position(position=dummy).shape[-1]
        # 编码方向的维度
        d_posenc_dir = encode_direction(position=dummy).shape[-1]

        # 根据隐藏层数量设置MLP宽度
        mlp_widths = [d_hidden] * n_hidden_layers
        # 输入宽度由编码位置的维度和隐藏层宽度组成
        input_widths = [d_posenc_pos] + mlp_widths
        # 输出宽度由隐藏层宽度和输出数量组成
        output_widths = mlp_widths + [n_output]

        # 如果需要,在输入宽度中插入方向的维度
        if insert_direction_at is not None:
            input_widths[insert_direction_at] += d_posenc_dir

        # 创建线性层的模块列表
        self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)])

        # 根据激活函数选择设置
        if act_fn == "swish":
            # 定义Swish激活函数
            self.activation = lambda x: F.silu(x)
        else:
            # 如果激活函数不被支持,则抛出错误
            raise ValueError(f"Unsupported activation function {act_fn}")

        # 定义不同的激活函数
        self.sdf_activation = torch.tanh
        self.density_activation = torch.nn.functional.relu
        self.channel_activation = torch.sigmoid

    # 将索引映射到键的函数
    def map_indices_to_keys(self, output):
        # 创建一个映射表,将键映射到输出的切片
        h_map = {
            "sdf": (0, 1),
            "density_coarse": (1, 2),
            "density_fine": (2, 3),
            "stf": (3, 6),
            "nerf_coarse": (6, 9),
            "nerf_fine": (9, 12),
        }

        # 根据映射表生成新的输出字典
        mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()}

        # 返回映射后的输出
        return mapped_output
    # 定义前向传播方法,接受位置、方向、时间戳等参数
    def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"):
        # 对输入位置进行编码
        h = encode_position(position)

        # 初始化激活值的预激活变量
        h_preact = h
        # 初始化无方向的激活值变量
        h_directionless = None
        # 遍历多层感知机中的每一层
        for i, layer in enumerate(self.mlp):
            # 检查当前层是否为插入方向的层
            if i == self.config.insert_direction_at:  # 4 in the config
                # 保存当前的预激活值作为无方向的激活值
                h_directionless = h_preact
                # 对位置和方向进行编码,得到方向编码
                h_direction = encode_direction(position, direction=direction)
                # 将位置编码和方向编码在最后一维进行拼接
                h = torch.cat([h, h_direction], dim=-1)

            # 将当前的激活值输入到当前层进行处理
            h = layer(h)

            # 更新预激活值为当前激活值
            h_preact = h

            # 如果不是最后一层,则应用激活函数
            if i < len(self.mlp) - 1:
                h = self.activation(h)

        # 将最后一层的激活值赋给最终激活值
        h_final = h
        # 如果无方向的激活值仍为 None,则赋值为当前的预激活值
        if h_directionless is None:
            h_directionless = h_preact

        # 将激活值映射到相应的键
        activation = self.map_indices_to_keys(h_final)

        # 根据 nerf_level 选择粗糙或细致的密度
        if nerf_level == "coarse":
            h_density = activation["density_coarse"]
        else:
            h_density = activation["density_fine"]

        # 根据渲染模式选择相应的通道
        if rendering_mode == "nerf":
            if nerf_level == "coarse":
                h_channels = activation["nerf_coarse"]
            else:
                h_channels = activation["nerf_fine"]

        # 如果渲染模式为 stf,选择相应的通道
        elif rendering_mode == "stf":
            h_channels = activation["stf"]

        # 对密度进行激活处理
        density = self.density_activation(h_density)
        # 对有符号距离进行激活处理
        signed_distance = self.sdf_activation(activation["sdf"])
        # 对通道进行激活处理
        channels = self.channel_activation(h_channels)

        # yiyi notes: I think signed_distance is not used
        # 返回 MLPNeRFModelOutput 对象,包含密度、有符号距离和通道信息
        return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts)
# 定义一个名为 ChannelsProj 的类,继承自 nn.Module
class ChannelsProj(nn.Module):
    # 初始化方法,接受参数以定义投影的特性
    def __init__(
        self,
        *,
        vectors: int,  # 设定向量的数量
        channels: int,  # 设定通道的数量
        d_latent: int,  # 设定潜在特征的维度
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 定义线性层,将 d_latent 映射到 vectors * channels
        self.proj = nn.Linear(d_latent, vectors * channels)
        # 定义层归一化,用于标准化每个通道
        self.norm = nn.LayerNorm(channels)
        # 保存潜在特征的维度
        self.d_latent = d_latent
        # 保存向量的数量
        self.vectors = vectors
        # 保存通道的数量
        self.channels = channels

    # 前向传播方法,定义输入张量的处理方式
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 将输入张量赋值给 x_bvd,便于后续使用
        x_bvd = x
        # 将权重重新形状为 (vectors, channels, d_latent)
        w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
        # 将偏置重新形状为 (1, vectors, channels)
        b_vc = self.proj.bias.view(1, self.vectors, self.channels)
        # 计算爱因斯坦求和,将输入与权重相乘并累加
        h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
        # 对计算结果进行层归一化
        h = self.norm(h)

        # 将偏置添加到归一化后的结果
        h = h + b_vc
        # 返回最终的输出
        return h


# 定义一个名为 ShapEParamsProjModel 的类,继承自 ModelMixin 和 ConfigMixin
class ShapEParamsProjModel(ModelMixin, ConfigMixin):
    """
    将 3D 资产的潜在表示投影,以获取多层感知器(MLP)的权重。

    更多细节见原始论文:
    """

    # 注册到配置中
    @register_to_config
    def __init__(
        self,
        *,
        param_names: Tuple[str] = (  # 定义参数名称的元组
            "nerstf.mlp.0.weight",
            "nerstf.mlp.1.weight",
            "nerstf.mlp.2.weight",
            "nerstf.mlp.3.weight",
        ),
        param_shapes: Tuple[Tuple[int]] = (  # 定义参数形状的元组
            (256, 93),
            (256, 256),
            (256, 256),
            (256, 256),
        ),
        d_latent: int = 1024,  # 设置潜在特征的维度,默认值为 1024
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 检查输入参数的有效性
        if len(param_names) != len(param_shapes):
            # 如果参数名称与形状数量不一致,抛出错误
            raise ValueError("Must provide same number of `param_names` as `param_shapes`")
        # 初始化一个空的模块字典,用于存储投影层
        self.projections = nn.ModuleDict({})
        # 遍历参数名称和形状,并为每一对创建一个 ChannelsProj 实例
        for k, (vectors, channels) in zip(param_names, param_shapes):
            self.projections[_sanitize_name(k)] = ChannelsProj(
                vectors=vectors,  # 设置向量数量
                channels=channels,  # 设置通道数量
                d_latent=d_latent,  # 设置潜在特征的维度
            )

    # 前向传播方法
    def forward(self, x: torch.Tensor):
        out = {}  # 初始化输出字典
        start = 0  # 初始化起始索引
        # 遍历参数名称和形状
        for k, shape in zip(self.config.param_names, self.config.param_shapes):
            vectors, _ = shape  # 获取当前参数的向量数量
            end = start + vectors  # 计算结束索引
            x_bvd = x[:, start:end]  # 从输入中切片提取相关部分
            # 将切片后的输入通过对应的投影层处理,并调整形状
            out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
            start = end  # 更新起始索引为结束索引
        # 返回包含所有输出的字典
        return out


# 定义一个名为 ShapERenderer 的类,继承自 ModelMixin 和 ConfigMixin
class ShapERenderer(ModelMixin, ConfigMixin):
    # 注册到配置中
    @register_to_config
    # 初始化方法,用于创建类的实例
    def __init__(
        self,
        *,  # 指定后续参数为关键字参数
        param_names: Tuple[str] = (  # 定义参数名称的元组,默认值为特定的权重名称
            "nerstf.mlp.0.weight",
            "nerstf.mlp.1.weight",
            "nerstf.mlp.2.weight",
            "nerstf.mlp.3.weight",
        ),
        param_shapes: Tuple[Tuple[int]] = (  # 定义参数形状的元组,默认值为特定的形状
            (256, 93),
            (256, 256),
            (256, 256),
            (256, 256),
        ),
        d_latent: int = 1024,  # 定义潜在维度的整数,默认值为1024
        d_hidden: int = 256,  # 定义隐藏层维度的整数,默认值为256
        n_output: int = 12,  # 定义输出层的神经元数量,默认值为12
        n_hidden_layers: int = 6,  # 定义隐藏层的层数,默认值为6
        act_fn: str = "swish",  # 定义激活函数的名称,默认值为"swish"
        insert_direction_at: int = 4,  # 定义插入方向的索引,默认值为4
        background: Tuple[float] = (  # 定义背景颜色的元组,默认值为白色
            255.0,
            255.0,
            255.0,
        ),
    ):
        super().__init__()  # 调用父类的初始化方法

        # 创建参数投影模型,传入参数名称、形状和潜在维度
        self.params_proj = ShapEParamsProjModel(
            param_names=param_names,
            param_shapes=param_shapes,
            d_latent=d_latent,
        )
        # 创建多层感知机模型,传入隐藏层维度、输出层数量、隐藏层层数、激活函数和插入方向
        self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
        # 创建空的神经辐射场模型,传入背景颜色和通道缩放
        self.void = VoidNeRFModel(background=background, channel_scale=255.0)
        # 创建包围盒体积模型,定义最大和最小边界
        self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
        # 创建网格解码器
        self.mesh_decoder = MeshDecoder()

    @torch.no_grad()  # 禁用梯度计算,提高推理性能
    @torch.no_grad()  # 冗余的禁用梯度计算装饰器
    def decode_to_image(
        self,
        latents,  # 输入的潜在变量
        device,  # 指定计算设备(如CPU或GPU)
        size: int = 64,  # 输出图像的尺寸,默认值为64
        ray_batch_size: int = 4096,  # 每批次光线的数量,默认值为4096
        n_coarse_samples=64,  # 粗采样的数量,默认值为64
        n_fine_samples=128,  # 精细采样的数量,默认值为128
    ):
        # 从生成的潜在变量投影参数
        projected_params = self.params_proj(latents)

        # 更新渲染器的MLP层
        for name, param in self.mlp.state_dict().items():  # 遍历MLP模型的所有参数
            if f"nerstf.{name}" in projected_params.keys():  # 检查投影参数是否存在于MLP参数中
                param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))  # 更新MLP参数

        # 创建相机对象
        camera = create_pan_cameras(size)  # 生成全景相机
        rays = camera.camera_rays  # 获取相机射线
        rays = rays.to(device)  # 将射线移动到指定设备
        n_batches = rays.shape[1] // ray_batch_size  # 计算总批次数量

        coarse_sampler = StratifiedRaySampler()  # 创建粗采样器

        images = []  # 初始化图像列表

        for idx in range(n_batches):  # 遍历每个批次
            rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]  # 获取当前批次的射线

            # 使用粗糙的分层采样渲染射线
            _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples)
            # 然后使用附加的重要性加权射线样本进行渲染
            channels, _, _ = self.render_rays(
                rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out
            )

            images.append(channels)  # 将渲染结果添加到图像列表

        images = torch.cat(images, dim=1)  # 在维度1上拼接所有图像
        images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)  # 调整图像的形状

        return images  # 返回渲染出的图像

    @torch.no_grad()  # 禁用梯度计算,提高推理性能
    def decode_to_mesh(
        self,
        latents,  # 输入的潜在变量
        device,  # 指定计算设备(如CPU或GPU)
        grid_size: int = 128,  # 网格大小,默认值为128
        query_batch_size: int = 4096,  # 每批次查询的数量,默认值为4096
        texture_channels: Tuple = ("R", "G", "B"),  # 纹理通道,默认值为RGB

.\diffusers\pipelines\shap_e\__init__.py

# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING

# 从 utils 模块导入所需的函数和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 导入慢速加载的标志
    OptionalDependencyNotAvailable,  # 导入可选依赖不可用的异常
    _LazyModule,  # 导入懒加载模块的类
    get_objects_from_module,  # 导入从模块获取对象的函数
    is_torch_available,  # 导入检查 PyTorch 是否可用的函数
    is_transformers_available,  # 导入检查 Transformers 是否可用的函数
)

# 定义一个空字典用于存储虚拟对象
_dummy_objects = {}
# 定义一个空字典用于存储模块的导入结构
_import_structure = {}

try:
    # 检查 Transformers 和 PyTorch 是否可用,如果不可用则抛出异常
    if not (is_transformers_available() and is_torch_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 模块导入虚拟对象以避免错误
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新虚拟对象字典,包含从虚拟对象模块中获取的对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 如果依赖可用,设置导入结构以包含相关模块
    _import_structure["camera"] = ["create_pan_cameras"]
    _import_structure["pipeline_shap_e"] = ["ShapEPipeline"]
    _import_structure["pipeline_shap_e_img2img"] = ["ShapEImg2ImgPipeline"]
    _import_structure["renderer"] = [
        "BoundingBoxVolume",
        "ImportanceRaySampler",
        "MLPNeRFModelOutput",
        "MLPNeRSTFModel",
        "ShapEParamsProjModel",
        "ShapERenderer",
        "StratifiedRaySampler",
        "VoidNeRFModel",
    ]

# 如果在类型检查或慢速导入的情况下执行以下代码
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 再次检查依赖是否可用
        if not (is_transformers_available() and is_torch_available()):
            raise OptionalDependencyNotAvailable()

    # 捕获可选依赖不可用的异常
    except OptionalDependencyNotAvailable:
        # 从虚拟对象模块中导入所有对象
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 从各模块导入具体对象
        from .camera import create_pan_cameras
        from .pipeline_shap_e import ShapEPipeline
        from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline
        from .renderer import (
            BoundingBoxVolume,
            ImportanceRaySampler,
            MLPNeRFModelOutput,
            MLPNeRSTFModel,
            ShapEParamsProjModel,
            ShapERenderer,
            StratifiedRaySampler,
            VoidNeRFModel,
        )

else:
    # 如果不进行类型检查或慢速导入,执行懒加载模块
    import sys

    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名
        globals()["__file__"],  # 当前文件路径
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规格
    )

    # 将虚拟对象添加到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\stable_audio\modeling_stable_audio.py

# 版权所有 2024 Stability AI 和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定, 
# 根据许可证分发的软件按“原样”提供,
# 不提供任何明示或暗示的担保或条件。
# 请参阅许可证了解特定语言所规定的权限和
# 限制。

from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from math import pi  # 从 math 模块导入圆周率 π
from typing import Optional  # 从 typing 模块导入 Optional 类型

import torch  # 导入 PyTorch 库
import torch.nn as nn  # 导入 PyTorch 的神经网络模块
import torch.utils.checkpoint  # 导入 PyTorch 的检查点功能

from ...configuration_utils import ConfigMixin, register_to_config  # 从配置工具导入配置混合和注册功能
from ...models.modeling_utils import ModelMixin  # 从模型工具导入模型混合
from ...utils import BaseOutput, logging  # 从工具模块导入基础输出和日志功能

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

class StableAudioPositionalEmbedding(nn.Module):  # 定义一个名为 StableAudioPositionalEmbedding 的类,继承自 nn.Module
    """用于连续时间的嵌入层"""

    def __init__(self, dim: int):  # 构造函数,接受一个维度参数
        super().__init__()  # 调用父类构造函数
        assert (dim % 2) == 0  # 确保维度为偶数
        half_dim = dim // 2  # 计算一半的维度
        self.weights = nn.Parameter(torch.randn(half_dim))  # 初始化权重为随机值并作为可训练参数

    def forward(self, times: torch.Tensor) -> torch.Tensor:  # 定义前向传播方法,接受一个时间张量
        times = times[..., None]  # 将时间张量扩展为最后一维
        freqs = times * self.weights[None] * 2 * pi  # 计算频率
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)  # 计算频率的正弦和余弦值
        fouriered = torch.cat((times, fouriered), dim=-1)  # 将时间和频率特征拼接
        return fouriered  # 返回处理后的张量

@dataclass
class StableAudioProjectionModelOutput(BaseOutput):  # 定义一个数据类,用于稳定音频投影模型的输出
    """
    参数:
    用于稳定音频投影层输出的类。
        text_hidden_states (`torch.Tensor`,形状为 `(batch_size, sequence_length, hidden_size)`,*可选*):
            通过线性投影获得的文本编码器的隐状态序列。
        seconds_start_hidden_states (`torch.Tensor`,形状为 `(batch_size, 1, hidden_size)`,*可选*):
            通过线性投影获得的音频开始隐状态序列。
        seconds_end_hidden_states (`torch.Tensor`,形状为 `(batch_size, 1, hidden_size)`,*可选*):
            通过线性投影获得的音频结束隐状态序列。
    """

    text_hidden_states: Optional[torch.Tensor] = None  # 可选的文本隐状态张量
    seconds_start_hidden_states: Optional[torch.Tensor] = None  # 可选的音频开始隐状态张量
    seconds_end_hidden_states: Optional[torch.Tensor] = None  # 可选的音频结束隐状态张量


class StableAudioNumberConditioner(nn.Module):  # 定义一个名为 StableAudioNumberConditioner 的类,继承自 nn.Module
    """
    一个简单的线性投影模型,将数字映射到潜在空间。
    # 参数说明
    Args:
        number_embedding_dim (`int`):
            # 数字嵌入的维度
            Dimensionality of the number embeddings.
        min_value (`int`):
            # 秒数条件模块的最小值
            The minimum value of the seconds number conditioning modules.
        max_value (`int`):
            # 秒数条件模块的最大值
            The maximum value of the seconds number conditioning modules
        internal_dim (`int`):
            # 中间数字隐藏状态的维度
            Dimensionality of the intermediate number hidden states.
    """

    # 初始化方法
    def __init__(
        self,
        # 数字嵌入的维度
        number_embedding_dim,
        # 秒数条件模块的最小值
        min_value,
        # 秒数条件模块的最大值
        max_value,
        # 中间数字隐藏状态的维度,默认为256
        internal_dim: Optional[int] = 256,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 创建时间位置嵌入层,包含稳定音频位置嵌入和线性变换
        self.time_positional_embedding = nn.Sequential(
            StableAudioPositionalEmbedding(internal_dim),
            # 从内部维度到数字嵌入维度的线性转换
            nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
        )

        # 保存数字嵌入的维度
        self.number_embedding_dim = number_embedding_dim
        # 保存最小值
        self.min_value = min_value
        # 保存最大值
        self.max_value = max_value

    # 前向传播方法
    def forward(
        self,
        # 输入的浮点张量
        floats: torch.Tensor,
    ):
        # 将浮点数限制在最小值和最大值之间
        floats = floats.clamp(self.min_value, self.max_value)

        # 将浮点数归一化到[0, 1]范围
        normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)

        # 将浮点数转换为嵌入器相同的类型
        embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
        normalized_floats = normalized_floats.to(embedder_dtype)

        # 通过时间位置嵌入层生成嵌入
        embedding = self.time_positional_embedding(normalized_floats)
        # 调整嵌入的形状
        float_embeds = embedding.view(-1, 1, self.number_embedding_dim)

        # 返回浮点数嵌入
        return float_embeds
# 定义一个稳定音频投影模型类,继承自 ModelMixin 和 ConfigMixin
class StableAudioProjectionModel(ModelMixin, ConfigMixin):
    """
    一个简单的线性投影模型,用于将条件值映射到共享的潜在空间。

    参数:
        text_encoder_dim (`int`):
            文本编码器 (T5) 的文本嵌入维度。
        conditioning_dim (`int`):
            输出条件张量的维度。
        min_value (`int`):
            秒数条件模块的最小值。
        max_value (`int`):
            秒数条件模块的最大值。
    """

    # 注册构造函数到配置
    @register_to_config
    def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value):
        # 调用父类构造函数
        super().__init__()
        # 根据条件维度选择合适的投影方式
        self.text_projection = (
            nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim)
        )
        # 初始化开始时间条件模块
        self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
        # 初始化结束时间条件模块
        self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)

    # 定义前向传播方法
    def forward(
        self,
        text_hidden_states: Optional[torch.Tensor] = None,
        start_seconds: Optional[torch.Tensor] = None,
        end_seconds: Optional[torch.Tensor] = None,
    ):
        # 如果没有文本隐藏状态,则使用输入的文本隐藏状态,否则进行投影
        text_hidden_states = (
            text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states)
        )
        # 如果没有开始秒数,则使用输入的开始秒数,否则进行条件处理
        seconds_start_hidden_states = (
            start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds)
        )
        # 如果没有结束秒数,则使用输入的结束秒数,否则进行条件处理
        seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds)

        # 返回一个稳定音频投影模型输出对象,包含所有隐藏状态
        return StableAudioProjectionModelOutput(
            text_hidden_states=text_hidden_states,
            seconds_start_hidden_states=seconds_start_hidden_states,
            seconds_end_hidden_states=seconds_end_hidden_states,
        )

.\diffusers\pipelines\stable_audio\pipeline_stable_audio.py

# 版权声明,指明版权归 Stability AI 和 HuggingFace 团队所有
# 
# 根据 Apache 2.0 许可证授权("许可证");
# 除非遵守许可证,否则不得使用此文件。
# 可在以下地址获得许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有规定,否则根据许可证分发的软件
# 是以“按现状”基础提供,不提供任何明示或暗示的保证或条件。
# 有关许可证的具体条款和权限限制,请参阅许可证文档。

import inspect  # 导入 inspect 模块以进行对象的检查
from typing import Callable, List, Optional, Union  # 从 typing 导入类型提示工具

import torch  # 导入 PyTorch 库
from transformers import (  # 从 transformers 导入相关模型和标记器
    T5EncoderModel,  # 导入 T5 编码器模型
    T5Tokenizer,  # 导入 T5 标记器
    T5TokenizerFast,  # 导入快速 T5 标记器
)

from ...models import AutoencoderOobleck, StableAudioDiTModel  # 从模型中导入特定类
from ...models.embeddings import get_1d_rotary_pos_embed  # 导入获取一维旋转位置嵌入的函数
from ...schedulers import EDMDPMSolverMultistepScheduler  # 导入多步调度器类
from ...utils import (  # 从 utils 导入通用工具
    logging,  # 导入日志记录模块
    replace_example_docstring,  # 导入替换示例文档字符串的函数
)
from ...utils.torch_utils import randn_tensor  # 从 PyTorch 工具中导入生成随机张量的函数
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline  # 导入音频管道输出和扩散管道
from .modeling_stable_audio import StableAudioProjectionModel  # 导入稳定音频投影模型

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,禁止 pylint 检查命名

EXAMPLE_DOC_STRING = """  # 定义示例文档字符串,包含代码示例
    Examples:  # 示例部分
        ```py  # 代码块开始
        >>> import scipy  # 导入 scipy 库
        >>> import torch  # 导入 PyTorch 库
        >>> import soundfile as sf  # 导入 soundfile 库以处理音频文件
        >>> from diffusers import StableAudioPipeline  # 从 diffusers 导入稳定音频管道

        >>> repo_id = "stabilityai/stable-audio-open-1.0"  # 定义模型的仓库 ID
        >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)  # 从预训练模型加载管道,设置数据类型为 float16
        >>> pipe = pipe.to("cuda")  # 将管道移动到 GPU

        >>> # 定义提示语
        >>> prompt = "The sound of a hammer hitting a wooden surface."  # 正面提示语
        >>> negative_prompt = "Low quality."  # 负面提示语

        >>> # 为生成器设置种子
        >>> generator = torch.Generator("cuda").manual_seed(0)  # 创建 GPU 上的随机数生成器并设置种子

        >>> # 执行生成
        >>> audio = pipe(  # 调用管道生成音频
        ...     prompt,  # 传入正面提示语
        ...     negative_prompt=negative_prompt,  # 传入负面提示语
        ...     num_inference_steps=200,  # 设置推理步骤数
        ...     audio_end_in_s=10.0,  # 设置音频结束时间为 10 秒
        ...     num_waveforms_per_prompt=3,  # 每个提示生成三个波形
        ...     generator=generator,  # 传入随机数生成器
        ... ).audios  # 获取生成的音频列表

        >>> output = audio[0].T.float().cpu().numpy()  # 转置第一个音频并转换为 NumPy 数组
        >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)  # 将输出音频写入文件
        ```py  # 代码块结束
"""
    # 文档字符串,描述参数的作用和类型
    Args:
        vae ([`AutoencoderOobleck`]):
            # 变分自编码器 (VAE) 模型,用于将图像编码为潜在表示并从潜在表示解码图像。
        text_encoder ([`~transformers.T5EncoderModel`]):
            # 冻结的文本编码器。StableAudio 使用
            # [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel) 的编码器,
            # 特别是 [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) 变体。
        projection_model ([`StableAudioProjectionModel`]):
            # 一个经过训练的模型,用于线性投影文本编码器模型的隐藏状态和开始及结束秒数。
            # 编码器的投影隐藏状态和条件秒数被连接,以作为变换器模型的输入。
        tokenizer ([`~transformers.T5Tokenizer`]):
            # 用于为冻结文本编码器进行文本标记化的分词器。
        transformer ([`StableAudioDiTModel`]):
            # 用于去噪编码音频潜在表示的 `StableAudioDiTModel`。
        scheduler ([`EDMDPMSolverMultistepScheduler`]):
            # 结合 `transformer` 使用的调度器,用于去噪编码的音频潜在表示。
    """

    # 定义模型组件的顺序,用于 CPU 内存卸载
    model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"

    # 初始化方法,接收多个模型和组件作为参数
    def __init__(
        self,
        vae: AutoencoderOobleck,
        text_encoder: T5EncoderModel,
        projection_model: StableAudioProjectionModel,
        tokenizer: Union[T5Tokenizer, T5TokenizerFast],
        transformer: StableAudioDiTModel,
        scheduler: EDMDPMSolverMultistepScheduler,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 注册多个模块,以便后续使用
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            projection_model=projection_model,
            tokenizer=tokenizer,
            transformer=transformer,
            scheduler=scheduler,
        )
        # 计算旋转嵌入维度,注意力头维度的一半
        self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2

    # 从 diffusers.pipelines.pipeline_utils.StableDiffusionMixin 复制的方法,启用 VAE 切片
    def enable_vae_slicing(self):
        r"""
        # 启用切片 VAE 解码。当启用此选项时,VAE 将输入张量切片以
        # 进行分步解码。这有助于节省内存并允许更大的批处理大小。
        """
        # 调用 VAE 的启用切片方法
        self.vae.enable_slicing()

    # 从 diffusers.pipelines.pipeline_utils.StableDiffusionMixin 复制的方法,禁用 VAE 切片
    def disable_vae_slicing(self):
        r"""
        # 禁用切片 VAE 解码。如果之前启用了 `enable_vae_slicing`,
        # 此方法将返回到一次性解码。
        """
        # 调用 VAE 的禁用切片方法
        self.vae.disable_slicing()
    # 编码提示信息的函数定义
        def encode_prompt(
            self,
            prompt,  # 提示内容
            device,  # 设备(CPU或GPU)
            do_classifier_free_guidance,  # 是否使用无分类器自由引导
            negative_prompt=None,  # 可选的负面提示内容
            prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入
            attention_mask: Optional[torch.LongTensor] = None,  # 可选的注意力掩码
            negative_attention_mask: Optional[torch.LongTensor] = None,  # 可选的负面注意力掩码
        # 编码音频持续时间的函数定义
        def encode_duration(
            self,
            audio_start_in_s,  # 音频开始时间(秒)
            audio_end_in_s,  # 音频结束时间(秒)
            device,  # 设备(CPU或GPU)
            do_classifier_free_guidance,  # 是否使用无分类器自由引导
            batch_size,  # 批处理大小
        ):
            # 如果开始时间不是列表,则转换为列表
            audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
            # 如果结束时间不是列表,则转换为列表
            audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
    
            # 如果开始时间列表长度为1,则扩展为批处理大小
            if len(audio_start_in_s) == 1:
                audio_start_in_s = audio_start_in_s * batch_size
            # 如果结束时间列表长度为1,则扩展为批处理大小
            if len(audio_end_in_s) == 1:
                audio_end_in_s = audio_end_in_s * batch_size
    
            # 将开始时间转换为浮点数列表
            audio_start_in_s = [float(x) for x in audio_start_in_s]
            # 将开始时间转换为张量并移动到指定设备
            audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
    
            # 将结束时间转换为浮点数列表
            audio_end_in_s = [float(x) for x in audio_end_in_s]
            # 将结束时间转换为张量并移动到指定设备
            audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
    
            # 使用投影模型获取输出
            projection_output = self.projection_model(
                start_seconds=audio_start_in_s,  # 开始时间张量
                end_seconds=audio_end_in_s,  # 结束时间张量
            )
            # 获取开始时间的隐藏状态
            seconds_start_hidden_states = projection_output.seconds_start_hidden_states
            # 获取结束时间的隐藏状态
            seconds_end_hidden_states = projection_output.seconds_end_hidden_states
    
            # 如果使用无分类器自由引导,则需要进行两个前向传递
            # 这里复制音频隐藏状态以避免进行两个前向传递
            if do_classifier_free_guidance:
                # 在第一个维度上重复开始时间的隐藏状态
                seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
                # 在第一个维度上重复结束时间的隐藏状态
                seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
    
            # 返回开始和结束时间的隐藏状态
            return seconds_start_hidden_states, seconds_end_hidden_states
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的代码
    # 准备额外的参数用于调度器步骤,因为并非所有调度器都有相同的参数签名
    def prepare_extra_step_kwargs(self, generator, eta):
        # eta (η) 仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
        # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
        # eta 的取值应在 [0, 1] 之间

        # 检查调度器的 step 方法是否接受 eta 参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化一个空的字典,用于存放额外的步骤参数
        extra_step_kwargs = {}
        # 如果调度器接受 eta 参数,则将其添加到额外参数字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 检查调度器的 step 方法是否接受 generator 参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果调度器接受 generator 参数,则将其添加到额外参数字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回包含额外步骤参数的字典
        return extra_step_kwargs

    # 检查输入参数的有效性
    def check_inputs(
        self,
        prompt,  # 输入提示文本
        audio_start_in_s,  # 音频起始时间(秒)
        audio_end_in_s,  # 音频结束时间(秒)
        callback_steps,  # 回调步骤的间隔
        negative_prompt=None,  # 可选的负面提示文本
        prompt_embeds=None,  # 可选的提示嵌入向量
        negative_prompt_embeds=None,  # 可选的负面提示嵌入向量
        attention_mask=None,  # 可选的注意力掩码
        negative_attention_mask=None,  # 可选的负面注意力掩码
        initial_audio_waveforms=None,  # 初始音频波形(张量)
        initial_audio_sampling_rate=None,  # 初始音频采样率
    # 准备潜在变量
    def prepare_latents(
        self,
        batch_size,  # 批处理大小
        num_channels_vae,  # VAE 的通道数
        sample_size,  # 样本尺寸
        dtype,  # 数据类型
        device,  # 设备信息(如 CPU 或 GPU)
        generator,  # 随机数生成器
        latents=None,  # 可选的潜在变量
        initial_audio_waveforms=None,  # 初始音频波形(张量)
        num_waveforms_per_prompt=None,  # 每个提示的音频波形数量
        audio_channels=None,  # 音频通道数
    # 禁用梯度计算,以提高性能
    @torch.no_grad()
    # 替换示例文档字符串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 调用方法,执行推理过程
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,  # 输入的提示文本,可以是字符串或字符串列表
        audio_end_in_s: Optional[float] = None,  # 音频结束时间(秒),可选
        audio_start_in_s: Optional[float] = 0.0,  # 音频起始时间(秒),默认为0.0
        num_inference_steps: int = 100,  # 推理步骤数,默认为100
        guidance_scale: float = 7.0,  # 指导因子,默认为7.0
        negative_prompt: Optional[Union[str, List[str]]] = None,  # 可选的负面提示文本
        num_waveforms_per_prompt: Optional[int] = 1,  # 每个提示的音频波形数量,默认为1
        eta: float = 0.0,  # eta 值,默认为0.0
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,  # 可选的随机数生成器
        latents: Optional[torch.Tensor] = None,  # 可选的潜在变量(张量)
        initial_audio_waveforms: Optional[torch.Tensor] = None,  # 可选的初始音频波形(张量)
        initial_audio_sampling_rate: Optional[torch.Tensor] = None,  # 可选的初始音频采样率
        prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入向量(张量)
        negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入向量(张量)
        attention_mask: Optional[torch.LongTensor] = None,  # 可选的注意力掩码(张量)
        negative_attention_mask: Optional[torch.LongTensor] = None,  # 可选的负面注意力掩码(张量)
        return_dict: bool = True,  # 是否返回字典格式的结果,默认为 True
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,  # 可选的回调函数
        callback_steps: Optional[int] = 1,  # 回调步骤的间隔,默认为1
        output_type: Optional[str] = "pt",  # 输出类型,默认为 "pt"

.\diffusers\pipelines\stable_audio\__init__.py

# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING

# 从上级模块的 utils 导入多个对象和函数
from ...utils import (
    # 导入 DIFFUSERS_SLOW_IMPORT 变量
    DIFFUSERS_SLOW_IMPORT,
    # 导入 OptionalDependencyNotAvailable 异常类
    OptionalDependencyNotAvailable,
    # 导入 LazyModule 类
    _LazyModule,
    # 导入获取模块对象的函数
    get_objects_from_module,
    # 导入判断是否可用的函数
    is_torch_available,
    is_transformers_available,
    is_transformers_version,
)

# 定义一个空字典,用于存储占位对象
_dummy_objects = {}
# 定义一个空字典,用于存储模块导入结构
_import_structure = {}

# 尝试检测可用性
try:
    # 检查 transformers 和 torch 是否可用,并且 transformers 版本是否满足要求
    if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
        # 如果不满足条件,则抛出异常
        raise OptionalDependencyNotAvailable()
# 捕获 OptionalDependencyNotAvailable 异常
except OptionalDependencyNotAvailable:
    # 从 utils 导入 dummy_torch_and_transformers_objects 模块
    from ...utils import dummy_torch_and_transformers_objects

    # 更新 _dummy_objects 字典,填充占位对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果没有异常,执行以下代码
else:
    # 将模型导入结构添加到 _import_structure 字典中
    _import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"]
    _import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"]

# 检查类型或慢导入标志
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    # 尝试检测可用性
    try:
        # 再次检查 transformers 和 torch 是否可用,并且 transformers 版本是否满足要求
        if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
            # 如果不满足条件,则抛出异常
            raise OptionalDependencyNotAvailable()
    # 捕获 OptionalDependencyNotAvailable 异常
    except OptionalDependencyNotAvailable:
        # 从 dummy_torch_and_transformers_objects 模块导入所有内容
        from ...utils.dummy_torch_and_transformers_objects import *

    # 如果没有异常,执行以下代码
    else:
        # 从 modeling_stable_audio 模块导入 StableAudioProjectionModel 类
        from .modeling_stable_audio import StableAudioProjectionModel
        # 从 pipeline_stable_audio 模块导入 StableAudioPipeline 类
        from .pipeline_stable_audio import StableAudioPipeline

# 如果不是类型检查或慢导入
else:
    # 导入 sys 模块
    import sys

    # 使用 _LazyModule 创建一个懒加载模块
    sys.modules[__name__] = _LazyModule(
        __name__,
        globals()["__file__"],
        _import_structure,
        module_spec=__spec__,
    )
    # 遍历 _dummy_objects 字典,将每个占位对象添加到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\stable_cascade\pipeline_stable_cascade.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版(“许可证”)许可;
# 除非遵循该许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,按照许可证分发的软件均按“原样”提供,
# 不附带任何形式的明示或暗示的担保或条件。
# 有关许可证下的特定语言和权限限制,请参见许可证。

from typing import Callable, Dict, List, Optional, Union  # 从 typing 模块导入类型提示所需的类

import torch  # 导入 PyTorch 库
from transformers import CLIPTextModel, CLIPTokenizer  # 从 transformers 导入 CLIP 文本模型和标记器

from ...models import StableCascadeUNet  # 从当前包导入 StableCascadeUNet 模型
from ...schedulers import DDPMWuerstchenScheduler  # 从当前包导入调度器 DDPMWuerstchenScheduler
from ...utils import is_torch_version, logging, replace_example_docstring  # 从 utils 导入工具函数
from ...utils.torch_utils import randn_tensor  # 从 torch_utils 导入 randn_tensor 函数
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput  # 从 pipeline_utils 导入 DiffusionPipeline 和 ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel  # 从 wuerstchen 模块导入 PaellaVQModel

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

EXAMPLE_DOC_STRING = """  # 示例文档字符串,提供使用示例
    Examples:
        ```py
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline  # 导入管道类

        >>> prior_pipe = StableCascadePriorPipeline.from_pretrained(  # 从预训练模型加载 StableCascadePriorPipeline
        ...     "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16  # 指定模型名称和数据类型
        ... ).to("cuda")  # 将管道移至 CUDA 设备
        >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain(  # 从预训练模型加载 StableCascadeDecoderPipeline
        ...     "stabilityai/stable-cascade", torch_dtype=torch.float16  # 指定模型名称和数据类型
        ... ).to("cuda")  # 将管道移至 CUDA 设备

        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"  # 定义生成图像的提示
        >>> prior_output = pipe(prompt)  # 使用提示生成初步输出
        >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)  # 基于初步输出生成最终图像
        ```py  # 结束示例代码块
"""

class StableCascadeDecoderPipeline(DiffusionPipeline):  # 定义 StableCascadeDecoderPipeline 类,继承自 DiffusionPipeline
    """
    用于从 Stable Cascade 模型生成图像的管道。

    此模型继承自 [`DiffusionPipeline`]。请查看超类文档以了解库为所有管道实现的通用方法
    (如下载或保存,运行在特定设备等)。
    # 参数说明
    Args:
        tokenizer (`CLIPTokenizer`):
            # CLIP 的分词器,用于处理文本输入
            The CLIP tokenizer.
        text_encoder (`CLIPTextModel`):
            # CLIP 的文本编码器,负责将文本转化为向量表示
            The CLIP text encoder.
        decoder ([`StableCascadeUNet`]):
            # 稳定的级联解码器 UNet,用于生成图像
            The Stable Cascade decoder unet.
        vqgan ([`PaellaVQModel`]):
            # VQGAN 模型,用于图像生成
            The VQGAN model.
        scheduler ([`DDPMWuerstchenScheduler`]):
            # 调度器,与 `prior` 结合用于生成图像嵌入
            A scheduler to be used in combination with `prior` to generate image embedding.
        latent_dim_scale (float, `optional`, defaults to 10.67):
            # 用于从图像嵌入计算 VQ 潜在空间大小的倍数
            Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
            height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
            width=int(24*10.67)=256 in order to match the training conditions.
    """

    # 设置解码器名称
    unet_name = "decoder"
    # 设置文本编码器名称
    text_encoder_name = "text_encoder"
    # 定义模型在 CPU 上的卸载顺序
    model_cpu_offload_seq = "text_encoder->decoder->vqgan"
    # 注册需要回调的张量输入
    _callback_tensor_inputs = [
        "latents",  # 潜在变量
        "prompt_embeds_pooled",  # 处理后的提示嵌入
        "negative_prompt_embeds",  # 负面提示嵌入
        "image_embeddings",  # 图像嵌入
    ]

    # 初始化方法
    def __init__(
        self,
        decoder: StableCascadeUNet,  # 解码器模型
        tokenizer: CLIPTokenizer,  # 分词器
        text_encoder: CLIPTextModel,  # 文本编码器
        scheduler: DDPMWuerstchenScheduler,  # 调度器
        vqgan: PaellaVQModel,  # VQGAN 模型
        latent_dim_scale: float = 10.67,  # 潜在维度缩放因子
    ) -> None:
        # 调用父类构造方法
        super().__init__()
        # 注册模块
        self.register_modules(
            decoder=decoder,  # 注册解码器
            tokenizer=tokenizer,  # 注册分词器
            text_encoder=text_encoder,  # 注册文本编码器
            scheduler=scheduler,  # 注册调度器
            vqgan=vqgan,  # 注册 VQGAN
        )
        # 将潜在维度缩放因子注册到配置中
        self.register_to_config(latent_dim_scale=latent_dim_scale)

    # 准备潜在变量的方法
    def prepare_latents(
        self, 
        batch_size,  # 批大小
        image_embeddings,  # 图像嵌入
        num_images_per_prompt,  # 每个提示生成的图像数量
        dtype,  # 数据类型
        device,  # 设备信息
        generator,  # 随机数生成器
        latents,  # 潜在变量
        scheduler  # 调度器
    ):
        # 获取图像嵌入的形状信息
        _, channels, height, width = image_embeddings.shape
        # 定义潜在变量的形状
        latents_shape = (
            batch_size * num_images_per_prompt,  # 总图像数量
            4,  # 通道数
            int(height * self.config.latent_dim_scale),  # 潜在图像高度
            int(width * self.config.latent_dim_scale),  # 潜在图像宽度
        )

        # 如果没有提供潜在变量,则生成随机潜在变量
        if latents is None:
            latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
        else:
            # 如果提供的潜在变量形状不符合预期,则抛出异常
            if latents.shape != latents_shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
            # 将潜在变量转移到指定设备
            latents = latents.to(device)

        # 将潜在变量与调度器的初始噪声标准差相乘
        latents = latents * scheduler.init_noise_sigma
        # 返回准备好的潜在变量
        return latents

    # 编码提示的方法
    def encode_prompt(
        self,
        device,  # 设备信息
        batch_size,  # 批大小
        num_images_per_prompt,  # 每个提示生成的图像数量
        do_classifier_free_guidance,  # 是否进行无分类器引导
        prompt=None,  # 提示文本
        negative_prompt=None,  # 负面提示文本
        prompt_embeds: Optional[torch.Tensor] = None,  # 提示嵌入(可选)
        prompt_embeds_pooled: Optional[torch.Tensor] = None,  # 处理后的提示嵌入(可选)
        negative_prompt_embeds: Optional[torch.Tensor] = None,  # 负面提示嵌入(可选)
        negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,  # 处理后的负面提示嵌入(可选)
    # 检查输入参数的有效性
        def check_inputs(
            self,  # 当前类实例
            prompt,  # 正向提示文本
            negative_prompt=None,  # 负向提示文本,默认为 None
            prompt_embeds=None,  # 正向提示的嵌入表示,默认为 None
            negative_prompt_embeds=None,  # 负向提示的嵌入表示,默认为 None
            callback_on_step_end_tensor_inputs=None,  # 回调函数输入,默认为 None
        ):
            # 检查回调函数输入是否为 None,并验证每个输入是否在预定义的回调输入列表中
            if callback_on_step_end_tensor_inputs is not None and not all(
                k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
            ):
                # 如果输入无效,抛出错误并列出无效的输入
                raise ValueError(
                    f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
                )
    
            # 检查正向提示和正向嵌入是否同时提供
            if prompt is not None and prompt_embeds is not None:
                # 抛出错误,说明不能同时提供两者
                raise ValueError(
                    f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                    " only forward one of the two."
                )
            # 检查正向提示和正向嵌入是否都未提供
            elif prompt is None and prompt_embeds is None:
                # 抛出错误,要求至少提供一个
                raise ValueError(
                    "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
                )
            # 检查正向提示的类型是否正确
            elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
                # 抛出错误,说明类型不符合要求
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
    
            # 检查负向提示和负向嵌入是否同时提供
            if negative_prompt is not None and negative_prompt_embeds is not None:
                # 抛出错误,说明不能同时提供两者
                raise ValueError(
                    f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                    f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
                )
    
            # 检查正向嵌入和负向嵌入的形状是否匹配
            if prompt_embeds is not None and negative_prompt_embeds is not None:
                if prompt_embeds.shape != negative_prompt_embeds.shape:
                    # 抛出错误,说明形状不匹配
                    raise ValueError(
                        "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                        f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                        f" {negative_prompt_embeds.shape}."
                    )
    
        # 定义一个属性,获取指导比例
        @property
        def guidance_scale(self):
            # 返回指导比例的值
            return self._guidance_scale
    
        # 定义一个属性,判断是否进行无分类器指导
        @property
        def do_classifier_free_guidance(self):
            # 返回指导比例是否大于 1
            return self._guidance_scale > 1
    
        # 定义一个属性,获取时间步数
        @property
        def num_timesteps(self):
            # 返回时间步数的值
            return self._num_timesteps
    
        # 禁用梯度计算的装饰器
        @torch.no_grad()
        # 替换示例文档字符串的装饰器
        @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义可调用类的方法,用于生成图像
        def __call__(
            self,
            # 输入的图像嵌入,可以是单个张量或张量列表
            image_embeddings: Union[torch.Tensor, List[torch.Tensor]],
            # 提示词,可以是字符串或字符串列表
            prompt: Union[str, List[str]] = None,
            # 推理步骤的数量,默认为10
            num_inference_steps: int = 10,
            # 引导尺度,默认为0.0
            guidance_scale: float = 0.0,
            # 负提示词,可以是字符串或字符串列表,默认为None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 提示词嵌入,默认为None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 提示词池化后的嵌入,默认为None
            prompt_embeds_pooled: Optional[torch.Tensor] = None,
            # 负提示词嵌入,默认为None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 负提示词池化后的嵌入,默认为None
            negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
            # 每个提示生成的图像数量,默认为1
            num_images_per_prompt: int = 1,
            # 随机数生成器,可以是单个或列表,默认为None
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量,默认为None
            latents: Optional[torch.Tensor] = None,
            # 输出类型,默认为"pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典格式,默认为True
            return_dict: bool = True,
            # 每步结束时的回调函数,默认为None
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            # 每步结束时使用的张量输入列表,默认为包含"latents"
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],

.\diffusers\pipelines\stable_cascade\pipeline_stable_cascade_combined.py

# 版权声明,说明该代码的所有权
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 使用 Apache License 2.0 进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件只能在遵循许可证的情况下使用
# you may not use this file except in compliance with the License.
# 可以在以下网址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件在“原样”基础上分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入所需的类型定义
from typing import Callable, Dict, List, Optional, Union

# 导入图像处理库
import PIL
# 导入 PyTorch
import torch
# 从 transformers 库导入 CLIP 模型及处理器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

# 从本地模型中导入 StableCascadeUNet
from ...models import StableCascadeUNet
# 从调度器中导入 DDPMWuerstchenScheduler
from ...schedulers import DDPMWuerstchenScheduler
# 导入工具函数
from ...utils import is_torch_version, replace_example_docstring
# 从管道工具中导入 DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline
# 从 VQ 模型中导入 PaellaVQModel
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
# 导入 StableCascade 解码器管道
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
# 导入 StableCascade 优先管道
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline


# 文档字符串示例,展示如何使用文本转图像功能
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import StableCascadeCombinedPipeline

        # 从预训练模型创建管道实例
        >>> pipe = StableCascadeCombinedPipeline.from_pretrained(
        ...     "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
        ... )
        # 启用模型的 CPU 离线加载
        >>> pipe.enable_model_cpu_offload()
        # 定义图像生成的提示
        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
        # 生成图像
        >>> images = pipe(prompt=prompt)
        ```
"""

# 定义稳定级联组合管道类,用于文本到图像生成
class StableCascadeCombinedPipeline(DiffusionPipeline):
    """
    Combined Pipeline for text-to-image generation using Stable Cascade.

    该模型继承自 [`DiffusionPipeline`]。检查父类文档以了解库为所有管道实现的通用方法
    (例如下载或保存、在特定设备上运行等。)
    # 文档字符串,描述初始化方法参数的含义
        Args:
            tokenizer (`CLIPTokenizer`):
                用于文本输入的解码器分词器。
            text_encoder (`CLIPTextModel`):
                用于文本输入的解码器文本编码器。
            decoder (`StableCascadeUNet`):
                用于解码器图像生成管道的解码模型。
            scheduler (`DDPMWuerstchenScheduler`):
                用于解码器图像生成管道的调度器。
            vqgan (`PaellaVQModel`):
                用于解码器图像生成管道的 VQGAN 模型。
            feature_extractor ([`~transformers.CLIPImageProcessor`]):
                从生成图像中提取特征的模型,作为 `image_encoder` 的输入。
            image_encoder ([`CLIPVisionModelWithProjection`]):
                冻结的 CLIP 图像编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
            prior_prior (`StableCascadeUNet`):
                用于先验管道的先验模型。
            prior_scheduler (`DDPMWuerstchenScheduler`):
                用于先验管道的调度器。
        """
    
        # 设置加载连接管道的标志为 True
        _load_connected_pipes = True
        # 定义可选组件的列表
        _optional_components = ["prior_feature_extractor", "prior_image_encoder"]
    
        # 初始化方法
        def __init__(
            # 定义参数类型及名称
            self,
            tokenizer: CLIPTokenizer,
            text_encoder: CLIPTextModel,
            decoder: StableCascadeUNet,
            scheduler: DDPMWuerstchenScheduler,
            vqgan: PaellaVQModel,
            prior_prior: StableCascadeUNet,
            prior_text_encoder: CLIPTextModel,
            prior_tokenizer: CLIPTokenizer,
            prior_scheduler: DDPMWuerstchenScheduler,
            prior_feature_extractor: Optional[CLIPImageProcessor] = None,
            prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None,
        ):
            # 调用父类初始化方法
            super().__init__()
    
            # 注册多个模块以便于管理
            self.register_modules(
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                decoder=decoder,
                scheduler=scheduler,
                vqgan=vqgan,
                prior_text_encoder=prior_text_encoder,
                prior_tokenizer=prior_tokenizer,
                prior_prior=prior_prior,
                prior_scheduler=prior_scheduler,
                prior_feature_extractor=prior_feature_extractor,
                prior_image_encoder=prior_image_encoder,
            )
            # 初始化先验管道
            self.prior_pipe = StableCascadePriorPipeline(
                prior=prior_prior,
                text_encoder=prior_text_encoder,
                tokenizer=prior_tokenizer,
                scheduler=prior_scheduler,
                image_encoder=prior_image_encoder,
                feature_extractor=prior_feature_extractor,
            )
            # 初始化解码器管道
            self.decoder_pipe = StableCascadeDecoderPipeline(
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                decoder=decoder,
                scheduler=scheduler,
                vqgan=vqgan,
            )
    # 启用 xformers 的内存高效注意力机制
        def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
            # 调用解码管道以启用内存高效的注意力机制
            self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
    
    # 启用模型的 CPU 离线加载
        def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
            r"""
            使用 accelerate 将所有模型转移到 CPU,降低内存使用,性能影响较小。与 `enable_sequential_cpu_offload` 相比,该方法在调用模型的 `forward` 方法时一次移动整个模型到 GPU,并在下一个模型运行之前保持在 GPU 中。内存节省低于 `enable_sequential_cpu_offload`,但由于 `unet` 的迭代执行,性能更好。
            """
            # 启用 CPU 离线加载到优先管道
            self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
            # 启用 CPU 离线加载到解码管道
            self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
    
    # 启用顺序 CPU 离线加载
        def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
            r"""
            使用 🤗 Accelerate 将所有模型(`unet`、`text_encoder`、`vae` 和 `safety checker` 状态字典)转移到 CPU,显著减少内存使用。模型被移动到 `torch.device('meta')`,仅在调用其特定子模块的 `forward` 方法时加载到 GPU。离线加载是基于子模块进行的。内存节省高于使用 `enable_model_cpu_offload`,但性能较低。
            """
            # 启用顺序 CPU 离线加载到优先管道
            self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
            # 启用顺序 CPU 离线加载到解码管道
            self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
    
    # 处理进度条的显示
        def progress_bar(self, iterable=None, total=None):
            # 在优先管道中显示进度条
            self.prior_pipe.progress_bar(iterable=iterable, total=total)
            # 在解码管道中显示进度条
            self.decoder_pipe.progress_bar(iterable=iterable, total=total)
    
    # 设置进度条的配置
        def set_progress_bar_config(self, **kwargs):
            # 设置优先管道的进度条配置
            self.prior_pipe.set_progress_bar_config(**kwargs)
            # 设置解码管道的进度条配置
            self.decoder_pipe.set_progress_bar_config(**kwargs)
    
    # 禁用梯度计算以节省内存
        @torch.no_grad()
        # 替换示例文档字符串
        @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
    # 定义可调用方法,允许使用多个参数进行推理
        def __call__(
            self,
            # 输入的提示,可以是字符串或字符串列表
            prompt: Optional[Union[str, List[str]]] = None,
            # 输入的图像,可以是张量或 PIL 图像,支持列表形式
            images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
            # 生成图像的高度,默认值为 512
            height: int = 512,
            # 生成图像的宽度,默认值为 512
            width: int = 512,
            # 推理步骤的数量,用于先验模型,默认值为 60
            prior_num_inference_steps: int = 60,
            # 先验指导尺度,控制生成的样式强度,默认值为 4.0
            prior_guidance_scale: float = 4.0,
            # 推理步骤的数量,控制图像生成的细致程度,默认值为 12
            num_inference_steps: int = 12,
            # 解码器指导尺度,影响图像的多样性,默认值为 0.0
            decoder_guidance_scale: float = 0.0,
            # 负面提示,可以是字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 提示嵌入,供模型使用的预计算张量
            prompt_embeds: Optional[torch.Tensor] = None,
            # 池化后的提示嵌入,增强模型的理解能力
            prompt_embeds_pooled: Optional[torch.Tensor] = None,
            # 负面提示嵌入,供模型使用的预计算张量
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 池化后的负面提示嵌入,增强模型的理解能力
            negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
            # 每个提示生成的图像数量,默认值为 1
            num_images_per_prompt: int = 1,
            # 随机数生成器,控制生成过程的随机性
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量,用于图像生成的输入张量
            latents: Optional[torch.Tensor] = None,
            # 输出类型,默认使用 PIL 图像格式
            output_type: Optional[str] = "pil",
            # 是否返回字典格式的结果,默认值为 True
            return_dict: bool = True,
            # 先验回调函数,处理每个步骤结束时的操作
            prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            # 先验回调使用的张量输入,默认包含 'latents'
            prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            # 回调函数,处理每个步骤结束时的操作
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            # 回调使用的张量输入,默认包含 'latents'
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],

.\diffusers\pipelines\stable_cascade\pipeline_stable_cascade_prior.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)授权;
# 除非遵循该许可证,否则您不得使用此文件。
# 您可以在以下位置获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,否则根据许可证分发的软件是按“原样”基础分发的,
# 不附带任何形式的保证或条件,无论是明示还是暗示。
# 有关许可证所涵盖权限和限制的具体语言,请参见许可证。

from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from math import ceil  # 从 math 模块导入 ceil 函数,用于向上取整
from typing import Callable, Dict, List, Optional, Union  # 导入类型提示相关的类型

import numpy as np  # 导入 numpy 库并简化为 np
import PIL  # 导入 PIL 库用于图像处理
import torch  # 导入 PyTorch 库
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection  # 从 transformers 模块导入 CLIP 相关的处理器和模型

from ...models import StableCascadeUNet  # 从当前包的模型模块导入 StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler  # 从当前包的调度器模块导入 DDPMWuerstchenScheduler
from ...utils import BaseOutput, logging, replace_example_docstring  # 从当前包的工具模块导入 BaseOutput、logging 和 replace_example_docstring
from ...utils.torch_utils import randn_tensor  # 从当前包的 PyTorch 工具模块导入 randn_tensor 函数
from ..pipeline_utils import DiffusionPipeline  # 从上级包的管道工具模块导入 DiffusionPipeline

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 定义 DEFAULT_STAGE_C_TIMESTEPS 为线性空间的时间步列表
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]

# 示例文档字符串,展示如何使用 StableCascadePriorPipeline
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import StableCascadePriorPipeline  # 从 diffusers 模块导入 StableCascadePriorPipeline

        >>> prior_pipe = StableCascadePriorPipeline.from_pretrained(  # 创建预训练的 StableCascadePriorPipeline 实例
        ...     "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16  # 指定模型路径和数据类型
        ... ).to("cuda")  # 将管道移动到 CUDA 设备

        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"  # 定义输入提示
        >>> prior_output = pipe(prompt)  # 使用管道生成图像输出
        ```py
"""

@dataclass  # 使用 dataclass 装饰器定义数据类
class StableCascadePriorPipelineOutput(BaseOutput):  # 定义 StableCascadePriorPipelineOutput 类,继承自 BaseOutput
    """
    WuerstchenPriorPipeline 的输出类。

    Args:
        image_embeddings (`torch.Tensor` or `np.ndarray`)  # 图像嵌入,表示文本提示的图像特征
            Prior image embeddings for text prompt
        prompt_embeds (`torch.Tensor`):  # 文本提示的嵌入
            Text embeddings for the prompt.
        negative_prompt_embeds (`torch.Tensor`):  # 负文本提示的嵌入
            Text embeddings for the negative prompt.
    """

    image_embeddings: Union[torch.Tensor, np.ndarray]  # 定义图像嵌入属性,类型为 torch.Tensor 或 np.ndarray
    prompt_embeds: Union[torch.Tensor, np.ndarray]  # 定义提示嵌入属性,类型为 torch.Tensor 或 np.ndarray
    prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]  # 定义池化提示嵌入属性,类型为 torch.Tensor 或 np.ndarray
    negative_prompt_embeds: Union[torch.Tensor, np.ndarray]  # 定义负提示嵌入属性,类型为 torch.Tensor 或 np.ndarray
    negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]  # 定义池化负提示嵌入属性,类型为 torch.Tensor 或 np.ndarray

class StableCascadePriorPipeline(DiffusionPipeline):  # 定义 StableCascadePriorPipeline 类,继承自 DiffusionPipeline
    """
    生成 Stable Cascade 的图像先验的管道。

    此模型继承自 [`DiffusionPipeline`]。请查看超类文档以获取库为所有管道实现的通用方法(如下载或保存,运行在特定设备上等)。
    # 函数的参数说明
        Args:
            prior ([`StableCascadeUNet`]):  # 稳定级联生成网络,用于近似从文本和/或图像嵌入得到的图像嵌入。
                The Stable Cascade prior to approximate the image embedding from the text and/or image embedding.
            text_encoder ([`CLIPTextModelWithProjection`]):  # 冻结的文本编码器,用于处理文本输入。
                Frozen text-encoder
                ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
            feature_extractor ([`~transformers.CLIPImageProcessor`]):  # 从生成的图像中提取特征的模型,作为输入传递给图像编码器。
                Model that extracts features from generated images to be used as inputs for the `image_encoder`.
            image_encoder ([`CLIPVisionModelWithProjection`]):  # 冻结的 CLIP 图像编码器,用于处理图像输入。
                Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
            tokenizer (`CLIPTokenizer`):  # 处理文本的分词器,能够将文本转为模型可理解的格式。
                Tokenizer of class
                [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
            scheduler ([`DDPMWuerstchenScheduler`]):  # 调度器,与 prior 结合使用,以生成图像嵌入。
                A scheduler to be used in combination with `prior` to generate image embedding.
            resolution_multiple ('float', *optional*, defaults to 42.67):  # 生成多张图像时的默认分辨率。
                Default resolution for multiple images generated.
        """
    
        unet_name = "prior"  # 设置 prior 的名称为 "prior"
        text_encoder_name = "text_encoder"  # 设置文本编码器的名称
        model_cpu_offload_seq = "image_encoder->text_encoder->prior"  # 定义模型的 CPU 卸载顺序
        _optional_components = ["image_encoder", "feature_extractor"]  # 定义可选组件列表
        _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]  # 定义回调张量输入列表
    
        def __init__(  # 初始化方法
            self,
            tokenizer: CLIPTokenizer,  # 传入的分词器对象
            text_encoder: CLIPTextModelWithProjection,  # 传入的文本编码器对象
            prior: StableCascadeUNet,  # 传入的生成网络对象
            scheduler: DDPMWuerstchenScheduler,  # 传入的调度器对象
            resolution_multiple: float = 42.67,  # 设置默认分辨率的参数
            feature_extractor: Optional[CLIPImageProcessor] = None,  # 可选的特征提取器对象
            image_encoder: Optional[CLIPVisionModelWithProjection] = None,  # 可选的图像编码器对象
        ) -> None:  # 方法返回类型
            super().__init__()  # 调用父类的初始化方法
            self.register_modules(  # 注册各个模块
                tokenizer=tokenizer,  # 注册分词器
                text_encoder=text_encoder,  # 注册文本编码器
                image_encoder=image_encoder,  # 注册图像编码器
                feature_extractor=feature_extractor,  # 注册特征提取器
                prior=prior,  # 注册生成网络
                scheduler=scheduler,  # 注册调度器
            )
            self.register_to_config(resolution_multiple=resolution_multiple)  # 将分辨率参数注册到配置中
    
        def prepare_latents(  # 准备潜在变量的方法
            self,  # 指向实例本身
            batch_size,  # 批处理的大小
            height,  # 图像的高度
            width,  # 图像的宽度
            num_images_per_prompt,  # 每个提示生成的图像数量
            dtype,  # 数据类型
            device,  # 设备类型(CPU/GPU)
            generator,  # 随机数生成器
            latents,  # 潜在变量
            scheduler  # 调度器
    ):
        # 定义潜在形状,包括每个提示的图像数量和批处理大小等信息
        latent_shape = (
            num_images_per_prompt * batch_size,
            self.prior.config.in_channels,
            ceil(height / self.config.resolution_multiple),
            ceil(width / self.config.resolution_multiple),
        )

        # 如果潜在变量为空,则随机生成一个张量
        if latents is None:
            latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype)
        else:
            # 检查潜在变量的形状是否与预期匹配,不匹配则引发错误
            if latents.shape != latent_shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}")
            # 将潜在变量转移到指定设备
            latents = latents.to(device)

        # 将潜在变量乘以调度器的初始噪声标准差
        latents = latents * scheduler.init_noise_sigma
        # 返回处理后的潜在变量
        return latents

    def encode_prompt(
        # 定义编码提示的函数,接收多个参数
        self,
        device,
        batch_size,
        num_images_per_prompt,
        do_classifier_free_guidance,
        prompt=None,
        negative_prompt=None,
        prompt_embeds: Optional[torch.Tensor] = None,
        prompt_embeds_pooled: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
    def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt):
        # 定义编码图像的函数,初始化图像嵌入列表
        image_embeds = []
        for image in images:
            # 提取特征并将图像转为张量形式
            image = self.feature_extractor(image, return_tensors="pt").pixel_values
            # 将图像张量移动到指定设备和数据类型
            image = image.to(device=device, dtype=dtype)
            # 编码图像并将结果嵌入添加到列表中
            image_embed = self.image_encoder(image).image_embeds.unsqueeze(1)
            image_embeds.append(image_embed)
        # 将所有图像嵌入按维度1拼接在一起
        image_embeds = torch.cat(image_embeds, dim=1)

        # 重复图像嵌入以匹配批处理大小和每个提示的图像数量
        image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
        # 创建与图像嵌入形状相同的零张量作为负面图像嵌入
        negative_image_embeds = torch.zeros_like(image_embeds)

        # 返回正面和负面图像嵌入
        return image_embeds, negative_image_embeds

    def check_inputs(
        # 定义输入检查函数,接收多个可能为空的参数
        self,
        prompt,
        images=None,
        image_embeds=None,
        negative_prompt=None,
        prompt_embeds=None,
        prompt_embeds_pooled=None,
        negative_prompt_embeds=None,
        negative_prompt_embeds_pooled=None,
        callback_on_step_end_tensor_inputs=None,
    @property
    # 定义属性以获取引导比例
    def guidance_scale(self):
        return self._guidance_scale

    @property
    # 定义属性以判断是否使用无分类器引导
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1

    @property
    # 定义属性以获取时间步数
    def num_timesteps(self):
        return self._num_timesteps

    def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
        # 定义获取时间步比率的条件函数
        s = torch.tensor([0.003])
        clamp_range = [0, 1]
        # 计算最小方差
        min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
        var = alphas_cumprod[t]
        # 将方差限制在指定范围内
        var = var.clamp(*clamp_range)
        s, min_var = s.to(var.device), min_var.to(var.device)
        # 计算并返回比率
        ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
        return ratio

    @torch.no_grad()
    # 装饰器用于不计算梯度
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义一个可调用的类方法,处理输入的各种参数
        def __call__(
            # 提示文本,可以是单个字符串或字符串列表
            self,
            prompt: Optional[Union[str, List[str]]] = None,
            # 输入图像,可以是单个 Tensor、PIL 图像或它们的列表
            images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
            # 输出图像的高度,默认值为 1024
            height: int = 1024,
            # 输出图像的宽度,默认值为 1024
            width: int = 1024,
            # 推理步骤的数量,默认值为 20
            num_inference_steps: int = 20,
            # 时间步列表,决定生成图像的时间步数
            timesteps: List[float] = None,
            # 引导尺度,用于调整生成图像的质量,默认值为 4.0
            guidance_scale: float = 4.0,
            # 负提示文本,可以是单个字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 预计算的提示嵌入,可以为 Tensor
            prompt_embeds: Optional[torch.Tensor] = None,
            # 预计算的池化提示嵌入,可以为 Tensor
            prompt_embeds_pooled: Optional[torch.Tensor] = None,
            # 预计算的负提示嵌入,可以为 Tensor
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 预计算的池化负提示嵌入,可以为 Tensor
            negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
            # 输入图像的嵌入,可以为 Tensor
            image_embeds: Optional[torch.Tensor] = None,
            # 每个提示生成的图像数量,默认值为 1
            num_images_per_prompt: Optional[int] = 1,
            # 随机数生成器,可以是单个生成器或生成器的列表
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量,可以为 Tensor
            latents: Optional[torch.Tensor] = None,
            # 输出类型,默认为 "pt"
            output_type: Optional[str] = "pt",
            # 是否返回字典格式的结果,默认值为 True
            return_dict: bool = True,
            # 每一步结束时调用的回调函数,接受步数、总步数和状态字典
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            # 在步骤结束时的张量输入回调列表,默认包括 "latents"
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
posted @ 2024-10-22 12:34  绝不原创的飞龙  阅读(60)  评论(0)    收藏  举报