Lucidrains-系列项目源码解析-十八-

Lucidrains 系列项目源码解析(十八)

.\lucidrains\geometric-vector-perceptron\examples\data_utils.py

# 作者:Eric Alcaide

# 导入必要的库
import os 
import sys
# 科学计算库
import torch
import torch_sparse
import numpy as np 
from einops import repeat, rearrange
# 导入自定义工具 - 来自 https://github.com/EleutherAI/mp_nerf
from data_handler import *

# 新数据构建函数
def get_atom_ids_dict():
    """ 获取将每个原子映射到一个标记的字典 """
    # 初始化原子集合
    ids = set(["N", "CA", "C", "O"])

    # 遍历 SC_BUILD_INFO 中的键值对
    for k,v in SC_BUILD_INFO.items():
        # 遍历原子名称列表
        for name in v["atom-names"]:
            ids.add(name)
            
    # 返回原子到标记的映射字典
    return {k: i for i,k in enumerate(sorted(ids))}

#################################
##### 原始项目数据 #####
#################################

# 定义氨基酸序列和对应的数字
AAS = "ARNDCQEGHILKMFPSTWYV_"
AAS2NUM = {k: AAS.index(k) for k in AAS}
# 获取原子标记字典
ATOM_IDS = get_atom_ids_dict()
# 定义氨基酸的键值对应的键为氨基酸,值为键值对应的原子键值对
GVP_DATA = { 
    'A': {
        'bonds': [[0,1], [1,2], [2,3], [1,4]] 
         },
    'R': {
        'bonds': [[0,1], [1,2], [2,3], [2,4], [4,5], [5,6],
                  [6,7], [7,8], [8,9], [8,10]] 
         },
    # 其他氨基酸的键值对应的键值对
    # ...
    '_': {
        'bonds': []
        }
    }

#################################
##### 原始项目数据 #####
#################################

def graph_laplacian_embedds(edges, eigen_k, center_idx=1, norm=False):
    """ 返回图拉普拉斯的前 K 个特征向量中点的嵌入。
        输入:
        * edges: (2, N). 长整型张量或列表。足够表示无向边。
        * eigen_k: 整数。要返回嵌入的前 N 个特征向量。
        * center_idx: 整数。用作嵌入中心的索引。
        * norm: 布尔值。是否使用归一化拉普拉斯。不建议使用。
        输出:(n_points, eigen_k)
    """
    # 如果 edges 是列表,则转换为长整型张量
    if isinstance(edges, list):
        edges = torch.tensor(edges).long()
        # 纠正维度
        if edges.shape[0] != 2:
            edges = edges.t()
        # 如果为空,则返回零张量
        if edges.shape[0] == 0:
            return torch.zeros(1, eigen_k)
    # 获取参数
    # 计算边的最大值并加1,作为邻接矩阵的大小
    size = torch.max(edges)+1
    # 获取边所在设备信息
    device = edges.device
    # 创建邻接矩阵
    adj_mat = torch.eye(size, device=device) 
    # 遍历边的起始点和终点,将邻接矩阵对应位置设为1
    for i,j in edges.t():
        adj_mat[i,j] = adj_mat[j,i] = 1.
        
    # 计算度矩阵
    deg_mat = torch.eye(size) * adj_mat.sum(dim=-1, keepdim=True)
    # 计算拉普拉斯矩阵
    laplace = deg_mat - adj_mat
    # 如果传入了norm参数,则使用规范化的拉普拉斯矩阵
    if norm:
        # 遍历边的起始点和终点,更新拉普拉斯矩阵的值
        for i,j in edges.t():
            laplace[i,j] = laplace[j,i] = -1 / (deg_mat[i,i] * deg_mat[j,j])**0.5
    # 对拉普拉斯矩阵进行特征值分解,获取特征值和特征向量
    e, v = torch.symeig(laplace, eigenvectors=True)
    # 根据特征值的绝对值降序排列,获取排序后的索引
    idxs = torch.sort( e.abs(), descending=True)[1]
    # 获取前eigen_k个特征向量作为嵌入向量
    embedds = v[:, idxs[:eigen_k]]
    # 将嵌入向量减去中心点的嵌入向量
    embedds = embedds - embedds[center_idx].unsqueeze(-2)
    # 返回处理后的嵌入向量
    return embedds
# 返回每个氨基酸中每个原子的标记
def make_atom_id_embedds(k):
    # 创建一个长度为14的零张量
    mask = torch.zeros(14).long()
    # 定义氨基酸中的原子列表
    atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[k]["atom-names"]
    # 遍历原子列表,将每个原子的标记存储在mask中
    for i,atom in enumerate(atom_list):
        mask[i] = ATOM_IDS[atom]
    return mask


#################################
########## SAVE INFO ############
#################################

# 创建包含各种信息的字典
SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
                    "bond_mask": make_bond_mask(k),
                    "theta_mask": make_theta_mask(k),
                    "torsion_mask": make_torsion_mask(k),
                    "idx_mask": make_idx_mask(k),
                    #
                    "eigen_embedd": graph_laplacian_embedds(GVP_DATA[k]["bonds"], eigen_k = 3),
                    "atom_id_embedd": make_atom_id_embedds(k)
                    } 
                for k in "ARNDCQEGHILKMFPSTWYV_"}

#################################
######### RANDOM UTILS ##########
#################################


# 使用正弦和余弦编码距离
def encode_dist(x, scales=[1,2,4,8], include_self = True):
    """ Encodes a distance with sines and cosines. 
        Inputs:
        * x: (batch, N) or (N,). data to encode.
              Infer devic and type (f16, f32, f64) from here.
        * scales: (s,) or list. lower or higher depending on distances.
        Output: (..., num_scales*2 + 1) if include_self or (..., num_scales*2) 
    """
    x = x.unsqueeze(-1)
    # 推断设备
    device, precise = x.device, x.type()
    # 转换为张量
    if isinstance(scales, list):
        scales = torch.tensor([scales], device=device).type(precise)
    # 获取正弦编码
    sines   = torch.sin(x / scales)
    cosines = torch.cos(x / scales)
    # 连接并返回
    enc_x = torch.cat([sines, cosines], dim=-1)
    return torch.cat([enc_x, x], dim=-1) if include_self else enc_x

# 解码距离
def decode_dist(x, scales=[1,2,4,8], include_self = False):
    """ Encodes a distance with sines and cosines. 
        Inputs:
        * x: (batch, N, 2*fourier_feats (+1) ) or (N,). data to encode.
              Infer devic and type (f16, f32, f64) from here.
        * scales: (s,) or list. lower or higher depending on distances.
        * include_self: whether to average with raw prediction or not.
        Output: (batch, N)
    """
    device, precise = x.device, x.type()
    # 转换为张量
    if isinstance(scales, list):
        scales = torch.tensor([scales], device=device).type(precise)
    # 通过 atan2 解码并校正负角度
    half = x.shape[-1]//2
    decodes = torch.atan2(x[..., :half], x[..., half:2*half])
    decodes += (decodes<0).type(precise) * 2*np.pi 
    # 调整偏移量
    offsets = torch.zeros_like(decodes)
    for i in range(decodes.shape[-1]-1, 0, -1):
        offsets[:, i-1] = 2 * ( offsets[:, i] + (decodes[:, i]>np.pi).type(precise) * np.pi )
    decodes += offsets
    avg_dec = (decodes * scales).mean(dim=-1, keepdim=True)
    if include_self:
        return 0.5*(avg_dec + x[..., -1:])
    return avg_dec

# 计算第n次邻接矩阵
def nth_deg_adjacency(adj_mat, n=1, sparse=False):
    """ Calculates the n-th degree adjacency matrix.
        Performs mm of adj_mat and adds the newly added.
        Default is dense. Mods for sparse version are done when needed.
        Inputs: 
        * adj_mat: (N, N) adjacency tensor
        * n: int. degree of the output adjacency
        * sparse: bool. whether to use torch-sparse module
        Outputs: 
        * edge_idxs: the ij positions of the adjacency matrix
        * edge_attrs: the degree of connectivity (1 for neighs, 2 for neighs^2 )
    """
    adj_mat = adj_mat.float()
    attr_mat = torch.zeros_like(adj_mat)
    # 遍历范围为n的循环
    for i in range(n):
        # 如果i为0,则将属性矩阵与邻接矩阵相加
        if i == 0:
            attr_mat += adj_mat
            continue

        # 如果i为1且sparse为True
        if i == 1 and sparse: 
            # 创建稀疏邻接张量
            adj_mat = torch.sparse.FloatTensor(adj_mat.nonzero().t(),
                                                adj_mat[adj_mat != 0]).to(adj_mat.device).coalesce()
            idxs, vals = adj_mat.indices(), adj_mat.values()
            m, k, n = 3 * [adj_mat.shape[0]]  # (m, n) * (n, k) , 但adj_mats是方阵:m=n=k

        # 如果sparse为True
        if sparse:
            # 使用torch_sparse库中的spspmm函数进行稀疏矩阵乘法
            idxs, vals = torch_sparse.spspmm(idxs, vals, idxs, vals, m=m, k=k, n=n)
            adj_mat = torch.zeros_like(attr_mat)
            adj_mat[idxs[0], idxs[1]] = vals.bool().float()
        else:
            # 如果sparse为False,则将邻接矩阵平方,转换为布尔型矩阵
            adj_mat = (adj_mat @ adj_mat).bool().float() 

        # 更新属性矩阵
        attr_mat[(adj_mat - attr_mat.bool().float()).bool()] += i + 1

    # 返回更新后的邻接矩阵和属性矩阵
    return adj_mat, attr_mat
# 返回蛋白质的共价键的索引
def prot_covalent_bond(seq, adj_degree=1, cloud_mask=None):
    """ 返回蛋白质的共价键的索引。
        输入
        * seq: str. 用1字母氨基酸代码表示的蛋白质序列。
        * cloud_mask: 选择存在原子的掩码。
        输出: edge_idxs
    """
    # 创建或推断 cloud_mask
    if cloud_mask is None: 
        cloud_mask = scn_cloud_mask(seq).bool()
    device, precise = cloud_mask.device, cloud_mask.type()
    # 获取每个氨基酸的起始位置
    scaff = torch.zeros_like(cloud_mask)
    scaff[:, 0] = 1
    idxs = scaff[cloud_mask].nonzero().view(-1)
    # 从包含 GVP_DATA 的字典中获取姿势 + 索引 - 返回所有边
    adj_mat = torch.zeros(idxs.amax()+14, idxs.amax()+14)
    for i,idx in enumerate(idxs):
        # 与下一个氨基酸的键
        extra = []
        if i < idxs.shape[0]-1:
            extra = [[2, (idxs[i+1]-idx).item()]]

        bonds = idx + torch.tensor( GVP_DATA[seq[i]]['bonds'] + extra ).long().t() 
        adj_mat[bonds[0], bonds[1]] = 1.
    # 转换为无向图
    adj_mat = adj_mat + adj_mat.t()
    # 进行 N 次邻接
    adj_mat, attr_mat = nth_deg_adjacency(adj_mat, n=adj_degree, sparse=True)

    edge_idxs = attr_mat.nonzero().t().long()
    edge_attrs = attr_mat[edge_idxs[0], edge_idxs[1]]
    return edge_idxs, edge_attrs


def dist2ca(x, mask=None, eps=1e-7):
    """ 计算每个点到 C-alfa 的距离。
        输入:
        * x: (L, 14, D)
        * mask: (L, 14) 的布尔掩码
        返回单位向量和范数。
    """
    x = x - x[:, 1].unsqueeze(1)
    norm = torch.norm(x, dim=-1, keepdim=True)
    x_norm = x / (norm+eps)
    if mask:
        return x_norm[mask], norm[mask]
    return x_norm, norm


def orient_aa(x, mask=None, eps=1e-7):
    """ 计算主链特征的单位向量和范数。
        输入:
        * x: (L, 14, D). Sidechainnet 格式的坐标。
        返回单位向量 (5) 和范数 (3)。
    """
    # 获取张量信息
    device, precise = x.device, x.type()

    vec_wrap  = torch.zeros(5, x.shape[0], 3, device=device) # (feats, L, dims+1)
    norm_wrap = torch.zeros(3, x.shape[0], device=device)
    # 第一个特征是 CB-CA
    vec_wrap[0]  = x[:, 4] - x[:, 1]
    norm_wrap[0] = torch.norm(vec_wrap[0], dim=-1)
    vec_wrap[0] /= norm_wrap[0].unsqueeze(dim=-1) + eps
    # 第二个是 CA+ - CA :
    vec_wrap[1, :-1]  = x[:-1, 1] - x[1:, 1]
    norm_wrap[1, :-1] = torch.norm(vec_wrap[1, :-1], dim=-1)
    vec_wrap[1, :-1] /= norm_wrap[1, :-1].unsqueeze(dim=-1) + eps
    # 同样但是反向向量
    vec_wrap[2] = (-1)*vec_wrap[1]
    # 第三个是 CA - CA-
    vec_wrap[3, 1:]  = x[:-1, 1] - x[1:, 1]
    norm_wrap[2, 1:] = torch.norm(vec_wrap[3, 1:], dim=-1)
    vec_wrap[3, 1:] /= norm_wrap[2, 1:].unsqueeze(dim=-1) + eps
    # 现在反向顺序的向量
    vec_wrap[4] = (-1)*vec_wrap[3]

    return vec_wrap, norm_wrap


def chain2atoms(x, mask=None):
    """ 从 (L, other) 扩展到 (L, C, other)。"""
    device, precise = x.device, x.type()
    # 获取掩码
    wrap = torch.ones(x.shape[0], 14, *x.shape[1:]).type(precise).to(device)
    # 分配
    wrap = wrap * x.unsqueeze(1)
    if mask is not None:
        return wrap[mask]
    return wrap


def from_encode_to_pred(whole_point_enc, use_fourier=False, embedd_info=None, needed_info=None, vec_dim=3):
    """ 将上述函数的编码转换为标签/预测格式。
        仅包含位置恢复所需的基本信息 (径向单位向量 + 范数)
        输入: 包含以下内容的输入元组:
        * whole_point_enc: (atoms, vector_dims+scalar_dims)
                           与上述函数相同的形状。
                           径向单位向量必须是第一个向量维度
        * embedd_info: 字典。包含标量和向量特征的数量。
    """
    vec_dims = vec_dim * embedd_info["point_n_vectors"]
    start_pos = 2*len(needed_info["atom_pos_scales"])+vec_dims
    # 如果使用傅立叶变换
    if use_fourier:
        # 解码整个点编码中的部分向量维度,不包括自身
        decoded_dist = decode_dist(whole_point_enc[:, vec_dims:start_pos+1],
                                    scales=needed_info["atom_pos_scales"],
                                    include_self=False)
    else:
        # 如果不使用傅立叶变换,直接取整个点编码中的特定维度
        decoded_dist = whole_point_enc[:, start_pos:start_pos+1]
    # 返回连接后的张量,包括单位径向向量和向量范数
    return torch.cat([whole_point_enc[:, :3], decoded_dist], dim=-1)
def encode_whole_bonds(x, x_format="coords", embedd_info={},
                       needed_info = {"cutoffs": [2,5,10],
                                      "bond_scales": [.5, 1, 2],
                                      "adj_degree": 1},
                       free_mem=False, eps=1e-7):
    """ Given some coordinates, and the needed info,
        encodes the bonds from point information.
        * x: (N, 3) or prediction format
        * x_format: one of ["coords" or "prediction"]
        * embedd_info: dict. contains the needed embedding info
        * needed_info: dict. contains additional needed info
            { cutoffs: list. cutoff distances for bonds.
                       can be a string for the k closest (ex: "30_closest"),
                       empty list for just covalent.
              bond_scales: list. fourier encodings
              adj_degree: int. degree of adj (2 means adj of adj is my adj)
                               0 for no adjacency
            }
        * free_mem: whether to delete variables
        * eps: constant for numerical stability
    """ 
    device, precise = x.device, x.type()
    # convert to 3d coords if passed as preds
    if x_format == "encode":
        pred_x = from_encode_to_pred(x, embedd_info=embedd_info, needed_info=needed_info)
        x = pred_x[:, :3] * pred_x[:, 3:4]

    # encode bonds

    # 1. BONDS: find the covalent bond_indices - allow arg -> DRY
    native = None
    if "prot_covalent_bond" in needed_info.keys():
        native = True
        native_bonds = needed_info["covalent_bond"]
    elif needed_info["adj_degree"]:
        native = True
        native_bonds  = prot_covalent_bond(needed_info["seq"], needed_info["adj_degree"])
        
    if native: 
        native_idxs, native_attrs = native_bonds[0].to(device), native_bonds[1].to(device)

    # determine kind of cutoff (hard distance threshold or closest points)
    closest = None
    if len(needed_info["cutoffs"]) > 0: 
        cutoffs = needed_info["cutoffs"].copy() 
        if sum( isinstance(ci, str) for ci in cutoffs ) > 0:
            cutoffs = [-1e-3] # negative so no bond is taken  
            closest = int( needed_info["cutoffs"][0].split("_")[0] ) 

        # points under cutoff = d(i - j) < X 
        cutoffs = torch.tensor(cutoffs, device=device).type(precise)
        dist_mat = torch.cdist(x, x, p=2)

    # normal buckets
    bond_buckets = torch.zeros(*x.shape[:-1], x.shape[-2], device=device).type(precise)
    if len(needed_info["cutoffs"]) > 0 and not closest:
        # count from latest degree of adjacency given
        bond_buckets = torch.bucketize(dist_mat, cutoffs)
        bond_buckets[native_idxs[0], native_idxs[1]] = cutoffs.shape[0]
        # find the indexes - symmetric and we dont want the diag
        bond_buckets   += cutoffs.shape[0] * torch.eye(bond_buckets.shape[0], device=device).long()
        close_bond_idxs = ( bond_buckets < cutoffs.shape[0] ).nonzero().t()
        # move away from poses reserved for native
        bond_buckets[close_bond_idxs[0], close_bond_idxs[1]] += needed_info["adj_degree"]+1

    # the K closest (covalent bonds excluded) are considered bonds 
    # 如果存在最近的键,执行以下操作
    elif closest:
        # 将距离矩阵复制一份,并将共价键屏蔽掉
        masked_dist_mat = dist_mat.clone()
        masked_dist_mat += torch.eye(masked_dist_mat.shape[0], device=device) * torch.amax(masked_dist_mat)
        masked_dist_mat[native_idxs[0], native_idxs[1]] = masked_dist_mat[0,0].clone()
        # 根据距离排序,*(-1)使得最小值在前
        _, sorted_col_idxs = torch.topk(-masked_dist_mat, k=k, dim=-1)
        # 连接索引并重复行索引以匹配列索引的数量
        sorted_col_idxs = rearrange(sorted_col_idxs[:, :k], '... n k -> ... (n k)')
        sorted_row_idxs = torch.repeat_interleave( torch.arange(dist_mat.shape[0]).long(), repeats=k ).to(device)
        close_bond_idxs = torch.stack([ sorted_row_idxs, sorted_col_idxs ], dim=0)
        # 将远离保留给原生的姿势
        bond_buckets = torch.ones_like(dist_mat) * (needed_info["adj_degree"]+1)

    # 合并所有键
    if len(needed_info["cutoffs"]) > 0:
        if close_bond_idxs.shape[0] > 0:
            whole_bond_idxs = torch.cat([native_idxs, close_bond_idxs], dim=-1)
    else:
        whole_bond_idxs = native_idxs

    # 2. ATTRS: 将键编码为属性
    bond_vecs  = x[ whole_bond_idxs[0] ] - x[ whole_bond_idxs[1] ]
    bond_norms = torch.norm(bond_vecs, dim=-1)
    bond_vecs /= (bond_norms + eps).unsqueeze(-1)
    bond_norms_enc = encode_dist(bond_norms, scales=needed_info["bond_scales"]).squeeze()

    if native:
        bond_buckets[native_idxs[0], native_idxs[1]] = native_attrs
    bond_attrs = bond_buckets[whole_bond_idxs[0] , whole_bond_idxs[1]]
    # 打包标量和向量 - 额外的令牌用于共价键
    bond_n_vectors = 1
    bond_n_scalars = (2 * len(needed_info["bond_scales"]) + 1) + 1 # 最后一个是大小为1+len(cutoffs)的嵌入
    whole_bond_enc = torch.cat([bond_vecs, # 1个向量 - 不需要反转 - 我们做2倍的键(对称性)
                                # 标量
                                bond_norms_enc, # 2 * len(scales)
                                (bond_attrs-1).unsqueeze(-1) # 1 
                               ], dim=-1) 
    # 释放 GPU 内存
    if free_mem:
        del bond_buckets, bond_norms_enc, bond_vecs, dist_mat,\
            close_bond_idxs, native_bond_idxs
        if closest: 
            del masked_dist_mat, sorted_col_idxs, sorted_row_idxs

    embedd_info = {"bond_n_vectors": bond_n_vectors, 
                   "bond_n_scalars": bond_n_scalars, 
                   "bond_embedding_nums": [ len(needed_info["cutoffs"]) + needed_info["adj_degree"] ]} # 额外一个用于共价键(默认)

    return whole_bond_idxs, whole_bond_enc, embedd_info
def encode_whole_protein(seq, true_coords, angles, padding_seq,
                         needed_info = { "cutoffs": [2, 5, 10],
                                          "bond_scales": [0.5, 1, 2]}, free_mem=False):
    """ Encodes a whole protein. In points + vectors. """
    # 获取设备和数据类型
    device, precise = true_coords.device, true_coords.type()
    #################
    # encode points #
    #################
    # 创建云掩码
    cloud_mask = torch.tensor(scn_cloud_mask(seq[:-padding_seq or None])).bool().to(device)
    flat_mask = rearrange(cloud_mask, 'l c -> (l c)')
    # 嵌入所有内容

    # 一般位置嵌入
    center_coords = true_coords - true_coords.mean(dim=0)
    pos_unit_norms = torch.norm(center_coords, dim=-1, keepdim=True)
    pos_unit_vecs  = center_coords / pos_unit_norms
    pos_unit_norms_enc = encode_dist(pos_unit_norms, scales=needed_info["atom_pos_scales"]).squeeze()
    # 重新格式化坐标到scn (L, 14, 3) - 待解决如果填充=0
    coords_wrap = rearrange(center_coords, '(l c) d -> l c d', c=14)[:-padding_seq or None] 

    # 蛋白质中的位置嵌入
    aa_pos = encode_dist( torch.arange(len(seq[:-padding_seq or None]), device=device).float(), scales=needed_info["aa_pos_scales"])
    atom_pos = chain2atoms(aa_pos)[cloud_mask]

    # 原子标识嵌入
    atom_id_embedds = torch.stack([SUPREME_INFO[k]["atom_id_embedd"] for k in seq[:-padding_seq or None]], 
                                  dim=0)[cloud_mask].to(device)
    # 氨基酸嵌入
    seq_int = torch.tensor([AAS2NUM[aa] for aa in seq[:-padding_seq or None]], device=device).long()
    aa_id_embedds   = chain2atoms(seq_int, mask=cloud_mask)

    # CA - SC 距离
    dist2ca_vec, dist2ca_norm = dist2ca(coords_wrap) 
    dist2ca_norm_enc = encode_dist(dist2ca_norm, scales=needed_info["dist2ca_norm_scales"]).squeeze()

    # 主链特征
    vecs, norms    = orient_aa(coords_wrap)
    bb_vecs_atoms  = chain2atoms(torch.transpose(vecs, 0, 1), mask=cloud_mask)
    bb_norms_atoms = chain2atoms(torch.transpose(norms, 0, 1), mask=cloud_mask)
    bb_norms_atoms_enc = encode_dist(bb_norms_atoms, scales=[0.5])

    ################
    # encode bonds #
    ################
    bond_info = encode_whole_bonds(x = coords_wrap[cloud_mask],
                                   x_format = "coords",
                                   embedd_info = {},
                                   needed_info = needed_info )
    whole_bond_idxs, whole_bond_enc, bond_embedd_info = bond_info
    #########
    # merge #
    #########

    # 连接以使最终为[矢量维度,标量维度]
    point_n_vectors = 1 + 1 + 5
    point_n_scalars = 2*len(needed_info["atom_pos_scales"]) + 1 +\
                      2*len(needed_info["aa_pos_scales"]) + 1 +\
                      2*len(needed_info["dist2ca_norm_scales"]) + 1+\
                      rearrange(bb_norms_atoms_enc, 'atoms feats encs -> atoms (feats encs)').shape[1] +\
                      2 # 最后2个尚未嵌入

    whole_point_enc = torch.cat([ pos_unit_vecs[ :-padding_seq*14 or None ][ flat_mask ], # 1
                                  dist2ca_vec[cloud_mask], # 1
                                  rearrange(bb_vecs_atoms, 'atoms n d -> atoms (n d)'), # 5
                                  # 标量
                                  pos_unit_norms_enc[ :-padding_seq*14 or None ][ flat_mask ], # 2n+1
                                  atom_pos, # 2n+1
                                  dist2ca_norm_enc[cloud_mask], # 2n+1
                                  rearrange(bb_norms_atoms_enc, 'atoms feats encs -> atoms (feats encs)'), # 2n+1
                                  atom_id_embedds.unsqueeze(-1),
                                  aa_id_embedds.unsqueeze(-1) ], dim=-1) # 最后2个尚未嵌入
    if free_mem:
        del pos_unit_vecs, dist2ca_vec, bb_vecs_atoms, pos_unit_norms_enc, cloud_mask,\
            atom_pos, dist2ca_norm_enc, bb_norms_atoms_enc, atom_id_embedds, aa_id_embedds
    # 记录嵌入维度信息,包括点向量数量和标量数量
    point_embedd_info = {"point_n_vectors": point_n_vectors,
                         "point_n_scalars": point_n_scalars,}

    # 合并点和键的嵌入信息
    embedd_info = {**point_embedd_info, **bond_embedd_info}

    # 返回整体点编码、整体键索引、整体键编码和嵌入信息
    return whole_point_enc, whole_bond_idxs, whole_bond_enc, embedd_info
def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, verbose=True):
    """ Gets a protein from sidechainnet and returns
        the right attrs for training. 
        Inputs: 
        * dataloader_: sidechainnet iterator over dataset
        * vocab_: sidechainnet VOCAB class
        * min_len: int. minimum sequence length
        * max_len: int. maximum sequence length
        * verbose: bool. verbosity level
    """
    # 遍历数据加载器中的训练数据批次
    for batch in dataloader_['train']:
        # 尝试在两个循环中同时中断
        try:
            # 遍历当前批次中的序列
            for i in range(batch.int_seqs.shape[0]):
                # 获取变量
                seq     = ''.join([vocab_.int2char(aa) for aa in batch.int_seqs[i].numpy()])
                int_seq = batch.int_seqs[i]
                angles  = batch.angs[i]
                mask    = batch.msks[i]
                # 获取填充
                padding_angles = (torch.abs(angles).sum(dim=-1) == 0).long().sum()
                padding_seq    = (batch.int_seqs[i] == 20).sum()
                # 仅接受具有正确维度且没有缺失坐标的序列
                # 大于0以避免后续负索引错误
                if batch.crds[i].shape[0]//14 == int_seq.shape[0]:
                    if ( max_len > len(seq) and len(seq) > min_len ) and padding_seq == padding_angles: 
                        if verbose:
                            print("stopping at sequence of length", len(seq))
                            # print(len(seq), angles.shape, "paddings: ", padding_seq, padding_angles)
                        # 触发 StopIteration 异常
                        raise StopIteration
                    else:
                        # print("found a seq of length:", len(seq),
                        #        "but oustide the threshold:", min_len, max_len)
                        pass
        except StopIteration:
            # 中断外部循环
            break
            
    # 返回序列、坐标、角度、填充序列、掩码和蛋白质ID
    return seq, batch.crds[i], angles, padding_seq, batch.msks[i], batch.pids[i]

GVP - Point Cloud

Geometric Vector Perceptron applied to Point Clouds

To install:

  1. git clone ${repo_url}
  2. install packages:
  3. Try to run the notebooks (they should run, report errors if encountered)
    • proto_dev_model.ipynb: shows how to gather the data and train a simple model on it, then reconstruct original struct and calculate improvement.

Descritpion:

  1. encode a protein (3d) into some features (scalars and position vectors)
    • we encode both point features and edge features
  2. train the model to predict the right point features back
  3. reconstruct the 3d case to see the improvement

TO DO LIST:

See the issues tab?

Contribute

PRs and ideas are welcome. Describe a list of the changes you've made and provide tests/examples if possible (they're not requiered, but surely helps understanding).

.\lucidrains\geometric-vector-perceptron\examples\scn_data_module.py

# 导入必要的模块
from argparse import ArgumentParser
from typing import List, Optional
from typing import Union

import numpy as np
import pytorch_lightning as pl
import sidechainnet
from sidechainnet.dataloaders.collate import get_collate_fn
from sidechainnet.utils.sequence import ProteinVocabulary
from torch.utils.data import DataLoader, Dataset

# 定义自定义数据集类
class ScnDataset(Dataset):
    def __init__(self, dataset, max_len: int):
        super(ScnDataset, self).__init__()
        self.dataset = dataset

        self.max_len = max_len
        self.scn_collate_fn = get_collate_fn(False)
        self.vocab = ProteinVocabulary()

    # 定义数据集的拼接函数
    def collate_fn(self, batch):
        batch = self.scn_collate_fn(batch)
        real_seqs = [
            "".join([self.vocab.int2char(aa) for aa in seq])
            for seq in batch.int_seqs.numpy()
        ]
        seq = real_seqs[0][: self.max_len]
        true_coords = batch.crds[0].view(-1, 14, 3)[: self.max_len].view(-1, 3)
        angles = batch.angs[0, : self.max_len]
        mask = batch.msks[0, : self.max_len]

        # 计算填充序列的长度
        padding_seq = (np.array([*seq]) == "_").sum()
        return {
            "seq": seq,
            "true_coords": true_coords,
            "angles": angles,
            "padding_seq": padding_seq,
            "mask": mask,
        }

    # 获取数据集中指定索引的数据
    def __getitem__(self, index: int):
        return self.dataset[index]

    # 返回数据集的长度
    def __len__(self) -> int:
        return len(self.dataset)

# 定义数据模块类
class ScnDataModule(pl.LightningDataModule):
    # 添加数据特定参数
    @staticmethod
    def add_data_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--casp_version", type=int, default=7)
        parser.add_argument("--scn_dir", type=str, default="./sidechainnet_data")
        parser.add_argument("--train_batch_size", type=int, default=1)
        parser.add_argument("--eval_batch_size", type=int, default=1)
        parser.add_argument("--num_workers", type=int, default=1)
        parser.add_argument("--train_max_len", type=int, default=256)
        parser.add_argument("--eval_max_len", type=int, default=256)

        return parser

    # 初始化数据模块
    def __init__(
        self,
        casp_version: int = 7,
        scn_dir: str = "./sidechainnet_data",
        train_batch_size: int = 1,
        eval_batch_size: int = 1,
        num_workers: int = 1,
        train_max_len: int = 256,
        eval_max_len: int = 256,
        **kwargs,
    ):
        super().__init__()

        assert train_batch_size == eval_batch_size == 1, "batch size must be 1 for now"

        self.casp_version = casp_version
        self.scn_dir = scn_dir
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.num_workers = num_workers
        self.train_max_len = train_max_len
        self.eval_max_len = eval_max_len

    # 设置数据模块
    def setup(self, stage: Optional[str] = None):
        dataloaders = sidechainnet.load(
            casp_version=self.casp_version,
            scn_dir=self.scn_dir,
            with_pytorch="dataloaders",
        )
        print(
            dataloaders.keys()
        )  # ['train', 'train_eval', 'valid-10', ..., 'valid-90', 'test']

        self.train = ScnDataset(dataloaders["train"].dataset, self.train_max_len)
        self.val = ScnDataset(dataloaders["valid-90"].dataset, self.eval_max_len)
        self.test = ScnDataset(dataloaders["test"].dataset, self.eval_max_len)

    # 获取训练数据加载器
    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        return DataLoader(
            self.train,
            batch_size=self.train_batch_size,
            shuffle=True,
            collate_fn=self.train.collate_fn,
            num_workers=self.num_workers,
            pin_memory=True,
        )
    # 定义用于验证数据集的数据加载器函数,返回一个数据加载器对象或对象列表
    def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        # 创建一个数据加载器对象,用于加载验证数据集
        return DataLoader(
            self.val,  # 使用验证数据集
            batch_size=self.eval_batch_size,  # 指定批量大小
            shuffle=False,  # 不打乱数据集顺序
            collate_fn=self.val.collate_fn,  # 使用验证数据集的数据整理函数
            num_workers=self.num_workers,  # 指定数据加载器的工作进程数
            pin_memory=True,  # 将数据加载到 CUDA 固定内存中
        )

    # 定义用于测试数据集的数据加载器函数,返回一个数据加载器对象或对象列表
    def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        # 创建一个数据加载器对象,用于加载测试数据集
        return DataLoader(
            self.test,  # 使用测试数据集
            batch_size=self.eval_batch_size,  # 指定批量大小
            shuffle=False,  # 不打乱数据集顺序
            collate_fn=self.test.collate_fn,  # 使用测试数据集的数据整理函数
            num_workers=self.num_workers,  # 指定数据加载器的工作进程数
            pin_memory=True,  # 将数据加载到 CUDA 固定内存中
        )
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 创建一个 ScnDataModule 的实例对象
    dm = ScnDataModule()
    # 设置数据模块
    dm.setup()

    # 获取训练数据加载器
    train = dm.train_dataloader()
    # 打印训练数据加载器的长度
    print("train length", len(train))

    # 获取验证数据加载器
    valid = dm.val_dataloader()
    # 打印验证数据加载器的长度
    print("valid length", len(valid))

    # 获取测试数据加载器
    test = dm.test_dataloader()
    # 打印测试数据加载器的长度
    print("test length", len(test))

    # 遍历测试数据加载器
    for batch in test:
        # 打印当前批次的数据
        print(batch)
        # 跳出循环,只打印第一个批次的数据
        break

.\lucidrains\geometric-vector-perceptron\examples\train_lightning.py

import gc
from argparse import ArgumentParser
from functools import partial
from pathlib import Path
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange
from loguru import logger
from pytorch_lightning.callbacks import (
    GPUStatsMonitor,
    LearningRateMonitor,
    ModelCheckpoint,
    ProgressBar,
)
from pytorch_lightning.loggers import TensorBoardLogger

from examples.data_handler import kabsch_torch, scn_cloud_mask
from examples.data_utils import (
    encode_whole_bonds,
    encode_whole_protein,
    from_encode_to_pred,
    prot_covalent_bond,
)
from examples.scn_data_module import ScnDataModule
from geometric_vector_perceptron.geometric_vector_perceptron import GVP_Network

# 定义一个继承自 LightningModule 的结构模型类
class StructureModel(pl.LightningModule):
    # 静态方法,用于添加模型特定参数
    @staticmethod
    def add_model_specific_args(parent_parser):
        # 创建参数解析器
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        # 添加模型参数
        parser.add_argument("--depth", type=int, default=4)
        parser.add_argument("--cutoffs", type=float, default=1.0)
        parser.add_argument("--noise", type=float, default=1.0)
        # 添加优化器和调度器参数
        parser.add_argument("--init_lr", type=float, default=1e-3)

        return parser

    # 初始化方法,接受模型参数
    def __init__(
        self,
        depth: int = 1,
        cutoffs: float = 1.0,
        noise: float = 1.0,
        init_lr: float = 1e-3,
        **kwargs,
    ):
        super().__init__()

        # 保存超参数
        self.save_hyperparameters()
        
        # 定义需要的信息字典
        self.needed_info = {
            "cutoffs": [cutoffs], # -1e-3 for just covalent, "30_closest", 5. for under 5A, etc
            "bond_scales": [1, 2, 4],
            "aa_pos_scales": [1, 2, 4, 8, 16, 32, 64, 128],
            "atom_pos_scales": [1, 2, 4, 8, 16, 32],
            "dist2ca_norm_scales": [1, 2, 4],
            "bb_norms_atoms": [0.5],  # will encode 3 vectors with this
        }

        # 创建 GVP_Network 模型
        self.model = GVP_Network(
            n_layers=depth,
            feats_x_in=48,
            vectors_x_in=7,
            feats_x_out=48,
            vectors_x_out=7,
            feats_edge_in=8,
            vectors_edge_in=1,
            feats_edge_out=8,
            vectors_edge_out=1,
            embedding_nums=[36, 20],
            embedding_dims=[16, 16],
            edge_embedding_nums=[2],
            edge_embedding_dims=[2],
            residual=True,
            recalc=1
        )

        self.noise = noise
        self.init_lr = init_lr

        self.baseline_losses = [] # 存储基准损失
        self.epoch_losses = [] # 存储每个 epoch 的损失
    # 定义前向传播函数,接受序列、真实坐标、角度、填充序列、掩码作为输入
    def forward(self, seq, true_coords, angles, padding_seq, mask):
        # 获取需要的信息
        needed_info = self.needed_info
        # 获取设备信息
        device = true_coords.device

        # 将序列截取到填充序列之前的部分
        needed_info["seq"] = seq[: (-padding_seq) or None]
        # 计算蛋白质的共价键
        needed_info["covalent_bond"] = prot_covalent_bond(needed_info["seq"])

        # 对整个蛋白质进行编码
        pre_target = encode_whole_protein(
            seq,
            true_coords,
            angles,
            padding_seq,
            needed_info=needed_info,
            free_mem=True,
        )
        pre_target_x, _, _, embedd_info = pre_target

        # 对蛋白质进行编码并加入噪声
        encoded = encode_whole_protein(
            seq,
            true_coords + self.noise * torch.randn_like(true_coords),
            angles,
            padding_seq,
            needed_info=needed_info,
            free_mem=True,
        )

        x, edge_index, edge_attrs, embedd_info = encoded

        # 创建批次信息
        batch = torch.tensor([0 for i in range(x.shape[0])], device=x.device).long()

        # 添加位置坐标
        cloud_mask = scn_cloud_mask(seq[: (-padding_seq) or None]).to(device)
        chain_mask = mask[: (-padding_seq) or None].unsqueeze(-1) * cloud_mask
        flat_chain_mask = rearrange(chain_mask.bool(), "l c -> (l c)")
        cloud_mask = cloud_mask.bool()
        flat_cloud_mask = rearrange(cloud_mask, "l c -> (l c)")

        # 部分重新计算边
        recalc_edge = partial(
            encode_whole_bonds,
            x_format="encode",
            embedd_info=embedd_info,
            needed_info=needed_info,
            free_mem=True,
        )

        # 预测
        scores = self.model.forward(
            x,
            edge_index,
            batch=batch,
            edge_attr=edge_attrs,
            recalc_edge=recalc_edge,
            verbose=False,
        )

        # 格式化预测、基线和目标
        target = from_encode_to_pred(
            pre_target_x, embedd_info=embedd_info, needed_info=needed_info
        )
        pred = from_encode_to_pred(
            scores, embedd_info=embedd_info, needed_info=needed_info
        )
        base = from_encode_to_pred(x, embedd_info=embedd_info, needed_info=needed_info)

        # 计算误差

        # 选项1:损失是输出令牌的均方误差
        # loss_ = (target-pred)**2
        # loss  = loss_.mean()

        # 选项2:损失是重构坐标的RMSD
        target_coords = target[:, 3:4] * target[:, :3]
        pred_coords = pred[:, 3:4] * pred[:, :3]
        base_coords = base[:, 3:4] * base[:, :3]

        ## 对齐 - 有时svc失败 - 不知道为什么
        try:
            pred_aligned, target_aligned = kabsch_torch(pred_coords.t(), target_coords.t()) # (3, N)
            base_aligned, _ = kabsch_torch(base_coords.t(), target_coords.t())
            loss = ( (pred_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5 
            loss_base = ( (base_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5 
        except:
            pred_aligned, target_aligned = None, None
            print("svd failed convergence, ep:", ep)
            loss = ( (pred_coords.t() - target_coords.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5
            loss_base = ( (base_coords - target_coords)[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5 

        # 释放GPU内存
        del true_coords, angles, pre_target_x, edge_index, edge_attrs
        del scores, target_coords, pred_coords, base_coords
        del encoded, pre_target, target_aligned, pred_aligned
        gc.collect()

        # 返回损失
        return {"loss": loss, "loss_base": loss_base}

    # 配置优化器
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.init_lr)
        return optimizer

    # 训练开始时的操作
    def on_train_start(self) -> None:
        self.baseline_losses = []
        self.epoch_losses = []
    # 训练步骤,接收一个批次数据和批次索引
    def training_step(self, batch, batch_idx):
        # 调用前向传播函数得到输出
        output = self.forward(**batch)
        # 获取损失值和基准损失值
        loss = output["loss"]
        loss_base = output["loss_base"]

        # 如果损失值为空或为 NaN,则返回 None
        if loss is None or torch.isnan(loss):
            return None

        # 将损失值和基准损失值添加到对应的列表中
        self.epoch_losses.append(loss.item())
        self.baseline_losses.append(loss_base.item())

        # 记录训练损失值到日志中,显示在进度条中
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_loss_base", output["loss_base"], on_epoch=True, prog_bar=False)

        # 返回损失值
        return loss

    # 训练结束时的操作
    def on_train_end(self) -> None:
        # 创建一个图形窗口
        plt.figure(figsize=(15, 6))
        # 设置图形标题
        plt.title(
            f"Loss Evolution - Denoising of Gaussian-masked Coordinates (mu=0, sigma={self.noise})"
        )
        # 绘制训练损失值随时间的变化曲线

        # 绘制滑动窗口平均值曲线
        for window in [8, 16, 32]:
            # 计算滑动窗口平均值
            plt.plot(
                [
                    np.mean(self.epoch_losses[:window][0 : i + 1])
                    for i in range(min(window, len(self.epoch_losses))
                ]
                + [
                    np.mean(self.epoch_losses[i : i + window + 1])
                    for i in range(len(self.epoch_losses) - window)
                ],
                label="Window mean n={0}".format(window),
            )

        # 绘制基准损失值的水平虚线
        plt.plot(
            np.ones(len(self.epoch_losses)) * np.mean(self.baseline_losses),
            "k--",
            label="Baseline",
        )

        # 设置 x 轴范围
        plt.xlim(-0.01 * len(self.epoch_losses), 1.01 * len(self.epoch_losses))
        # 设置 y 轴标签
        plt.ylabel("RMSD")
        # 设置 x 轴标签
        plt.xlabel("Batch number")
        # 添加图例
        plt.legend()
        # 保存图形为 PDF 文件
        plt.savefig("loss.pdf")

    # 验证步骤,接收一个批次数据和批次索引
    def validation_step(self, batch, batch_idx):
        # 调用前向传播函数得到输出,并记录验证损失值到日志中
        output = self.forward(**batch)
        self.log("val_loss", output["loss"], on_epoch=True, sync_dist=True)
        self.log("val_loss_base", output["loss_base"], on_epoch=True, sync_dist=True)

    # 测试步骤,接收一个批次数据和批次索引
    def test_step(self, batch, batch_idx):
        # 调用前向传播函数得到输出,并记录测试损失值到日志中
        output = self.forward(**batch)
        self.log("test_loss", output["loss"], on_epoch=True, sync_dist=True)
        self.log("test_loss_base", output["loss_base"], on_epoch=True, sync_dist=True)
# 根据参数获取训练器对象
def get_trainer(args):
    # 设置随机种子
    pl.seed_everything(args.seed)

    # 创建日志记录器
    root_dir = Path(args.default_root_dir).expanduser().resolve()
    root_dir.mkdir(parents=True, exist_ok=True)
    tb_save_dir = root_dir / "tb"
    tb_logger = TensorBoardLogger(save_dir=tb_save_dir)
    loggers = [tb_logger]
    logger.info(f"Run tensorboard --logdir {tb_save_dir}")

    # 创建回调函数
    ckpt_cb = ModelCheckpoint(verbose=True)
    lr_cb = LearningRateMonitor(logging_interval="step")
    pb_cb = ProgressBar(refresh_rate=args.progress_bar_refresh_rate)
    callbacks = [lr_cb, pb_cb]

    callbacks.append(ckpt_cb)

    gpu_cb = GPUStatsMonitor()
    callbacks.append(gpu_cb)

    plugins = []
    # 根据参数创建训练器对象
    trainer = pl.Trainer.from_argparse_args(
        args, logger=loggers, callbacks=callbacks, plugins=plugins
    )

    return trainer


def main(args):
    # 创建数据模块对象
    dm = ScnDataModule(**vars(args))
    # 创建模型对象
    model = StructureModel(**vars(args))
    # 获取训练器对象
    trainer = get_trainer(args)
    # 训练模型
    trainer.fit(model, datamodule=dm)
    # 测试模型并获取指标
    metrics = trainer.test(model, datamodule=dm)
    print("test", metrics)


if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_argument("--seed", type=int, default=23333, help="Seed everything.")

    # 添加模型特定参数
    parser = StructureModel.add_model_specific_args(parser)

    # 添加数据特定参数
    parser = ScnDataModule.add_data_specific_args(parser)

    # 添加训练器参数
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # 打印参数
    pprint(vars(args))
    # 执行主函数
    main(args)

.\lucidrains\geometric-vector-perceptron\examples\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\geometric-vector-perceptron\geometric_vector_perceptron\geometric_vector_perceptron.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch_geometric.nn 模块中导入 MessagePassing 类
from torch_geometric.nn import MessagePassing

# types

# 导入类型提示相关的模块和类型
from typing import Optional, List, Union
from torch_geometric.typing import OptPairTensor, Adj, Size, OptTensor, Tensor

# helper functions

# 定义一个函数,判断输入值是否存在
def exists(val):
    return val is not None

# classes

# 定义 GVP 类,继承自 nn.Module 类
class GVP(nn.Module):
    def __init__(
        self,
        *,
        dim_vectors_in,
        dim_vectors_out,
        dim_feats_in,
        dim_feats_out,
        feats_activation = nn.Sigmoid(),
        vectors_activation = nn.Sigmoid(),
        vector_gating = False
    ):
        super().__init__()
        self.dim_vectors_in = dim_vectors_in
        self.dim_feats_in = dim_feats_in

        self.dim_vectors_out = dim_vectors_out
        dim_h = max(dim_vectors_in, dim_vectors_out)

        # 初始化权重参数
        self.Wh = nn.Parameter(torch.randn(dim_vectors_in, dim_h))
        self.Wu = nn.Parameter(torch.randn(dim_h, dim_vectors_out))

        self.vectors_activation = vectors_activation

        # 定义输出特征的网络结构
        self.to_feats_out = nn.Sequential(
            nn.Linear(dim_h + dim_feats_in, dim_feats_out),
            feats_activation
        )

        # 根据 vector_gating 参数选择是否使用向量门控
        self.scalar_to_vector_gates = nn.Linear(dim_feats_out, dim_vectors_out) if vector_gating else None

    # 前向传播函数
    def forward(self, data):
        feats, vectors = data
        b, n, _, v, c  = *feats.shape, *vectors.shape

        # 断言向量维度和特征维度是否匹配
        assert c == 3 and v == self.dim_vectors_in, 'vectors have wrong dimensions'
        assert n == self.dim_feats_in, 'scalar features have wrong dimensions'

        # 计算 Vh 和 Vu
        Vh = einsum('b v c, v h -> b h c', vectors, self.Wh)
        Vu = einsum('b h c, h u -> b u c', Vh, self.Wu)

        # 计算向量的模长
        sh = torch.norm(Vh, p = 2, dim = -1)

        # 拼接特征和模长
        s = torch.cat((feats, sh), dim = 1)

        # 计算特征输出
        feats_out = self.to_feats_out(s)

        # 如果存在 scalar_to_vector_gates,则计算门控
        if exists(self.scalar_to_vector_gates):
            gating = self.scalar_to_vector_gates(feats_out)
            gating = gating.unsqueeze(dim = -1)
        else:
            gating = torch.norm(Vu, p = 2, dim = -1, keepdim = True)

        # 计算向量输出
        vectors_out = self.vectors_activation(gating) * Vu
        return (feats_out, vectors_out)

# 定义 GVPDropout 类,继承自 nn.Module 类
class GVPDropout(nn.Module):
    """ Separate dropout for scalars and vectors. """
    def __init__(self, rate):
        super().__init__()
        self.vector_dropout = nn.Dropout2d(rate)
        self.feat_dropout = nn.Dropout(rate)

    # 前向传播函数
    def forward(self, feats, vectors):
        return self.feat_dropout(feats), self.vector_dropout(vectors)

# 定义 GVPLayerNorm 类,继承自 nn.Module 类
class GVPLayerNorm(nn.Module):
    """ Normal layer norm for scalars, nontrainable norm for vectors. """
    def __init__(self, feats_h_size, eps = 1e-8):
        super().__init__()
        self.eps = eps
        self.feat_norm = nn.LayerNorm(feats_h_size)

    # 前向传播函数
    def forward(self, feats, vectors):
        vector_norm = vectors.norm(dim=(-1,-2), keepdim=True)
        normed_feats = self.feat_norm(feats)
        normed_vectors = vectors / (vector_norm + self.eps)
        return normed_feats, normed_vectors

# 定义 GVP_MPNN 类,继承自 MessagePassing 类
class GVP_MPNN(MessagePassing):
    r"""The Geometric Vector Perceptron message passing layer
        introduced in https://openreview.net/forum?id=1YLJDvSx6J4.
        
        Uses a Geometric Vector Perceptron instead of the normal 
        MLP in aggregation phase.

        Inputs will be a concatenation of (vectors, features)

        Args:
        * feats_x_in: int. number of scalar dimensions in the x inputs.
        * vectors_x_in: int. number of vector dimensions in the x inputs.
        * feats_x_out: int. number of scalar dimensions in the x outputs.
        * vectors_x_out: int. number of vector dimensions in the x outputs.
        * feats_edge_in: int. number of scalar dimensions in the edge_attr inputs.
        * vectors_edge_in: int. number of vector dimensions in the edge_attr inputs.
        * feats_edge_out: int. number of scalar dimensions in the edge_attr outputs.
        * vectors_edge_out: int. number of vector dimensions in the edge_attr outputs.
        * dropout: float. dropout rate.
        * vector_dim: int. dimensions of the space containing the vectors.
        * verbose: bool. verbosity level.
    """
    # 初始化函数,接受多个参数
    def __init__(self, feats_x_in, vectors_x_in,
                       feats_x_out, vectors_x_out,
                       feats_edge_in, vectors_edge_in,
                       feats_edge_out, vectors_edge_out,
                       dropout, residual=False, vector_dim=3, 
                       verbose=False, **kwargs):
        # 调用父类的初始化函数,设置聚合方式为"mean"
        super(GVP_MPNN, self).__init__(aggr="mean",**kwargs)
        # 记录是否输出详细信息
        self.verbose = verbose
        # 记录输入特征和向量的维度
        self.feats_x_in    = feats_x_in 
        self.vectors_x_in  = vectors_x_in # 输入中的 N 个向量特征
        self.feats_x_out   = feats_x_out 
        self.vectors_x_out = vectors_x_out # 输出中的 N 个向量特征
        # 记录边属性的维度
        self.feats_edge_in    = feats_edge_in 
        self.vectors_edge_in  = vectors_edge_in # 输入中的 N 个向量特征
        self.feats_edge_out   = feats_edge_out 
        self.vectors_edge_out = vectors_edge_out # 输出中的 N 个向量特征
        # 辅助层
        self.vector_dim = vector_dim
        # 初始化归一化层
        self.norm = nn.ModuleList([GVPLayerNorm(self.feats_x_out), # + self.feats_edge_out
                                   GVPLayerNorm(self.feats_x_out)])
        # 初始化 dropout 层
        self.dropout = GVPDropout(dropout)
        # 是否使用残差连接
        self.residual = residual
        # 接收 vec_in 消息和接收节点
        self.W_EV = nn.Sequential(GVP(
                                      dim_vectors_in = self.vectors_x_in + self.vectors_edge_in, 
                                      dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
                                      dim_feats_in = self.feats_x_in + self.feats_edge_in, 
                                      dim_feats_out = self.feats_x_out + self.feats_edge_out
                                  ), 
                                  GVP(
                                      dim_vectors_in = self.vectors_x_out + self.feats_edge_out, 
                                      dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
                                      dim_feats_in = self.feats_x_out + self.feats_edge_out,
                                      dim_feats_out = self.feats_x_out + self.feats_edge_out
                                  ),
                                  GVP(
                                      dim_vectors_in = self.vectors_x_out + self.feats_edge_out, 
                                      dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
                                      dim_feats_in = self.feats_x_out + self.feats_edge_out,
                                      dim_feats_out = self.feats_x_out + self.feats_edge_out
                                  ))
        
        # 初始化 W_dh 层
        self.W_dh = nn.Sequential(GVP(
                                      dim_vectors_in = self.vectors_x_out,
                                      dim_vectors_out = 2*self.vectors_x_out,
                                      dim_feats_in = self.feats_x_out,
                                      dim_feats_out = 4*self.feats_x_out
                                  ),
                                  GVP(
                                      dim_vectors_in = 2*self.vectors_x_out,
                                      dim_vectors_out = self.vectors_x_out,
                                      dim_feats_in = 4*self.feats_x_out,
                                      dim_feats_out = self.feats_x_out
                                  ))
    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        """"""
        # 获取输入张量 x 的最后一个维度的大小
        x_size = list(x.shape)[-1]
        # 分别聚合特征和向量
        feats, vectors = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        # 聚合
        feats, vectors = self.dropout(feats, vectors.reshape(vectors.shape[0], -1, self.vector_dim))
        # 获取与节点相关的信息 - 不返回边
        feats_nodes  = feats[:, :self.feats_x_in]
        vector_nodes = vectors[:, :self.vectors_x_in]
        # 将向量部分重塑为最后一个 3D
        x_vectors    = x[:, :self.vectors_x_in * self.vector_dim].reshape(x.shape[0], -1, self.vector_dim)
        feats, vectors = self.norm[0]( x[:, self.vectors_x_in * self.vector_dim:]+feats_nodes, x_vectors+vector_nodes )
        # 更新位置感知前馈
        feats_, vectors_ = self.dropout( *self.W_dh( (feats, vectors) ) )
        feats, vectors   = self.norm[1]( feats+feats_, vectors+vectors_ )
        # 使其成为残差
        new_x = torch.cat( [feats, vectors.flatten(start_dim=-2)], dim=-1 )
        if self.residual:
          return new_x + x
        return new_x


    def message(self, x_j, edge_attr) -> Tensor:
        # 拼接特征和边属性
        feats   = torch.cat([ x_j[:, self.vectors_x_in * self.vector_dim:],
                              edge_attr[:, self.vectors_edge_in * self.vector_dim:] ], dim=-1)
        vectors = torch.cat([ x_j[:, :self.vectors_x_in * self.vector_dim], 
                              edge_attr[:, :self.vectors_edge_in * self.vector_dim] ], dim=-1).reshape(x_j.shape[0],-1,self.vector_dim)
        feats, vectors = self.W_EV( (feats, vectors) )
        return feats, vectors.flatten(start_dim=-2)


    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
        r"""The initial call to start propagating messages.
        Args:
            adj (Tensor or SparseTensor): `edge_index` holds the indices of a general (sparse)
                assignment matrix of shape :obj:`[N, M]`.
            size (tuple, optional): If set to :obj:`None`, the size will be automatically inferred
                and assumed to be quadratic.
                This argument is ignored in case :obj:`edge_index` is a
                :obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        size = self.__check_input__(edge_index, size)
        coll_dict = self.__collect__(self.__user_args__,
                                     edge_index, size, kwargs)
        msg_kwargs = self.inspector.distribute('message', coll_dict)
        feats, vectors = self.message(**msg_kwargs)
        # 聚合它们
        aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
        out_feats   = self.aggregate(feats, **aggr_kwargs)
        out_vectors = self.aggregate(vectors, **aggr_kwargs)
        # 返回元组
        update_kwargs = self.inspector.distribute('update', coll_dict)
        return self.update((out_feats, out_vectors), **update_kwargs)

        
    def __repr__(self):
        dict_print = { "feats_x_in": self.feats_x_in,
                       "vectors_x_in": self.vectors_x_in,
                       "feats_x_out": self.feats_x_out,
                       "vectors_x_out": self.vectors_x_out,
                       "feats_edge_in": self.feats_edge_in,
                       "vectors_edge_in": self.vectors_edge_in,
                       "feats_edge_out": self.feats_edge_out,
                       "vectors_edge_out": self.vectors_edge_out,
                       "vector_dim": self.vector_dim }
        return  'GVP_MPNN Layer with the following attributes: ' + str(dict_print)
class GVP_Network(nn.Module):
    r"""Sample GNN model architecture that uses the Geometric Vector Perceptron
        message passing layer to learn over point clouds. 
        Main MPNN layer introduced in https://openreview.net/forum?id=1YLJDvSx6J4.

        Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...

        Args:
        * n_layers: int. number of MPNN layers
        * feats_x_in: int. number of scalar dimensions in the x inputs.
        * vectors_x_in: int. number of vector dimensions in the x inputs.
        * feats_x_out: int. number of scalar dimensions in the x outputs.
        * vectors_x_out: int. number of vector dimensions in the x outputs.
        * feats_edge_in: int. number of scalar dimensions in the edge_attr inputs.
        * vectors_edge_in: int. number of vector dimensions in the edge_attr inputs.
        * feats_edge_out: int. number of scalar dimensions in the edge_attr outputs.
        * embedding_nums: list. number of unique keys to embedd. for points
                          1 entry per embedding needed. 
        * embedding_dims: list. point - number of dimensions of
                          the resulting embedding. 1 entry per embedding needed. 
        * edge_embedding_nums: list. number of unique keys to embedd. for edges.
                               1 entry per embedding needed. 
        * edge_embedding_dims: list. point - number of dimensions of
                               the resulting embedding. 1 entry per embedding needed. 
        * vectors_edge_out: int. number of vector dimensions in the edge_attr outputs.
        * dropout: float. dropout rate.
        * vector_dim: int. dimensions of the space containing the vectors.
        * recalc: bool. Whether to recalculate edge features between MPNN layers.
        * verbose: bool. verbosity level.
    """
    # 初始化函数,接受多个参数,包括层数、输入特征和向量、输出特征和向量、边特征和向量等
    def __init__(self, n_layers, 
                       feats_x_in, vectors_x_in,
                       feats_x_out, vectors_x_out,
                       feats_edge_in, vectors_edge_in,
                       feats_edge_out, vectors_edge_out,
                       embedding_nums=[], embedding_dims=[],
                       edge_embedding_nums=[], edge_embedding_dims=[],
                       dropout=0.0, residual=False, vector_dim=3,
                       recalc=1, verbose=False):
        # 调用父类的初始化函数
        super().__init__()

        # 初始化各种属性
        self.n_layers         = n_layers 
        self.embedding_nums   = embedding_nums
        self.embedding_dims   = embedding_dims
        self.emb_layers       = torch.nn.ModuleList()
        self.edge_embedding_nums = edge_embedding_nums
        self.edge_embedding_dims = edge_embedding_dims
        self.edge_emb_layers     = torch.nn.ModuleList()
        
        # 实例化点和边的嵌入层
        for i in range( len(self.embedding_dims) ):
            self.emb_layers.append(nn.Embedding(num_embeddings = embedding_nums[i],
                                                embedding_dim  = embedding_dims[i]))
            feats_x_in += embedding_dims[i] - 1
            feats_x_out += embedding_dims[i] - 1
        for i in range( len(self.edge_embedding_dims) ):
            self.edge_emb_layers.append(nn.Embedding(num_embeddings = edge_embedding_nums[i],
                                                     embedding_dim  = edge_embedding_dims[i]))
            feats_edge_in += edge_embedding_dims[i] - 1
            feats_edge_out += edge_embedding_dims[i] - 1
        
        # 初始化其他属性
        self.fc_layers        = torch.nn.ModuleList()
        self.gcnn_layers      = torch.nn.ModuleList()
        self.feats_x_in       = feats_x_in
        self.vectors_x_in     = vectors_x_in
        self.feats_x_out      = feats_x_out
        self.vectors_x_out    = vectors_x_out
        self.feats_edge_in    = feats_edge_in
        self.vectors_edge_in  = vectors_edge_in
        self.feats_edge_out   = feats_edge_out
        self.vectors_edge_out = vectors_edge_out
        self.dropout          = dropout
        self.residual         = residual
        self.vector_dim       = vector_dim
        self.recalc           = recalc
        self.verbose          = verbose
        
        # 实例化GCNN层
        for i in range(n_layers):
            layer = GVP_MPNN(feats_x_in, vectors_x_in,
                             feats_x_out, vectors_x_out,
                             feats_edge_in, vectors_edge_in,
                             feats_edge_out, vectors_edge_out,
                             dropout, residual=residual,
                             vector_dim=vector_dim, verbose=verbose)
            self.gcnn_layers.append(layer)
    # 定义一个前向传播函数,接受输入 x、边索引 edge_index、批次 batch、边属性 edge_attr
    # bsize 为批次大小,recalc_edge 为重新计算边特征的函数,verbose 为是否输出详细信息的标志
    def forward(self, x, edge_index, batch, edge_attr,
                bsize=None, recalc_edge=None, verbose=0):
        """ Embedding of inputs when necessary, then pass layers.
            Recalculate edge features every time with the
            `recalc_edge` function.
        """
        # 复制输入数据,用于后续恢复原始数据
        original_x = x.clone()
        original_edge_index = edge_index.clone()
        original_edge_attr = edge_attr.clone()
        
        # 当需要时进行嵌入
        # 选择要嵌入的部分,逐个进行嵌入并添加到输入中
        
        # 提取要嵌入的部分
        to_embedd = x[:, -len(self.embedding_dims):].long()
        for i, emb_layer in enumerate(self.emb_layers):
            # 在第一次迭代时,对应于 `to_embedd` 部分的部分会被丢弃
            stop_concat = -len(self.embedding_dims) if i == 0 else x.shape[-1]
            x = torch.cat([x[:, :stop_concat], 
                           emb_layer(to_embedd[:, i])], dim=-1)
        
        # 传递层
        for i, layer in enumerate(self.gcnn_layers):
            # 嵌入边属性(每次都需要,因为边属性和索引在每次传递时都会重新计算)
            to_embedd = edge_attr[:, -len(self.edge_embedding_dims):].long()
            for j, edge_emb_layer in enumerate(self.edge_emb_layers):
                # 在第一次迭代时,对应于 `to_embedd` 部分的部分会被丢弃
                stop_concat = -len(self.edge_embedding_dims) if j == 0 else x.shape[-1]
                edge_attr = torch.cat([edge_attr[:, :-len(self.edge_embedding_dims) + j], 
                                       edge_emb_layer(to_embedd[:, j])], dim=-1)
            
            # 传递层
            x = layer(x, edge_index, edge_attr, size=bsize)

            # 每 self.recalc 步重新计算边信息
            # 但如果是最后一层的最后一次迭代,则不需要重新计算
            if (1 % self.recalc == 0) and not (i == self.n_layers - 1):
                edge_index, edge_attr, _ = recalc_edge(x)  # 返回属性、索引、嵌入信息
            else:
                edge_attr = original_edge_attr.clone()
                edge_index = original_edge_index.clone()
            
            if verbose:
                print("========")
                print("iter:", j, "layer:", i, "nlinks:", edge_attr.shape)
        
        return x

    # 定义对象的字符串表示形式
    def __repr__(self):
        return 'GVP_Network of: {0} layers'.format(len(self.gcnn_layers))

.\lucidrains\geometric-vector-perceptron\geometric_vector_perceptron\__init__.py

# 从 geometric_vector_perceptron 模块中导入 GVP, GVPDropout, GVPLayerNorm, GVP_MPNN, GVP_Network 类
from geometric_vector_perceptron.geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm, GVP_MPNN, GVP_Network

Geometric Vector Perceptron

Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biomolecules, in Pytorch. The repository may also contain experimentation to see if this could be easily extended to self-attention.

Install

$ pip install geometric-vector-perceptron

Functionality

  • GVP: Implementing the basic geometric vector perceptron.
  • GVPDropout: Adapted dropout for GVP in MPNN context
  • GVPLayerNorm: Adapted LayerNorm for GVP in MPNN context
  • GVP_MPNN: Adapted instance of Message Passing class from torch-geometric package. Still not tested.
  • GVP_Network: Functional model architecture ready for working with arbitary point clouds.

Usage

import torch
from geometric_vector_perceptron import GVP

model = GVP(
    dim_vectors_in = 1024,
    dim_feats_in = 512,
    dim_vectors_out = 256,
    dim_feats_out = 512,
    vector_gating = True   # use the vector gating as proposed in https://arxiv.org/abs/2106.03843
)

feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))

feats_out, vectors_out = model( (feats, vectors) ) # (1, 256), (1, 512, 3)

With the specialized dropout and layernorm as described in the paper

import torch
from torch import nn
from geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm

model = GVP(
    dim_vectors_in = 1024,
    dim_feats_in = 512,
    dim_vectors_out = 256,
    dim_feats_out = 512,
    vector_gating = True
)

dropout = GVPDropout(0.2)
norm = GVPLayerNorm(512)

feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))

feats, vectors = model( (feats, vectors) )
feats, vectors = dropout(feats, vectors)
feats, vectors = norm(feats, vectors)  # (1, 256), (1, 512, 3)

TF implementation:

The original implementation in TF by the paper authors can be found here: https://github.com/drorlab/gvp/

Citations

@inproceedings{anonymous2021learning,
    title   = {Learning from Protein Structure with Geometric Vector Perceptrons},
    author  = {Anonymous},
    booktitle = {Submitted to International Conference on Learning Representations},
    year    = {2021},
    url     = {https://openreview.net/forum?id=1YLJDvSx6J4}
}
@misc{jing2021equivariant,
    title   = {Equivariant Graph Neural Networks for 3D Macromolecular Structure}, 
    author  = {Bowen Jing and Stephan Eismann and Pratham N. Soni and Ron O. Dror},
    year    = {2021},
    eprint  = {2106.03843},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\geometric-vector-perceptron\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'geometric-vector-perceptron', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.14', # 版本号
  license='MIT', # 许可证
  description = 'Geometric Vector Perceptron - Pytorch', # 描述
  author = 'Phil Wang, Eric Alcaide', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/geometric-vector-perceptron', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'proteins',
    'biomolecules',
    'equivariance'
  ],
  install_requires=[ # 安装依赖
    'torch>=1.6',
    'torch-scatter',
    'torch-sparse',
    'torch-cluster',
    'torch-spline-conv',
    'torch-geometric'
  ],
  setup_requires=[ # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[ # 测试需要的依赖
    'pytest'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\geometric-vector-perceptron\tests\tests.py

# 导入 torch 库
import torch
# 从 geometric_vector_perceptron 库中导入 GVP, GVPDropout, GVPLayerNorm, GVP_MPNN
from geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm, GVP_MPNN

# 定义容差值
TOL = 1e-2

# 生成随机旋转矩阵
def random_rotation():
    q, r = torch.qr(torch.randn(3, 3))
    return q

# 计算向量之间的差值矩阵
def diff_matrix(vectors):
    b, _, d = vectors.shape
    diff = vectors[..., None, :] - vectors[:, None, ...]
    return diff.reshape(b, -1, d)

# 测试等变性
def test_equivariance():
    R = random_rotation()

    # 创建 GVP 模型
    model = GVP(
        dim_vectors_in = 1024,
        dim_feats_in = 512,
        dim_vectors_out = 256,
        dim_feats_out = 512
    )

    feats = torch.randn(1, 512)
    vectors = torch.randn(1, 32, 3)

    feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
    feats_out_r, vectors_out_r = model( (feats, diff_matrix(vectors @ R)) )

    err = ((vectors_out @ R) - vectors_out_r).max()
    assert err < TOL, 'equivariance must be respected'

# 测试所有层类型
def test_all_layer_types():
    R = random_rotation()

    # 创建 GVP 模型
    model = GVP(
        dim_vectors_in = 1024,
        dim_feats_in = 512,
        dim_vectors_out = 256,
        dim_feats_out = 512
    )
    dropout = GVPDropout(0.2)
    layer_norm = GVPLayerNorm(512)

    feats = torch.randn(1, 512)
    message = torch.randn(1, 512)
    vectors = torch.randn(1, 32, 3)

    # GVP 层
    feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
    assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]

    # GVP Dropout
    feats_out, vectors_out = dropout(feats_out, vectors_out)
    assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]

    # GVP Layer Norm
    feats_out, vectors_out = layer_norm(feats_out, vectors_out)
    assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]

# 测试 MPNN
def test_mpnn():
    # 输入数据
    x = torch.randn(5, 32)
    edge_idx = torch.tensor([[0,2,3,4,1], [1,1,3,3,4]]).long()
    edge_attr = torch.randn(5, 16)
    # 节点 (8 个标量和 8 个向量) || 边 (4 个标量和 3 个向量)
    dropout = 0.1
    # 定义层
    gvp_mpnn = GVP_MPNN(feats_x_in = 8,
                        vectors_x_in = 8,
                        feats_x_out = 8,
                        vectors_x_out = 8, 
                        feats_edge_in = 4,
                        vectors_edge_in = 4,
                        feats_edge_out = 4,
                        vectors_edge_out = 4,
                        dropout=0.1 )
    x_out = gvp_mpnn(x, edge_idx, edge_attr)

    assert x.shape == x_out.shape, "Input and output shapes don't match"

# 主函数入口
if __name__ == "__main__":
    # 执行等变性测试
    test_equivariance()
    # 执行所有层类型测试
    test_all_layer_types()
    # 执行 MPNN 测试
    test_mpnn()

.\lucidrains\gigagan-pytorch\gigagan_pytorch\attend.py

# 导入必要的库和模块
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

# 定义一个命名元组,用于存储注意力机制的配置信息
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)

# 主要类定义
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定在cuda和cpu上的高效注意力配置
        self.cpu_config = AttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, True)

    # 实现flash attention的方法
    def flash_attn(self, q, k, v):
        is_cuda = q.is_cuda

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # 检查是否有兼容的设备支持flash attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用torch.backends.cuda.sdp_kernel函数应用flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        if self.flash:
            return self.flash_attn(q, k, v)

        scale = q.shape[-1] ** -0.5

        # 计算相似度
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

        # 注意力计算
        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # 聚合数值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\gigagan-pytorch\gigagan_pytorch\data.py

# 导入必要的库
from functools import partial
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from torchvision import transforms as T

from beartype.door import is_bearable
from beartype.typing import Tuple

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 将图像转换为指定格式的函数
def convert_image_to_fn(img_type, image):
    if image.mode == img_type:
        return image

    return image.convert(img_type)

# 自定义数据集拼接函数
# 使数据集可以返回字符串并将其拼接成 List[str]
def collate_tensors_or_str(data):
    is_one_data = not isinstance(data[0], tuple)

    if is_one_data:
        data = torch.stack(data)
        return (data,)

    outputs = []
    for datum in zip(*data):
        if is_bearable(datum, Tuple[str, ...]):
            output = list(datum)
        else:
            output = torch.stack(datum)

        outputs.append(output)

    return tuple(outputs)

# 数据集类

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png', 'tiff'],
        augment_horizontal_flip = False,
        convert_image_to = None
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size

        # 获取文件夹中指定扩展名的所有文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 断言确保文件路径数量大于0
        assert len(self.paths) > 0, 'your folder contains no images'
        # 断言确保文件路径数量大于100
        assert len(self.paths) > 100, 'you need at least 100 images, 10k for research paper, millions for miraculous results (try Laion-5B)'

        # 创建转换函数
        maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()

        # 图像转换操作序列
        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    # 获取数据加载器
    def get_dataloader(self, *args, **kwargs):
        return DataLoader(self, *args, shuffle = True, drop_last = True, **kwargs)

    # 返回数据集长度
    def __len__(self):
        return len(self.paths)

    # 获取数据���中的数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 文本图像数据集类
class TextImageDataset(Dataset):
    def __init__(self):
        raise NotImplementedError

    # 获取数据加载器
    def get_dataloader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)

# 模拟文本图像数据集类
class MockTextImageDataset(TextImageDataset):
    def __init__(
        self,
        image_size,
        length = int(1e5),
        channels = 3
    ):
        self.image_size = image_size
        self.channels = channels
        self.length = length

    # 获取数据加载器
    def get_dataloader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)

    # 返回数据集长度
    def __len__(self):
        return self.length

    # 获取数据集中的数据
    def __getitem__(self, index):
        mock_image = torch.randn(self.channels, self.image_size, self.image_size)
        return mock_image, 'mock text'

.\lucidrains\gigagan-pytorch\gigagan_pytorch\distributed.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 导入 torch 分布式模块
import torch.distributed as dist
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# helpers

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 在指定维度上对张量进行填充的辅助函数
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))

# distributed helpers

# 在所有进程中收集具有可变维度的张量的辅助函数
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, world_size = t.device, dist.get_world_size()

    if not exists(sizes):
        size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
        sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
        dist.all_gather(sizes, size)
        sizes = torch.stack(sizes)

    max_size = sizes.amax().item()
    padded_t = pad_dim_to(t, max_size, dim = dim)

    gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
    dist.all_gather(gathered_tensors, padded_t)

    gathered_tensor = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensor = gathered_tensor.index_select(dim, indices)

    return gathered_tensor, sizes

# 自定义 Function 类 AllGather
class AllGather(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        is_dist = dist.is_initialized() and dist.get_world_size() > 1
        ctx.is_dist = is_dist

        if not is_dist:
            return x, None

        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        if not ctx.is_dist:
            return grads, None, None

        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 将 AllGather 类应用为函数
all_gather = AllGather.apply

.\lucidrains\gigagan-pytorch\gigagan_pytorch\gigagan_pytorch.py

# 导入必要的库
from collections import namedtuple
from pathlib import Path
from math import log2, sqrt
from random import random
from functools import partial

from torchvision import utils

import torch
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.autograd import grad as torch_grad
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler

from beartype import beartype
from beartype.typing import List, Optional, Tuple, Dict, Union, Iterable

from einops import rearrange, pack, unpack, repeat, reduce
from einops.layers.torch import Rearrange, Reduce

from kornia.filters import filter2d

from ema_pytorch import EMA

from gigagan_pytorch.version import __version__
from gigagan_pytorch.open_clip import OpenClipAdapter
from gigagan_pytorch.optimizer import get_optimizer
from gigagan_pytorch.distributed import all_gather

from tqdm import tqdm

from numerize import numerize

from accelerate import Accelerator, DistributedType
from accelerate.utils import DistributedDataParallelKwargs

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 检查数组是否为空
@beartype
def is_empty(arr: Iterable):
    return len(arr) == 0

# 返回第一个非空值
def default(*vals):
    for val in vals:
        if exists(val):
            return val
    return None

# 将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 检查数字是否为2的幂
def is_power_of_two(n):
    return log2(n).is_integer()

# 安全地从数组中取出第一个元素
def safe_unshift(arr):
    if len(arr) == 0:
        return None
    return arr.pop(0)

# 检查数字是否可以被另一个数字整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 将数组按照指定数量分组
def group_by_num_consecutive(arr, num):
    out = []
    for ind, el in enumerate(arr):
        if ind > 0 and divisible_by(ind, num):
            yield out
            out = []

        out.append(el)

    if len(out) > 0:
        yield out

# 检查数组中的元素是否唯一
def is_unique(arr):
    return len(set(arr)) == len(arr)

# 无限循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将数字分成指定数量的组
def num_to_groups(num, divisor):
    groups, remainder = divmod(num, divisor)
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 如果路径不存在,则创建目录
def mkdir_if_not_exists(path):
    path.mkdir(exist_ok = True, parents = True)

# 设置模型参数是否需要梯度
@beartype
def set_requires_grad_(
    m: nn.Module,
    requires_grad: bool
):
    for p in m.parameters():
        p.requires_grad = requires_grad

# 激活函数

# Leaky ReLU 激活函数
def leaky_relu(neg_slope = 0.2):
    return nn.LeakyReLU(neg_slope)

# 创建 3x3 的卷积层
def conv2d_3x3(dim_in, dim_out):
    return nn.Conv2d(dim_in, dim_out, 3, padding = 1)

# 张量操作辅助函数

# 对张量取对数
def log(t, eps = 1e-20):
    return t.clamp(min = eps).log()

# 计算梯度惩罚
def gradient_penalty(
    images,
    outputs,
    grad_output_weights = None,
    weight = 10,
    scaler: Optional[GradScaler] = None,
    eps = 1e-4
):
    if not isinstance(outputs, (list, tuple)):
        outputs = [outputs]

    if exists(scaler):
        outputs = [*map(scaler.scale, outputs)]

    if not exists(grad_output_weights):
        grad_output_weights = (1,) * len(outputs)

    maybe_scaled_gradients, *_ = torch_grad(
        outputs = outputs,
        inputs = images,
        grad_outputs = [(torch.ones_like(output) * weight) for output, weight in zip(outputs, grad_output_weights)],
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )

    gradients = maybe_scaled_gradients

    if exists(scaler):
        scale = scaler.get_scale()
        inv_scale = 1. / max(scale, eps)
        gradients = maybe_scaled_gradients * inv_scale

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# Hinge GAN 损失函数

# 生成器的 Hinge 损失
def generator_hinge_loss(fake):
    return fake.mean()

# 判别器的 Hinge 损失
def discriminator_hinge_loss(real, fake):
    return (F.relu(1 + real) + F.relu(1 - fake)).mean()

# 辅助损失函数

# 辅助匹配损失
def aux_matching_loss(real, fake):
    """
    # 计算负对数似然损失,因为在这个框架中,鉴别器对于真实数据为0,对于生成数据为高值。GANs可以任意交换这一点,只要生成器和鉴别器是对立的即可
    """
    # 返回真实数据和生成数据的负对数似然损失的均值
    return (log(1 + (-real).exp()) + log(1 + (-fake).exp())).mean()
# 使用装饰器 @beartype 对 aux_clip_loss 函数进行类型检查
@beartype
# 定义函数 aux_clip_loss,接受 OpenClipAdapter 类型的 clip 对象、Tensor 类型的 images 和可选的 List[str] 类型的 texts 或 Tensor 类型的 text_embeds
def aux_clip_loss(
    clip: OpenClipAdapter,
    images: Tensor,
    texts: Optional[List[str]] = None,
    text_embeds: Optional[Tensor] = None
):
    # 断言 texts 和 text_embeds 中只有一个存在
    assert exists(texts) ^ exists(text_embeds)

    # 将 images 在所有进程中进行收集
    images, batch_sizes = all_gather(images, 0, None)

    # 如果存在 texts,则使用 clip 对象的 embed_texts 方法获取 text_embeds,并在所有进程中进行收集
    if exists(texts):
        text_embeds, _ = clip.embed_texts(texts)
        text_embeds, _ = all_gather(text_embeds, 0, batch_sizes)

    # 返回 clip 对象的 contrastive_loss 方法计算的损失值
    return clip.contrastive_loss(images = images, text_embeds = text_embeds)

# 不同iable augmentation - Karras et al. stylegan-ada
# 从水平翻转开始

# 定义类 DiffAugment,继承自 nn.Module
class DiffAugment(nn.Module):
    # 初始化方法,接受概率 prob、是否进行水平翻转 horizontal_flip 和水平翻转概率 horizontal_flip_prob
    def __init__(
        self,
        *,
        prob,
        horizontal_flip,
        horizontal_flip_prob = 0.5
    ):
        super().__init__()
        self.prob = prob
        assert 0 <= prob <= 1.

        self.horizontal_flip = horizontal_flip
        self.horizontal_flip_prob = horizontal_flip_prob

    # 前向传播方法,接受 images 和 rgbs
    def forward(
        self,
        images,
        rgbs: List[Tensor]
    ):
        # 如果随机数大于等于概率 prob,则直接返回 images 和 rgbs
        if random() >= self.prob:
            return images, rgbs

        # 如果随机数小于水平翻转概率 horizontal_flip_prob,则对 images 和 rgbs 进行水平翻转
        if random() < self.horizontal_flip_prob:
            images = torch.flip(images, (-1,))
            rgbs = [torch.flip(rgb, (-1,)) for rgb in rgbs]

        return images, rgbs

# rmsnorm(新论文显示在 layernorm 中进行均值中心化不是必要的)

# 定义类 ChannelRMSNorm,继承自 nn.Module
class ChannelRMSNorm(nn.Module):
    # 初始化方法,接受维度 dim
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim, 1, 1))

    # 前向传播方法,对输入 x 进行归一化处���并乘以缩放因子和 gamma 参数
    def forward(self, x):
        normed = F.normalize(x, dim = 1)
        return normed * self.scale * self.gamma

# 定义类 RMSNorm,继承自 nn.Module
class RMSNorm(nn.Module):
    # 初始化方法,接受维度 dim
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    # 前向传播方法,对输入 x 进行归一化处理并乘以缩放因子和 gamma 参数
    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# 下采样和上采样

# 定义类 Blur,继承自 nn.Module
class Blur(nn.Module):
    # 初始化方法,创建一个张量 f,并注册为缓冲区
    def __init__(self):
        super().__init__()
        f = torch.Tensor([1, 2, 1])
        self.register_buffer('f', f)

    # 前向传播方法,对输入 x 进行二维卷积滤波
    def forward(self, x):
        f = self.f
        f = f[None, None, :] * f [None, :, None]
        return filter2d(x, f, normalized = True)

# 定义��数 Upsample,返回一个包含上采样和模糊处理的序列
def Upsample(*args):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
        Blur()
    )

# 定义类 PixelShuffleUpsample,继承自 nn.Module
class PixelShuffleUpsample(nn.Module):
    # 初始化方法,接受维度 dim
    def __init__(self, dim):
        super().__init__()
        conv = nn.Conv2d(dim, dim * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(2)
        )

        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播方法,对输入 x 进行处理
    def forward(self, x):
        return self.net(x)

# 定义函数 Downsample,返回一个包含下采样和卷积层的序列
def Downsample(dim):
    return nn.Sequential(
        Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
        nn.Conv2d(dim * 4, dim, 1)
    )

# 跳跃层激励

# 定义函数 SqueezeExcite,返回一个包含减少、线性层、SiLU 激活函数、线性层、Sigmoid 激活函数和重排维度的序列
def SqueezeExcite(dim, dim_out, reduction = 4, dim_min = 32):
    dim_hidden = max(dim_out // reduction, dim_min)

    return nn.Sequential(
        Reduce('b c h w -> b c', 'mean'),
        nn.Linear(dim, dim_hidden),
        nn.SiLU(),
        nn.Linear(dim_hidden, dim_out),
        nn.Sigmoid(),
        Rearrange('b c -> b c 1 1')
    )

# 自适应卷积
# 论文的主要创新 - 他们提出根据文本嵌入学习 N 个卷积核的 softmax 加权和

# 定义函数 get_same_padding,计算卷积层的 padding 大小
def get_same_padding(size, kernel, dilation, stride):
    return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2

# 定义类 AdaptiveConv2DMod,继承自 nn.Module
class AdaptiveConv2DMod(nn.Module):
    # 初始化函数,设置卷积层的参数
    def __init__(
        self,
        dim,
        dim_out,
        kernel,
        *,
        demod = True,
        stride = 1,
        dilation = 1,
        eps = 1e-8,
        num_conv_kernels = 1 # set this to be greater than 1 for adaptive
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置 epsilon 值
        self.eps = eps

        # 设置输出维度
        self.dim_out = dim_out

        # 设置卷积核大小、步长、膨胀率
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        # 是否使用自适应卷积核
        self.adaptive = num_conv_kernels > 1

        # 初始化权重参数
        self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)))

        # 是否使用 demodulation
        self.demod = demod

        # 使用 kaiming_normal 初始化权重
        nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')

    # 前向传播函数
    def forward(
        self,
        fmap,
        mod: Optional[Tensor] = None,
        kernel_mod: Optional[Tensor] = None
    ):
        """
        notation

        b - batch
        n - convs
        o - output
        i - input
        k - kernel
        """

        # 获取 batch 大小和特征图高度
        b, h = fmap.shape[0], fmap.shape[-2]

        # 考虑特征图在第一维度上由于多尺度输入和输出而扩展的情况

        # 如果 mod 的 batch 大小不等于 b,则进行重复操作
        if mod.shape[0] != b:
            mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])

        # 如果存在 kernel_mod
        if exists(kernel_mod):
            kernel_mod_has_el = kernel_mod.numel() > 0

            # 如果使用自适应卷积核,kernel_mod 必须为空
            assert self.adaptive or not kernel_mod_has_el

            # 如果 kernel_mod 不为空且其 batch 大小不等于 b,则进行重复操作
            if kernel_mod_has_el and kernel_mod.shape[0] != b:
                kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])

        # 准备用于调制的权重

        weights = self.weights

        # 如果使用自适应卷积核
        if self.adaptive:
            # 对权重进行重复操作
            weights = repeat(weights, '... -> b ...', b = b)

            # 确定自适应权重并使用 softmax 选择要使用的卷积核
            assert exists(kernel_mod) and kernel_mod.numel() > 0

            kernel_attn = kernel_mod.softmax(dim = -1)
            kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1 1')

            weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum')

        # 进行调制和解调制,类似 stylegan2 中的操作

        mod = rearrange(mod, 'b i -> b 1 i 1 1')

        weights = weights * (mod + 1)

        # 如果使用解调制
        if self.demod:
            inv_norm = reduce(weights ** 2, 'b o i k1 k2 -> b o 1 1 1', 'sum').clamp(min = self.eps).rsqrt()
            weights = weights * inv_norm

        fmap = rearrange(fmap, 'b c h w -> 1 (b c) h w')

        weights = rearrange(weights, 'b o ... -> (b o) ...')

        # 计算填充值
        padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
        # 使用卷积操作
        fmap = F.conv2d(fmap, weights, padding = padding, groups = b)

        # 重新排列特征图
        return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
# 定义 SelfAttention 类,用于实现自注意力机制
class SelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dot_product = False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        dim_inner = dim_head * heads

        self.dot_product = dot_product

        self.norm = ChannelRMSNorm(dim)

        self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_inner, 1, bias = False) if dot_product else None
        self.to_v = nn.Conv2d(dim, dim_inner, 1, bias = False)

        self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))

        self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)

    # 实现前向传播函数
    def forward(self, fmap):
        """
        einstein notation

        b - batch
        h - heads
        x - height
        y - width
        d - dimension
        i - source seq (attend from)
        j - target seq (attend to)
        """
        batch = fmap.shape[0]

        fmap = self.norm(fmap)

        x, y = fmap.shape[-2:]

        h = self.heads

        q, v = self.to_q(fmap), self.to_v(fmap)

        k = self.to_k(fmap) if exists(self.to_k) else q

        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = self.heads), (q, k, v))

        # add a null key / value, so network can choose to pay attention to nothing

        nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # l2 distance or dot product

        if self.dot_product:
            sim = einsum('b i d, b j d -> b i j', q, k)
        else:
            # using pytorch cdist leads to nans in lightweight gan training framework, at least
            q_squared = (q * q).sum(dim = -1)
            k_squared = (k * k).sum(dim = -1)
            l2dist_squared = rearrange(q_squared, 'b i -> b i 1') + rearrange(k_squared, 'b j -> b 1 j') - 2 * einsum('b i d, b j d -> b i j', q, k) # hope i'm mathing right
            sim = -l2dist_squared

        # scale

        sim = sim * self.scale

        # attention

        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)

        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)

        return self.to_out(out)

# 定义 CrossAttention 类,用于实现交叉注意力机制
class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_context,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        dim_inner = dim_head * heads
        kv_input_dim = default(dim_context, dim)

        self.norm = ChannelRMSNorm(dim)
        self.norm_context = RMSNorm(kv_input_dim)

        self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
        self.to_kv = nn.Linear(kv_input_dim, dim_inner * 2, bias = False)
        self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
    # 定义一个前向传播函数,接受特征图、上下文和可选的掩码作为输入
    def forward(self, fmap, context, mask = None):
        """
        einstein notation

        b - batch
        h - heads
        x - height
        y - width
        d - dimension
        i - source seq (attend from)
        j - target seq (attend to)
        """

        # 对特征图进行归一化处理
        fmap = self.norm(fmap)
        # 对上下文进行归一化处理
        context = self.norm_context(context)

        # 获取特征图的高度和宽度
        x, y = fmap.shape[-2:]

        # 获取头数
        h = self.heads

        # 将特征图转换为查询、键、值
        q, k, v = (self.to_q(fmap), *self.to_kv(context).chunk(2, dim = -1))

        # 将键和值重排维度
        k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (k, v))

        # 重排查询维度
        q = rearrange(q, 'b (h d) x y -> (b h) (x y) d', h = self.heads)

        # 计算查询和键之间的相似度
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 如果存在掩码,则进行掩码处理
        if exists(mask):
            mask = repeat(mask, 'b j -> (b h) 1 j', h = self.heads)
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 对相似度进行 softmax 操作得到注意力权重
        attn = sim.softmax(dim = -1)

        # 根据注意力权重计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)

        # 重排输出维度
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)

        # 将输出转换为最终输出
        return self.to_out(out)
# 定义经典的 transformer 注意力机制,使用 L2 距离

class TextAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)  # 初始化 RMS 归一化层
        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)  # 初始化线性层,用于计算查询、键、值

        self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))  # 初始化空键/值参数

        self.to_out = nn.Linear(dim_inner, dim, bias = False)  # 初始化输出线性层

    def forward(self, encodings, mask = None):
        """
        einstein notation

        b - batch
        h - heads
        x - height
        y - width
        d - dimension
        i - source seq (attend from)
        j - target seq (attend to)
        """
        batch = encodings.shape[0]

        encodings = self.norm(encodings)  # 对输入进行归一化处理

        h = self.heads

        q, k, v = self.to_qkv(encodings).chunk(3, dim = -1)  # 将查询、键、值分割
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))  # 重排形状

        # 添加一个空键/值,以便网络可以选择不关注任何内容

        nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)  # 重复空键/值

        k = torch.cat((nk, k), dim = -2)  # 拼接键
        v = torch.cat((nv, v), dim = -2)  # 拼接值

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale  # 计算相似度

        # 键填充掩码

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)  # 对掩码进行填充
            mask = repeat(mask, 'b n -> (b h) 1 n', h = h)  # 重复掩码
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)  # 对掩码外的值进行替换

        # 注意力

        attn = sim.softmax(dim = -1)  # 计算注意力权重
        out = einsum('b i j, b j d -> b i d', attn, v)  # 计算输出

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)  # 重排输出形状

        return self.to_out(out)  # 返回输出结果

# 前馈网络

def FeedForward(
    dim,
    mult = 4,
    channel_first = False
):
    dim_hidden = int(dim * mult)
    norm_klass = ChannelRMSNorm if channel_first else RMSNorm
    proj = partial(nn.Conv2d, kernel_size = 1) if channel_first else nn.Linear

    return nn.Sequential(
        norm_klass(dim),  # 初始化归一化层
        proj(dim, dim_hidden),  # 线性变换到隐藏维度
        nn.GELU(),  # GELU 激活函数
        proj(dim_hidden, dim)  # 线性变换回原始维度
    )

# 不同类型的 transformer 块或 transformer(多个块)

class SelfAttentionBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        dot_product = False
    ):
        super().__init__()
        self.attn = SelfAttention(dim = dim, dim_head = dim_head, heads = heads, dot_product = dot_product)  # 初始化自注意力层
        self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)  # 初始化前馈网络

    def forward(self, x):
        x = self.attn(x) + x  # 自注意力操作后加上残差连接
        x = self.ff(x) + x  # 前馈网络操作后加上残差连接
        return x

class CrossAttentionBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_context,
        dim_head = 64,
        heads = 8,
        ff_mult = 4
    ):
        super().__init__()
        self.attn = CrossAttention(dim = dim, dim_context = dim_context, dim_head = dim_head, heads = heads)  # 初始化交叉注意力层
        self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)  # 初始化前馈网络

    def forward(self, x, context, mask = None):
        x = self.attn(x, context = context, mask = mask) + x  # 交叉注意力操作后加上残差连接
        x = self.ff(x) + x  # 前馈网络操作后加上残差连接
        return x

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                TextAttention(dim = dim, dim_head = dim_head, heads = heads),  # 添加文本注意力层
                FeedForward(dim = dim, mult = ff_mult)  # 添加前馈网络
            ]))

        self.norm = RMSNorm(dim)  # 初始化 RMS 归一化层
    # 定义前向传播函数,接受输入 x 和掩码 mask,默认为 None
    def forward(self, x, mask = None):
        # 遍历每个注意力层和前馈神经网络层
        for attn, ff in self.layers:
            # 使用注意力层处理输入 x,并将结果与 x 相加
            x = attn(x, mask = mask) + x
            # 使用前馈神经网络层处理输入 x,并将结果与 x 相加
            x = ff(x) + x

        # 对处理后的 x 进行归一化处理
        return self.norm(x)
# 文本编码器类
class TextEncoder(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        *,
        dim,
        depth,
        clip: Optional[OpenClipAdapter] = None,
        dim_head = 64,
        heads = 8,
    ):
        super().__init__()
        self.dim = dim

        # 如果 clip 不存在,则创建一个 OpenClipAdapter 对象
        if not exists(clip):
            clip = OpenClipAdapter()

        self.clip = clip
        # 设置 clip 不需要梯度
        set_requires_grad_(clip, False)

        # 创建一个学习到的全局标记
        self.learned_global_token = nn.Parameter(torch.randn(dim))

        # 根据 clip.dim_latent 是否等于 dim 来选择 Linear 层或 Identity 函数
        self.project_in = nn.Linear(clip.dim_latent, dim) if clip.dim_latent != dim else nn.Identity()

        # 创建一个 Transformer 模型
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads
        )

    # 前向传播函数
    @beartype
    def forward(
        self,
        texts: Optional[List[str]] = None,
        text_encodings: Optional[Tensor] = None
    ):
        # texts 和 text_encodings 必须有且只有一个存在
        assert exists(texts) ^ exists(text_encodings)

        # 如果 text_encodings 不存在,则使用 texts 通过 clip.embed_texts 方法获取
        if not exists(text_encodings):
            with torch.no_grad():
                self.clip.eval()
                _, text_encodings = self.clip.embed_texts(texts)

        # 创建一个 mask,用于标记 text_encodings 中不为 0 的位置
        mask = (text_encodings != 0.).any(dim = -1)

        # 对 text_encodings 进行线性变换
        text_encodings = self.project_in(text_encodings)

        # 在 mask 前面填充一个 True 值,用于表示全局标记
        mask_with_global = F.pad(mask, (1, 0), value = True)

        # 获取 batch 大小,并重复学习到的全局标记
        batch = text_encodings.shape[0]
        global_tokens = repeat(self.learned_global_token, 'd -> b d', b = batch)

        # 打包全局标记和 text_encodings
        text_encodings, ps = pack([global_tokens, text_encodings], 'b * d')

        # 使用 Transformer 模型进行编码
        text_encodings = self.transformer(text_encodings, mask = mask_with_global)

        # 解包结果,获取全局标记和编码结果
        global_tokens, text_encodings = unpack(text_encodings, ps, 'b * d')

        # 返回全局标记��编码结果和 mask
        return global_tokens, text_encodings, mask

# 等权线性层
class EqualLinear(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_out,
        lr_mul = 1,
        bias = True
    ):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim_out, dim))
        if bias:
            self.bias = nn.Parameter(torch.zeros(dim_out))

        self.lr_mul = lr_mul

    # 前向传播函数
    def forward(self, input):
        return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)

# 风格网络类
class StyleNetwork(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        depth,
        lr_mul = 0.1,
        dim_text_latent = 0
    ):
        super().__init__()
        self.dim = dim
        self.dim_text_latent = dim_text_latent

        layers = []
        # 构建深度为 depth 的网络层
        for i in range(depth):
            is_first = i == 0
            dim_in = (dim + dim_text_latent) if is_first else dim

            layers.extend([EqualLinear(dim_in, dim, lr_mul), leaky_relu()])

        self.net = nn.Sequential(*layers)

    # 前向传播函数
    def forward(
        self,
        x,
        text_latent = None
    ):
        # 对输入 x 进行归一化
        x = F.normalize(x, dim = 1)

        # 如果 dim_text_latent 大于 0,则将 text_latent 拼接到 x 中
        if self.dim_text_latent > 0:
            assert exists(text_latent)
            x = torch.cat((x, text_latent), dim = -1)

        # 返回网络处理后的结果
        return self.net(x)

# 噪声类
class Noise(nn.Module):
    # 初始化函数
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(dim, 1, 1))

    # 前向传播函数
    def forward(
        self,
        x,
        noise = None
    ):
        b, _, h, w, device = *x.shape, x.device

        # 如果 noise 不存在,则创建一个随机噪声
        if not exists(noise):
            noise = torch.randn(b, 1, h, w, device = device)

        # 返回加上噪声的结果
        return x + self.weight * noise

# 生成器基类
class BaseGenerator(nn.Module):
    pass

# 生成器类
class Generator(BaseGenerator):
    # 初始化函数
    @beartype
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        image_size,  # 图像尺寸
        dim_capacity = 16,  # 容量维度
        dim_max = 2048,  # 最大维度
        channels = 3,  # 通道数
        style_network: Optional[Union[StyleNetwork, Dict]] = None,  # 风格网络
        style_network_dim = None,  # 风格网络维度
        text_encoder: Optional[Union[TextEncoder, Dict]] = None,  # 文本编码器
        dim_latent = 512,  # 潜在维度
        self_attn_resolutions: Tuple[int, ...] = (32, 16),  # 自注意力分辨率
        self_attn_dim_head = 64,  # 自注意力头维度
        self_attn_heads = 8,  # 自注意力头数
        self_attn_dot_product = True,  # 自注意力是否使用点积
        self_attn_ff_mult = 4,  # 自注意力前馈倍数
        cross_attn_resolutions: Tuple[int, ...] = (32, 16),  # 交叉注意力分辨率
        cross_attn_dim_head = 64,  # 交叉注意力头维度
        cross_attn_heads = 8,  # 交叉注意力头数
        cross_attn_ff_mult = 4,  # 交叉注意力前馈倍数
        num_conv_kernels = 2,  # 自适应卷积核数量
        num_skip_layers_excite = 0,  # 激励跳层数量
        unconditional = False,  # 是否无条件
        pixel_shuffle_upsample = False  # 像素混洗上采样
    def init_(self, m):
        # 初始化函数,使用 kaiming_normal 初始化卷积和全连接层的权重
        if type(m) in {nn.Conv2d, nn.Linear}:
            nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')

    @property
    def total_params(self):
        # 计算模型总参数数量
        return sum([p.numel() for p in self.parameters() if p.requires_grad])

    @property
    def device(self):
        # 获取模型所在设备
        return next(self.parameters()).device

    @beartype
    def forward(
        self,
        styles = None,  # 风格
        noise = None,  # 噪声
        texts: Optional[List[str]] = None,  # 文本列表
        text_encodings: Optional[Tensor] = None,  # 文本编码
        global_text_tokens = None,  # 全局文本标记
        fine_text_tokens = None,  # 精细文本标记
        text_mask = None,  # 文本掩码
        batch_size = 1,  # 批量大小
        return_all_rgbs = False  # 是否返回所有 RGB
        ):
            # 处理文本编码
            # 需要全局文本令牌来自适应选择主要贡献中的内核
            # 需要细文本令牌来使用交叉注意力

            if not self.unconditional:
                if exists(texts) or exists(text_encodings):
                    assert exists(texts) ^ exists(text_encodings), '要么传入原始文本作为 List[str],要么传入文本编码(来自 clip)作为 Tensor,但不能同时传入'
                    assert exists(self.text_encoder)

                    if exists(texts):
                        text_encoder_kwargs = dict(texts = texts)
                    elif exists(text_encodings):
                        text_encoder_kwargs = dict(text_encodings = text_encodings)

                    global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(**text_encoder_kwargs)
                else:
                    assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask)]), '未传入原始文本或文本嵌入以进行条件训练'
            else:
                assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])

            # 确定风格

            if not exists(styles):
                assert exists(self.style_network)

                if not exists(noise):
                    noise = torch.randn((batch_size, self.style_network_dim), device = self.device)

                styles = self.style_network(noise, global_text_tokens)

            # 将风格投影到卷积调制

            conv_mods = self.style_to_conv_modulations(styles)
            conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
            conv_mods = iter(conv_mods)

            # 准备初始块

            batch_size = styles.shape[0]

            x = repeat(self.init_block, 'c h w -> b c h w', b = batch_size)
            x = self.init_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))

            rgb = torch.zeros((batch_size, self.channels, 4, 4), device = self.device, dtype = x.dtype)

            # 跳过层挤压激发

            excitations = [None] * self.num_skip_layers_excite

            # 保存生成器每一层的所有 rgb 用于多分辨率输入判别

            rgbs = []

            # 主网络

            for squeeze_excite, (resnet_conv1, noise1, act1, resnet_conv2, noise2, act2), to_rgb_conv, self_attn, cross_attn, upsample, upsample_rgb in self.layers:

                if exists(upsample):
                    x = upsample(x)

                if exists(squeeze_excite):
                    skip_excite = squeeze_excite(x)
                    excitations.append(skip_excite)

                excite = safe_unshift(excitations)
                if exists(excite):
                    x = x * excite

                x = resnet_conv1(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
                x = noise1(x)
                x = act1(x)

                x = resnet_conv2(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
                x = noise2(x)
                x = act2(x)

                if exists(self_attn):
                    x = self_attn(x)

                if exists(cross_attn):
                    x = cross_attn(x, context = fine_text_tokens, mask = text_mask)

                layer_rgb = to_rgb_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))

                rgb = rgb + layer_rgb

                rgbs.append(rgb)

                if exists(upsample_rgb):
                    rgb = upsample_rgb(rgb)

            # 检查

            assert is_empty([*conv_mods]), '卷积错误调制'

            if return_all_rgbs:
                return rgb, rgbs

            return rgb
# 定义一个简单的解码器类,继承自 nn.Module
@beartype
class SimpleDecoder(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        dims: Tuple[int, ...],
        patch_dim: int = 1,
        frac_patches: float = 1.,
        dropout: float = 0.5
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保 frac_patches 在 0 到 1 之间
        assert 0 < frac_patches <= 1.

        # 初始化一些参数
        self.patch_dim = patch_dim
        self.frac_patches = frac_patches

        # 创建一个 dropout 层
        self.dropout = nn.Dropout(dropout)

        # 将 dim 和 dims 组成一个列表
        dims = [dim, *dims]

        # 初始化一个空的层列表
        layers = [conv2d_3x3(dim, dim)]

        # 遍历 dims 列表,创建卷积层和激活函数层
        for dim_in, dim_out in zip(dims[:-1], dims[1:]):
            layers.append(nn.Sequential(
                Upsample(dim_in),
                conv2d_3x3(dim_in, dim_out),
                leaky_relu()
            ))

        # 创建一个包含所有层的神经网络
        self.net = nn.Sequential(*layers)

    # 定义一个属性,返回参数的设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数,接受特征图和原始图像作为输入
    def forward(
        self,
        fmap,
        orig_image
    ):
        # 对特征图进行 dropout
        fmap = self.dropout(fmap)

        # 如果 frac_patches 小于 1
        if self.frac_patches < 1.:
            # 获取 batch 大小和 patch 维度
            batch, patch_dim = fmap.shape[0], self.patch_dim
            fmap_size, img_size = fmap.shape[-1], orig_image.shape[-1]

            # 断言确保特征图和图像大小能够整除 patch 维度
            assert divisible_by(fmap_size, patch_dim), f'feature map dimensions are {fmap_size}, but the patch dim was designated to be {patch_dim}'
            assert divisible_by(img_size, patch_dim), f'image size is {img_size} but the patch dim was specified to be {patch_dim}'

            # 重排特征图和原始图像的维度
            fmap, orig_image = map(lambda t: rearrange(t, 'b c (p1 h) (p2 w) -> b (p1 p2) c h w', p1 = patch_dim, p2 = patch_dim), (fmap, orig_image))

            # 计算总 patch 数量和需要重建的 patch 数量
            total_patches = patch_dim ** 2
            num_patches_recon = max(int(self.frac_patches * total_patches), 1)

            # 创建一个 batch 的索引和随机排列的索引
            batch_arange = torch.arange(batch, device = self.device)[..., None]
            batch_randperm = torch.randn((batch, total_patches)).sort(dim = -1).indices
            patch_indices = batch_randperm[..., :num_patches_recon]

            # 从特征图和原始图像中选择对应的 patch
            fmap, orig_image = map(lambda t: t[batch_arange, patch_indices], (fmap, orig_image))
            fmap, orig_image = map(lambda t: rearrange(t, 'b p ... -> (b p) ...'), (fmap, orig_image))

        # 将选定的 patch 输入神经网络进行重建
        recon = self.net(fmap)
        # 返回重建图像和原始图像的均方误差损失
        return F.mse_loss(recon, orig_image)

# 定义一个随机固定投影类,继承自 nn.Module
class RandomFixedProjection(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        dim_out,
        channel_first = True
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 生成随机权重并初始化
        weights = torch.randn(dim, dim_out)
        nn.init.kaiming_normal_(weights, mode = 'fan_out', nonlinearity = 'linear')

        # 初始化一些参数
        self.channel_first = channel_first
        self.register_buffer('fixed_weights', weights)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 如果 channel_first 为 False,则返回 x 与固定权重的矩阵乘积
        if not self.channel_first:
            return x @ self.fixed_weights

        # 如果 channel_first 为 True,则返回 x 与固定权重的张量乘积
        return einsum('b c ..., c d -> b d ...', x, self.fixed_weights)

# 定义一个视觉辅助鉴别器类,继承自 nn.Module
class VisionAidedDiscriminator(nn.Module):
    """ the vision-aided gan loss """

    # 初始化函数,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        depth = 2,
        dim_head = 64,
        heads = 8,
        clip: Optional[OpenClipAdapter] = None,
        layer_indices = (-1, -2, -3),
        conv_dim = None,
        text_dim = None,
        unconditional = False,
        num_conv_kernels = 2
    ):
        # 调用父类的构造函数
        super().__init__()

        # 如果指定的 clip 不存在,则使用 OpenClipAdapter() 创建一个 clip 对象
        if not exists(clip):
            clip = OpenClipAdapter()

        # 设置对象的 clip 属性为传入的 clip 参数
        self.clip = clip
        # 获取 clip 对象的 _dim_image_latent 属性值
        dim = clip._dim_image_latent

        # 设置 unconditional 属性为传入的 unconditional 参数,如果未传入则使用 dim 的值
        self.unconditional = unconditional
        text_dim = default(text_dim, dim)
        conv_dim = default(conv_dim, dim)

        # 初始化 layer_discriminators 为一个空的 nn.ModuleList
        self.layer_discriminators = nn.ModuleList([])
        # 设置 layer_indices 属性为传入的 layer_indices 参数

        # 根据 unconditional 的值选择不同的卷积类
        conv_klass = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels) if not unconditional else conv2d_3x3

        # 遍历 layer_indices,为每个索引创建一个包含不同模块的 nn.ModuleList
        for _ in layer_indices:
            self.layer_discriminators.append(nn.ModuleList([
                RandomFixedProjection(dim, conv_dim),
                conv_klass(conv_dim, conv_dim),
                nn.Linear(text_dim, conv_dim) if not unconditional else None,
                nn.Linear(text_dim, num_conv_kernels) if not unconditional else None,
                nn.Sequential(
                    conv2d_3x3(conv_dim, 1),
                    Rearrange('b 1 ... -> b ...')
                )
            ]))

    # 返回 layer_discriminators 中所有模块的参数
    def parameters(self):
        return self.layer_discriminators.parameters()

    # 返回 layer_discriminators 中所有模块参数的总数量
    @property
    def total_params(self):
        return sum([p.numel() for p in self.parameters()])

    # 前向传播函数,接收 images、texts、text_embeds 和 return_clip_encodings 参数
    @beartype
    def forward(
        self,
        images,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        return_clip_encodings = False
    ):

        # 断言条件,确保在有条件生成时存在 text_embeds 或 texts
        assert self.unconditional or (exists(text_embeds) ^ exists(texts))

        # 在无条件生成且存在 texts 时,使用 clip 对象的 embed_texts 属性作为 text_embeds
        with torch.no_grad():
            if not self.unconditional and exists(texts):
                self.clip.eval()
                text_embeds = self.clip.embed_texts

        # 获取 images 的编码结果
        _, image_encodings = self.clip.embed_images(images)

        # 初始化 logits 列表
        logits = []

        # 遍历 layer_indices 和 layer_discriminators 中的模块,计算 logits
        for layer_index, (rand_proj, conv, to_conv_mod, to_conv_kernel_mod, to_logits) in zip(self.layer_indices, self.layer_discriminators):
            image_encoding = image_encodings[layer_index]

            cls_token, rest_tokens = image_encoding[:, :1], image_encoding[:, 1:]
            height_width = int(sqrt(rest_tokens.shape[-2])) # 假设为正方形

            img_fmap = rearrange(rest_tokens, 'b (h w) d -> b d h w', h = height_width)

            img_fmap = img_fmap + rearrange(cls_token, 'b 1 d -> b d 1 1 ') # 将 cls token 汇入其余 token

            img_fmap = rand_proj(img_fmap)

            if self.unconditional:
                img_fmap = conv(img_fmap)
            else:
                assert exists(text_embeds)

                img_fmap = conv(
                    img_fmap,
                    mod = to_conv_mod(text_embeds),
                    kernel_mod = to_conv_kernel_mod(text_embeds)
                )

            layer_logits = to_logits(img_fmap)

            logits.append(layer_logits)

        # 如果不需要返回 clip 编码,则返回 logits
        if not return_clip_encodings:
            return logits

        # 否则返回 logits 和 image_encodings
        return logits, image_encodings
# 定义一个预测器类,继承自 nn.Module
class Predictor(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        depth = 4,
        num_conv_kernels = 2,
        unconditional = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置是否无条件的标志
        self.unconditional = unconditional
        # 创建一个卷积层,用于残差连接
        self.residual_fn = nn.Conv2d(dim, dim, 1)
        # 设置残差缩放因子
        self.residual_scale = 2 ** -0.5

        # 创建一个空的模块列表
        self.layers = nn.ModuleList([])

        # 根据是否无条件,选择不同的卷积类
        klass = nn.Conv2d if unconditional else partial(AdaptiveConv2DMod, num_conv_kernels = num_conv_kernels)
        klass_kwargs = dict(padding = 1) if unconditional else dict()

        # 循环创建深度次数的卷积层
        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                klass(dim, dim, 3, **klass_kwargs),
                leaky_relu(),
                klass(dim, dim, 3, **klass_kwargs),
                leaky_relu()
            ]))

        # 创建一个转换为 logits 的卷积层
        self.to_logits = nn.Conv2d(dim, 1, 1)

    # 前向传播函数,接受多个参数
    def forward(
        self,
        x,
        mod = None,
        kernel_mod = None
    ):
        # 计算残差
        residual = self.residual_fn(x)

        kwargs = dict()

        # 如果不是无条件的,则传入 mod 和 kernel_mod 参数
        if not self.unconditional:
            kwargs = dict(mod = mod, kernel_mod = kernel_mod)

        # 循环处理每一层
        for conv1, activation, conv2, activation in self.layers:

            inner_residual = x

            x = conv1(x, **kwargs)
            x = activation(x)
            x = conv2(x, **kwargs)
            x = activation(x)

            x = x + inner_residual
            x = x * self.residual_scale

        # 加上残差并返回 logits
        x = x + residual
        return self.to_logits(x)

# 定义一个鉴别器类,继承自 nn.Module
class Discriminator(nn.Module):
    # 初始化函数,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        dim_capacity = 16,
        image_size,
        dim_max = 2048,
        channels = 3,
        attn_resolutions: Tuple[int, ...] = (32, 16),
        attn_dim_head = 64,
        attn_heads = 8,
        self_attn_dot_product = False,
        ff_mult = 4,
        text_encoder: Optional[Union[TextEncoder, Dict]] = None,
        text_dim = None,
        filter_input_resolutions: bool = True,
        multiscale_input_resolutions: Tuple[int, ...] = (64, 32, 16, 8),
        multiscale_output_skip_stages: int = 1,
        aux_recon_resolutions: Tuple[int, ...] = (8,),
        aux_recon_patch_dims: Tuple[int, ...] = (2,),
        aux_recon_frac_patches: Tuple[float, ...] = (0.25,),
        aux_recon_fmap_dropout: float = 0.5,
        resize_mode = 'bilinear',
        num_conv_kernels = 2,
        num_skip_layers_excite = 0,
        unconditional = False,
        predictor_depth = 2
    def init_(self, m):
        # 初始化函数,对卷积和线性层进行初始化
        if type(m) in {nn.Conv2d, nn.Linear}:
            nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')

    # 将图像调整大小到指定分辨率
    def resize_image_to(self, images, resolution):
        return F.interpolate(images, resolution, mode = self.resize_mode)

    # 将真实图像调整大小到多个分辨率
    def real_images_to_rgbs(self, images):
        return [self.resize_image_to(images, resolution) for resolution in self.multiscale_input_resolutions]

    # 返回模型的总参数数量
    @property
    def total_params(self):
        return sum([p.numel() for p in self.parameters()])

    # 返回模型所在设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数,接受多个参数
    @beartype
    def forward(
        self,
        images,
        rgbs: List[Tensor],                   # 生成器的多分辨率输入
        texts: Optional[List[str]] = None,
        text_encodings: Optional[Tensor] = None,
        text_embeds = None,
        real_images = None,                   # 如果传入真实图像,网络将自动将其附加到传入的生成图像中,并通过适当的调整大小和连接生成所有中间分辨率
        return_multiscale_outputs = True,     # 可以强制不返回多尺度 logits
        calc_aux_loss = True
# gan

# 定义训练鉴别器损失的命名元组
TrainDiscrLosses = namedtuple('TrainDiscrLosses', [
    'divergence',
    'multiscale_divergence',
    'vision_aided_divergence',
    'total_matching_aware_loss',
    'gradient_penalty',
    'aux_reconstruction'
])
# 定义一个命名元组,包含训练生成器的损失值
TrainGenLosses = namedtuple('TrainGenLosses', [
    'divergence',
    'multiscale_divergence',
    'total_vd_divergence',
    'contrastive_loss'
])

# 定义 GigaGAN 类,继承自 nn.Module
class GigaGAN(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        *,
        generator: Union[BaseGenerator, Dict],  # 生成器对象或字典
        discriminator: Union[Discriminator, Dict],  # 判别器对象或字典
        vision_aided_discriminator: Optional[Union[VisionAidedDiscriminator, Dict]] = None,  # 辅助视觉判别器对象或字典
        diff_augment: Optional[Union[DiffAugment, Dict]] = None,  # 数据增强对象或字典
        learning_rate = 2e-4,  # 学习率
        betas = (0.5, 0.9),  # Adam 优化器的 beta 参数
        weight_decay = 0.,  # 权重衰减
        discr_aux_recon_loss_weight = 1.,  # 判别器辅助重建损失权重
        multiscale_divergence_loss_weight = 0.1,  # 多尺度散度损失权重
        vision_aided_divergence_loss_weight = 0.5,  # 视觉辅助散度损失权重
        generator_contrastive_loss_weight = 0.1,  # 生成器对比损失权重
        matching_awareness_loss_weight = 0.1,  # 匹配感知损失权重
        calc_multiscale_loss_every = 1,  # 计算多尺度损失的频率
        apply_gradient_penalty_every = 4,  # 应用梯度惩罚的频率
        resize_image_mode = 'bilinear',  # 图像调整模式
        train_upsampler = False,  # 是否训练上采样器
        log_steps_every = 20,  # 每隔多少步记录日志
        create_ema_generator_at_init = True,  # 是否在初始化时创建 EMA 生成器
        save_and_sample_every = 1000,  # 保存和采样的频率
        early_save_thres_steps = 2500,  # 早期保存的阈值步数
        early_save_and_sample_every = 100,  # 早期保存和采样的频率
        num_samples = 25,  # 采样数量
        model_folder = './gigagan-models',  # 模型保存文件夹路径
        results_folder = './gigagan-results',  # 结果保存文件夹路径
        sample_upsampler_dl: Optional[DataLoader] = None,  # 上采样器数据加载器
        accelerator: Optional[Accelerator] = None,  # 加速器对象
        accelerate_kwargs: dict = {},  # 加速参数
        find_unused_parameters = True,  # 是否查找未使用的参数
        amp = False,  # 是否使用混合精度训练
        mixed_precision_type = 'fp16'  # 混合精度类型
    # 保存模型的方法
    def save(self, path, overwrite = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 如果父目录不存在,则创建
        mkdir_if_not_exists(path.parents[0])

        # 断言是否覆盖保存或路径不存在
        assert overwrite or not path.exists()

        # 创建包含模型参数的字典
        pkg = dict(
            G = self.unwrapped_G.state_dict(),  # 生成器的状态字典
            D = self.unwrapped_D.state_dict(),  # 判别器的状态字典
            G_opt = self.G_opt.state_dict(),  # 生成器优化器的状态字典
            D_opt = self.D_opt.state_dict(),  # 判别器优化器的状态字典
            steps = self.steps.item(),  # 训练步数
            version = __version__  # 版本号
        )

        # 如果存在生成器优化器的 scaler,则保存其状态字典
        if exists(self.G_opt.scaler):
            pkg['G_scaler'] = self.G_opt.scaler.state_dict()

        # 如果存在判别器优化器的 scaler,则保存其状态字典
        if exists(self.D_opt.scaler):
            pkg['D_scaler'] = self.D_opt.scaler.state_dict()

        # 如果存在视觉辅助判别器,则保存其状态字典
        if exists(self.VD):
            pkg['VD'] = self.unwrapped_VD.state_dict()
            pkg['VD_opt'] = self.VD_opt.state_dict()

            # 如果存在视觉辅助判别器的 scaler,则保存其状态字典
            if exists(self.VD_opt.scaler):
                pkg['VD_scaler'] = self.VD_opt.scaler.state_dict()

        # 如果存在 EMA 生成器,则保存其状态字典
        if self.has_ema_generator:
            pkg['G_ema'] = self.G_ema.state_dict()

        # 使用 torch 保存模型参数字典到指定路径
        torch.save(pkg, str(path))
    # 从指定路径加载模型参数
    def load(self, path, strict = False):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型参数
        pkg = torch.load(str(path))

        # 检查加载的模型参数版本是否与当前版本一致
        if 'version' in pkg and pkg['version'] != __version__:
            print(f"trying to load from version {pkg['version']}")

        # 加载生成器和判别器的状态字典
        self.unwrapped_G.load_state_dict(pkg['G'], strict = strict)
        self.unwrapped_D.load_state_dict(pkg['D'], strict = strict)

        # 如果存在 VD 模型,则加载其状态字典
        if exists(self.VD):
            self.unwrapped_VD.load_state_dict(pkg['VD'], strict = strict)

        # 如果有 EMA 生成器,则加载其状态字典
        if self.has_ema_generator:
            self.G_ema.load_state_dict(pkg['G_ema'])

        # 如果模型参数中包含步数信息,则更新当前步数
        if 'steps' in pkg:
            self.steps.copy_(torch.tensor([pkg['steps']]))

        # 如果模型参数中包含优化器状态字典,则加载优化器状态
        if 'G_opt'not in pkg or 'D_opt' not in pkg:
            return

        try:
            # 加载生成器和判别器的优化器状态字典
            self.G_opt.load_state_dict(pkg['G_opt'])
            self.D_opt.load_state_dict(pkg['D_opt'])

            # 如果存在 VD 模型,则加载其优化器状态字典
            if exists(self.VD):
                self.VD_opt.load_state_dict(pkg['VD_opt'])

            # 如果模型参数中包含生成器的缩放器状态字典,则加载
            if 'G_scaler' in pkg and exists(self.G_opt.scaler):
                self.G_opt.scaler.load_state_dict(pkg['G_scaler'])

            # 如果模型参数中包含判别器的缩放器状态字典,则加载
            if 'D_scaler' in pkg and exists(self.D_opt.scaler):
                self.D_opt.scaler.load_state_dict(pkg['D_scaler'])

            # 如果模型参数中包含 VD 的缩放器状态字典,则加载
            if 'VD_scaler' in pkg and exists(self.VD_opt.scaler):
                self.VD_opt.scaler.load_state_dict(pkg['VD_scaler'])

        except Exception as e:
            # 加载优化器状态字典出错时打印错误信息
            self.print(f'unable to load optimizers {e.msg}- optimizer states will be reset')
            pass

    # 加速相关

    # 获取设备信息
    @property
    def device(self):
        return self.accelerator.device

    # 获取未包装的生成器模型
    @property
    def unwrapped_G(self):
        return self.accelerator.unwrap_model(self.G)

    # 获取未包装的判别器模型
    @property
    def unwrapped_D(self):
        return self.accelerator.unwrap_model(self.D)

    # 获取未包装的 VD 模型
    @property
    def unwrapped_VD(self):
        return self.accelerator.unwrap_model(self.VD)

    # 是否需要视觉辅助判别器
    @property
    def need_vision_aided_discriminator(self):
        return exists(self.VD) and self.vision_aided_divergence_loss_weight > 0.

    # 是否需要对比损失
    @property
    def need_contrastive_loss(self):
        return self.generator_contrastive_loss_weight > 0. and not self.unconditional

    # 打印信息
    def print(self, msg):
        self.accelerator.print(msg)

    # 是否分布式
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 调整图像大小
    def resize_image_to(self, images, resolution):
        return F.interpolate(images, resolution, mode = self.resize_image_mode)

    # 设置数据加载器
    @beartype
    def set_dataloader(self, dl: DataLoader):
        assert not exists(self.train_dl), 'training dataloader has already been set'

        self.train_dl = dl
        self.train_dl_batch_size = dl.batch_size

        self.train_dl = self.accelerator.prepare(self.train_dl)

    # 生成函数

    @torch.inference_mode()
    def generate(self, *args, **kwargs):
        model = self.G_ema if self.has_ema_generator else self.G
        model.eval()
        return model(*args, **kwargs)

    # 创建 EMA 生成器

    def create_ema_generator(
        self,
        update_every = 10,
        update_after_step = 100,
        decay = 0.995
    ):
        if not self.is_main:
            return

        assert not self.has_ema_generator, 'EMA generator has already been created'

        self.G_ema = EMA(self.unwrapped_G, update_every = update_every, update_after_step = update_after_step, beta = decay)
        self.has_ema_generator = True
    # 生成传递给生成器的参数
    def generate_kwargs(self, dl_iter, batch_size):
        # 根据训练是否为上采样器或非条件性来确定传递给生成器的内容

        # 可能的文本参数字典
        maybe_text_kwargs = dict()
        if self.train_upsampler or not self.unconditional:
            assert exists(dl_iter)

            if self.unconditional:
                real_images = next(dl_iter)
            else:
                result = next(dl_iter)
                assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])'
                real_images, texts = result

                maybe_text_kwargs['texts'] = texts[:batch_size]

            real_images = real_images.to(self.device)

        # 如果训练上采样生成器,则需要对真实图像进行下采样
        if self.train_upsampler:
            size = self.unwrapped_G.input_image_size
            lowres_real_images = F.interpolate(real_images, (size, size))

            G_kwargs = dict(lowres_image = lowres_real_images)
        else:
            assert exists(batch_size)

            G_kwargs = dict(batch_size = batch_size)

        # 创建噪声
        noise = torch.randn(batch_size, self.unwrapped_G.style_network.dim, device = self.device)

        G_kwargs.update(noise = noise)

        return G_kwargs, maybe_text_kwargs
    
    # 训练鉴别器的步骤
    @beartype
    def train_discriminator_step(
        self,
        dl_iter: Iterable,
        grad_accum_every = 1,
        apply_gradient_penalty = False,
        calc_multiscale_loss = True
    # 训练生成器的步骤
    def train_generator_step(
        self,
        batch_size = None,
        dl_iter: Optional[Iterable] = None,
        grad_accum_every = 1,
        calc_multiscale_loss = True
        ):
        # 初始化各种损失值
        total_divergence = 0.
        total_multiscale_divergence = 0. if calc_multiscale_loss else None
        total_vd_divergence = 0.
        contrastive_loss = 0.

        # 设置生成器和判别器为训练模式
        self.G.train()
        self.D.train()

        # 清空生成器和判别器的梯度
        self.D_opt.zero_grad()
        self.G_opt.zero_grad()

        # 初始化存储所有图像和文本的列表
        all_images = []
        all_texts = []

        for _ in range(grad_accum_every):

            # 生成器部分

            # 生成生成器所需的参数和文本参数
            G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)

            # 自动混合精度加速
            with self.accelerator.autocast():
                # 生成图像和 RGB 值
                images, rgbs = self.G(
                    **G_kwargs,
                    **maybe_text_kwargs,
                    return_all_rgbs = True
                )

                # 使用不同的数据增强方法
                if exists(self.diff_augment):
                    images, rgbs = self.diff_augment(images, rgbs)

                # 如果需要对比损失,累积所有图像和文本
                if self.need_contrastive_loss:
                    all_images.append(images)
                    all_texts.extend(maybe_text_kwargs['texts'])

                # 判别器部分

                # 获取判别器的输出
                logits, multiscale_logits, _ = self.D(
                    images,
                    rgbs,
                    **maybe_text_kwargs,
                    return_multiscale_outputs = calc_multiscale_loss,
                    calc_aux_loss = False
                )

                # 生成器的 Hinge 损失和判别器的多尺度输出
                divergence = generator_hinge_loss(logits)

                total_divergence += (divergence.item() / grad_accum_every)

                total_loss = divergence

                # 如果多尺度分歧损失权重大于 0 并且有多尺度输出
                if self.multiscale_divergence_loss_weight > 0. and len(multiscale_logits) > 0:
                    multiscale_divergence = 0.

                    for multiscale_logit in multiscale_logits:
                        multiscale_divergence = multiscale_divergence + generator_hinge_loss(multiscale_logit)

                    total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every)

                    total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight

                # 视觉辅助生成器的 Hinge 损失
                if self.need_vision_aided_discriminator:
                    vd_loss = 0.

                    logits = self.VD(images, **maybe_text_kwargs)

                    for logit in logits:
                        vd_loss = vd_loss + generator_hinge_loss(logit)

                    total_vd_divergence += (vd_loss.item() / grad_accum_every)

                    total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight

            # 反向传播
            self.accelerator.backward(total_loss / grad_accum_every, retain_graph = self.need_contrastive_loss)

        # 如果需要生成器对比损失
        # 收集所有图像和文本并计算损失
        if self.need_contrastive_loss:
            all_images = torch.cat(all_images, dim = 0)

            contrastive_loss = aux_clip_loss(
                clip = self.G.text_encoder.clip,
                texts = all_texts,
                images = all_images
            )

            self.accelerator.backward(contrastive_loss * self.generator_contrastive_loss_weight)

        # 生成器优化器步骤
        self.G_opt.step()

        # 更新指数移动平均生成器
        self.accelerator.wait_for_everyone()

        if self.is_main and self.has_ema_generator:
            self.G_ema.update()

        # 返回训练生成器的损失
        return TrainGenLosses(
            total_divergence,
            total_multiscale_divergence,
            total_vd_divergence,
            contrastive_loss
        )
    # 定义一个方法用于生成样本,接受模型、数据迭代器和批量大小作为参数
    def sample(self, model, dl_iter, batch_size):
        # 生成生成器参数和可能的文本参数
        G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)

        # 使用加速器自动混合精度
        with self.accelerator.autocast():
            # 调用模型生成输出
            generator_output = model(**G_kwargs, **maybe_text_kwargs)

        # 如果不需要训练上采样器,则直接返回生成器输出
        if not self.train_upsampler:
            return generator_output

        # 获取生成器输出的大小
        output_size = generator_output.shape[-1]
        # 获取低分辨率图像
        lowres_image = G_kwargs['lowres_image']
        # 将低分辨率图像插值到与生成器输出相同的大小
        lowres_image = F.interpolate(lowres_image, (output_size, output_size))

        # 返回拼接后的图像
        return torch.cat([lowres_image, generator_output])

    # 进入推断模式的装饰器
    @torch.inference_mode()
    # 定义一个保存样本的方法,接受批量大小和数据迭代器作为参数
    def save_sample(
        self,
        batch_size,
        dl_iter = None
    ):
        # 计算当前里程碑
        milestone = self.steps.item() // self.save_and_sample_every
        # 如果训练上采样器,则设置 nrow_mult 为 2,否则为 1
        nrow_mult = 2 if self.train_upsampler else 1
        # 将样本数量分组成批次
        batches = num_to_groups(self.num_samples, batch_size)

        # 如果训练上采样器,则使用默认的上采样器数据迭代器
        if self.train_upsampler:
            dl_iter = default(self.sample_upsampler_dl_iter, dl_iter)

        # 断言数据迭代器存在
        assert exists(dl_iter)

        # 定义保存模型和输出文件名的列表
        sample_models_and_output_file_name = [(self.unwrapped_G, f'sample-{milestone}.png')]

        # 如果有 EMA 生成器,则添加到列表中
        if self.has_ema_generator:
            sample_models_and_output_file_name.append((self.G_ema, f'ema-sample-{milestone}.png'))

        # 遍历模型和文件名列表
        for model, filename in sample_models_and_output_file_name:
            # 将模型设置为评估模式
            model.eval()

            # 获取所有图像列表
            all_images_list = list(map(lambda n: self.sample(model, dl_iter, n), batches))
            # 拼接所有图像
            all_images = torch.cat(all_images_list, dim = 0)

            # 将图像像素值限制在 0 到 1 之间
            all_images.clamp_(0., 1.)

            # 保存图像
            utils.save_image(
                all_images,
                str(self.results_folder / filename),
                nrow = int(sqrt(self.num_samples)) * nrow_mult
            )

        # 可能的操作:包括一些指标以保存改进的内容,包括一些采样器字典文本条目
        # 保存模型
        self.save(str(self.model_folder / f'model-{milestone}.ckpt'))

    # 使用 beartype 装饰器定义前向传播方法,接受步数和梯度累积频率作为参数
    @beartype
    def forward(
        self,
        *,
        steps,
        grad_accum_every = 1
        ):
        # 断言训练数据加载器已设置,否则提示需要通过运行.set_dataloader(dl: Dataloader)来设置数据加载器
        assert exists(self.train_dl), 'you need to set the dataloader by running .set_dataloader(dl: Dataloader)'

        # 获取训练数据加载器的批量大小
        batch_size = self.train_dl_batch_size
        # 创建数据加载器的迭代器
        dl_iter = cycle(self.train_dl)

        # 初始化上一次的梯度惩罚损失、多尺度判别器损失和多尺度生成器损失
        last_gp_loss = 0.
        last_multiscale_d_loss = 0.
        last_multiscale_g_loss = 0.

        # 循环执行训练步骤
        for _ in tqdm(range(steps), initial = self.steps.item()):
            # 获取当前步骤数
            steps = self.steps.item()
            # 判断是否为第一步
            is_first_step = steps == 1

            # 判断是否需要应用梯度惩罚
            apply_gradient_penalty = self.apply_gradient_penalty_every > 0 and divisible_by(steps, self.apply_gradient_penalty_every)
            # 判断是否需要计算多尺度损失
            calc_multiscale_loss =  self.calc_multiscale_loss_every > 0 and divisible_by(steps, self.calc_multiscale_loss_every)

            # 调用训练判别器步骤函数,获取各种损失值
            (
                d_loss,
                multiscale_d_loss,
                vision_aided_d_loss,
                matching_aware_loss,
                gp_loss,
                recon_loss
            ) = self.train_discriminator_step(
                dl_iter = dl_iter,
                grad_accum_every = grad_accum_every,
                apply_gradient_penalty = apply_gradient_penalty,
                calc_multiscale_loss = calc_multiscale_loss
            )

            # 等待所有进程完成
            self.accelerator.wait_for_everyone()

            # 调用训练生成器步骤函数,获取各种损失值
            (
                g_loss,
                multiscale_g_loss,
                vision_aided_g_loss,
                contrastive_loss
            ) = self.train_generator_step(
                dl_iter = dl_iter,
                batch_size = batch_size,
                grad_accum_every = grad_accum_every,
                calc_multiscale_loss = calc_multiscale_loss
            )

            # 如果梯度惩罚损失存在,则更新上一次的梯度惩罚损失
            if exists(gp_loss):
                last_gp_loss = gp_loss

            # 如果多尺度判别器损失存在,则更新上一次的多尺度判别器损失
            if exists(multiscale_d_loss):
                last_multiscale_d_loss = multiscale_d_loss

            # 如果多尺度生成器损失存在,则更新上一次的多尺度生成器损失
            if exists(multiscale_g_loss):
                last_multiscale_g_loss = multiscale_g_loss

            # 如果是第一步或者步骤数能被log_steps_every整除,则输出损失信息
            if is_first_step or divisible_by(steps, self.log_steps_every):

                # 构建损失信息元组
                losses = (
                    ('G', g_loss),
                    ('MSG', last_multiscale_g_loss),
                    ('VG', vision_aided_g_loss),
                    ('D', d_loss),
                    ('MSD', last_multiscale_d_loss),
                    ('VD', vision_aided_d_loss),
                    ('GP', last_gp_loss),
                    ('SSL', recon_loss),
                    ('CL', contrastive_loss),
                    ('MAL', matching_aware_loss)
                )

                # 将损失信息转换为字符串格式
                losses_str = ' | '.join([f'{loss_name}: {loss:.2f}' for loss_name, loss in losses])

                # 打印损失信息
                self.print(losses_str)

            # 等待所有进程完成
            self.accelerator.wait_for_everyone()

            # 如果是主进程且是第一步或者步骤数能被save_and_sample_every整除或者步骤数小于early_save_thres_steps且能被early_save_and_sample_every整除,则保存样本
            if self.is_main and (is_first_step or divisible_by(steps, self.save_and_sample_every) or (steps <= self.early_save_thres_steps and divisible_by(steps, self.early_save_and_sample_every))):
                self.save_sample(batch_size, dl_iter)
            
            # 更新步骤数
            self.steps += 1

        # 打印完成训练步骤数
        self.print(f'complete {steps} training steps')

.\lucidrains\gigagan-pytorch\gigagan_pytorch\open_clip.py

import torch
from torch import nn, einsum
import torch.nn.functional as F
import open_clip

from einops import rearrange

from beartype import beartype
from beartype.typing import List, Optional

# 检查值是否存在
def exists(val):
    return val is not None

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# OpenClipAdapter 类,继承自 nn.Module
class OpenClipAdapter(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        name = 'ViT-B/32',
        pretrained = 'laion400m_e32',
        tokenizer_name = 'ViT-B-32-quickgelu',
        eos_id = 49407
    ):
        super().__init__()

        # 创建 OpenCLIP 模型和预处理
        clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
        tokenizer = open_clip.get_tokenizer(tokenizer_name)

        self.clip = clip
        self.tokenizer = tokenizer
        self.eos_id = eos_id

        # 获取文本表示的钩子
        text_attention_final = self.find_layer('ln_final')
        self._dim_latent = text_attention_final.weight.shape[0]
        self.text_handle = text_attention_final.register_forward_hook(self._text_hook)

        # 获取图像表示的钩子
        self._dim_image_latent = self.find_layer('visual.ln_post').weight.shape[0]

        num_visual_layers = len(clip.visual.transformer.resblocks)
        self.image_handles = []

        for visual_layer in range(num_visual_layers):
            image_attention_final = self.find_layer(f'visual.transformer.resblocks.{visual_layer}')

            handle = image_attention_final.register_forward_hook(self._image_hook)
            self.image_handles.append(handle)

        # 归一化函数
        self.clip_normalize = preprocess.transforms[-1]
        self.cleared = False

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 查找指定层
    def find_layer(self,  layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    # 清除钩子
    def clear(self):
        if self.cleared:
            return

        self.text_handle()
        self.image_handle()

    # 文本钩子函数
    def _text_hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    # 图像钩子函数
    def _image_hook(self, _, inputs, outputs):
        if not hasattr(self, 'image_encodings'):
            self.image_encodings = []

        self.image_encodings.append(outputs)

    # 获取潜在维度
    @property
    def dim_latent(self):
        return self._dim_latent

    # 获取图像尺寸
    @property
    def image_size(self):
        image_size = self.clip.visual.image_size
        if isinstance(image_size, tuple):
            return max(image_size)
        return image_size

    # 获取图像通道数
    @property
    def image_channels(self):
        return 3

    # 获取最大文本长度
    @property
    def max_text_len(self):
        return self.clip.positional_embedding.shape[0]

    # 嵌入文本
    @beartype
    def embed_texts(
        self,
        texts: List[str]
    ):
        ids = self.tokenizer(texts)
        ids = ids.to(self.device)
        ids = ids[..., :self.max_text_len]

        is_eos_id = (ids == self.eos_id)
        text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
        text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
        text_mask = text_mask & (ids != 0)
        assert not self.cleared

        text_embed = self.clip.encode_text(ids)
        text_encodings = self.text_encodings
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        del self.text_encodings
        return l2norm(text_embed.float()), text_encodings.float()

    # 嵌入图像
    def embed_images(self, images):
        if images.shape[-1] != self.image_size:
            images = F.interpolate(images, self.image_size)

        assert not self.cleared
        images = self.clip_normalize(images)
        image_embeds = self.clip.encode_image(images)

        image_encodings = rearrange(self.image_encodings, 'l n b d -> l b n d')
        del self.image_encodings

        return l2norm(image_embeds.float()), image_encodings.float()

    @beartype
    # 对比损失函数,用于计算文本和图像之间的相似性损失
    def contrastive_loss(
        self,
        images,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None
    ):
        # 断言文本或文本嵌入至少存在一个
        assert exists(texts) ^ exists(text_embeds)

        # 如果文本嵌入不存在,则通过文本获取文本嵌入
        if not exists(text_embeds):
            text_embeds, _ = self.embed_texts(texts)

        # 通过图像获取图像嵌入
        image_embeds, _ = self.embed_images(images)

        # 获取文本嵌入的数量
        n = text_embeds.shape[0]

        # 获取温度参数
        temperature = self.clip.logit_scale.exp()
        # 计算文本嵌入和图像嵌入之间的相似性
        sim = einsum('i d, j d -> i j', text_embeds, image_embeds) * temperature

        # 创建标签,用于计算交叉熵损失
        labels = torch.arange(n, device = sim.device)

        # 返回文本和图像之间的相似性损失
        return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2

.\lucidrains\gigagan-pytorch\gigagan_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 根据参数设置获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = True,
    group_wd_params = True,
    **kwargs
):
    # 根据参数是否需要梯度来过滤参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and wd > 0:
        # 将参数分为需要权重衰减和不需要权重衰减的两个列表
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 根据分组情况设置参数列表
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果不需要权重衰减,则使用 Adam 优化器
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要权重衰减,则使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\gigagan-pytorch\gigagan_pytorch\unet_upsampler.py

# 从 math 模块中导入 log2 函数
from math import log2
# 从 functools 模块中导入 partial 函数
from functools import partial

# 导入 torch 库
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数,以及 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 从 gigagan_pytorch 模块中导入各个自定义类和函数
from gigagan_pytorch.attend import Attend
from gigagan_pytorch.gigagan_pytorch import (
    BaseGenerator,
    StyleNetwork,
    AdaptiveConv2DMod,
    TextEncoder,
    CrossAttentionBlock,
    Upsample
)

# 从 beartype 库中导入 beartype 函数和相关类型注解
from beartype import beartype
from beartype.typing import Optional, List, Union, Dict, Iterable

# 辅助函数

# 判断变量是否存在
def exists(x):
    return x is not None

# 返回默认值函数
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将输入转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 返回输入本身的函数
def identity(t, *args, **kwargs):
    return t

# 判断一个数是否为2的幂
def is_power_of_two(n):
    return log2(n).is_integer()

# 生成无限循环的迭代器
def null_iterator():
    while True:
        yield None

# 小型辅助模块

# 像素混洗上采样类
class PixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)

        # 创建卷积层对象
        conv = nn.Conv2d(dim, dim_out * 4, 1)
        self.init_conv_(conv)

        # 定义网络结构
        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(2)
        )

    # 初始化卷积层权重
    def init_conv_(self, conv):
        o, *rest_shape = conv.weight.shape
        conv_weight = torch.empty(o // 4, *rest_shape)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 下采样函数
def Downsample(dim, dim_out = None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

# RMS 归一化类
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    # 前向传播函数
    def forward(self, x):
        return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)

# 构建块模块

# 基础块类
class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        num_conv_kernels = 0
    ):
        super().__init__()
        self.proj = AdaptiveConv2DMod(dim, dim_out, kernel = 3, num_conv_kernels = num_conv_kernels)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    # 前向传播函数
    def forward(
        self,
        x,
        conv_mods_iter: Optional[Iterable] = None
    ):
        conv_mods_iter = default(conv_mods_iter, null_iterator())

        x = self.proj(
            x,
            mod = next(conv_mods_iter),
            kernel_mod = next(conv_mods_iter)
        )

        x = self.norm(x)
        x = self.act(x)
        return x

# ResNet 块类
class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        groups = 8,
        num_conv_kernels = 0,
        style_dims: List = []
    ):
        super().__init__()
        style_dims.extend([
            dim,
            num_conv_kernels,
            dim_out,
            num_conv_kernels
        ])

        self.block1 = Block(dim, dim_out, groups = groups, num_conv_kernels = num_conv_kernels)
        self.block2 = Block(dim_out, dim_out, groups = groups, num_conv_kernels = num_conv_kernels)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数
    def forward(
        self,
        x,
        conv_mods_iter: Optional[Iterable] = None
    ):
        h = self.block1(x, conv_mods_iter = conv_mods_iter)
        h = self.block2(h, conv_mods_iter = conv_mods_iter)

        return h + self.res_conv(x)

# 线性注意力类
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32
    # 初始化函数,设置缩放因子和头数
    def __init__(
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        # 初始化 RMSNorm 层
        self.norm = RMSNorm(dim)
        # 创建卷积层,用于计算查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        # 创建输出层,包含卷积层和 RMSNorm 层
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            RMSNorm(dim)
        )

    # 前向传播函数
    def forward(self, x):
        b, c, h, w = x.shape

        # 对输入进行归一化处理
        x = self.norm(x)

        # 将输入通过卷积层得到查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        # 对查询和键进行 softmax 处理
        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        # 对查询进行缩放
        q = q * self.scale

        # 计算上下文信息
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        # 计算输出
        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        flash = False
    ):
        # 初始化注意力机制模块
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        # 归一化层
        self.norm = RMSNorm(dim)
        # 注意力计算
        self.attend = Attend(flash = flash)

        # 将输入转换为查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        # 输出转换
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape

        # 归一化输入
        x = self.norm(x)

        # 将输入转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)

        # 注意力计算
        out = self.attend(q, k, v)

        # 重排输出形状
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# feedforward

def FeedForward(dim, mult = 4):
    # 前馈神经网络
    return nn.Sequential(
        RMSNorm(dim),
        nn.Conv2d(dim, dim * mult, 1),
        nn.GELU(),
        nn.Conv2d(dim * mult, dim, 1)
    )

# transformers

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        depth = 1,
        flash_attn = True,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 构建多层Transformer
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash_attn),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

class LinearTransformer(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        depth = 1,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 构建多层LinearTransformer
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LinearAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

# model

class UnetUpsampler(BaseGenerator):

    @beartype
    def __init__(
        self,
        dim,
        *,
        image_size,
        input_image_size,
        init_dim = None,
        out_dim = None,
        text_encoder: Optional[Union[TextEncoder, Dict]] = None,
        style_network: Optional[Union[StyleNetwork, Dict]] = None,
        style_network_dim = None,
        dim_mults = (1, 2, 4, 8, 16),
        channels = 3,
        resnet_block_groups = 8,
        full_attn = (False, False, False, True, True),
        cross_attn = (False, False, False, True, True),
        flash_attn = True,
        self_attn_dim_head = 64,
        self_attn_heads = 8,
        self_attn_dot_product = True,
        self_attn_ff_mult = 4,
        attn_depths = (1, 1, 1, 1, 1),
        cross_attn_dim_head = 64,
        cross_attn_heads = 8,
        cross_ff_mult = 4,
        mid_attn_depth = 1,
        num_conv_kernels = 2,
        resize_mode = 'bilinear',
        unconditional = True,
        skip_connect_scale = None
    ):
        # 初始化UnetUpsampler模型
        super().__init__()

    @property
    def allowable_rgb_resolutions(self):
        # 计算允许的RGB分辨率
        input_res_base = int(log2(self.input_image_size))
        output_res_base = int(log2(self.image_size))
        allowed_rgb_res_base = list(range(input_res_base, output_res_base))
        return [*map(lambda p: 2 ** p, allowed_rgb_res_base)]

    @property
    def device(self):
        # 获取模型所在设备
        return next(self.parameters()).device

    @property
    def total_params(self):
        # 计算模型总参数数量
        return sum([p.numel() for p in self.parameters()])

    def resize_image_to(self, x, size):
        # 调整输入图像大小
        return F.interpolate(x, (size, size), mode = self.resize_mode)
    # 定义一个前向传播函数,接受低分辨率图像、风格、噪声、文本等参数,并返回RGB图像
    def forward(
        self,
        lowres_image,
        styles = None,
        noise = None,
        texts: Optional[List[str]] = None,
        global_text_tokens = None,
        fine_text_tokens = None,
        text_mask = None,
        return_all_rgbs = False,
        replace_rgb_with_input_lowres_image = True   # discriminator should also receive the low resolution image the upsampler sees
    ):
        # 将输入的低分辨率图像赋值给x
        x = lowres_image
        # 获取x的形状
        shape = x.shape
        # 获取批处理大小
        batch_size = shape[0]

        # 断言x的最后两个维度与输入图像大小相同
        assert shape[-2:] == ((self.input_image_size,) * 2)

        # 处理文本编码
        # 需要全局文本标记自适应选择主要贡献中的内核
        # 需要细节文本标记进行交叉注意力
        if not self.unconditional:
            if exists(texts):
                assert exists(self.text_encoder)
                global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(texts)
            else:
                assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))])
        else:
            assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])

        # 风格
        if not exists(styles):
            assert exists(self.style_network)

            noise = default(noise, torch.randn((batch_size, self.style_network.dim), device = self.device))
            styles = self.style_network(noise, global_text_tokens)

        # 将风格投影到卷积调制
        conv_mods = self.style_to_conv_modulations(styles)
        conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
        conv_mods = iter(conv_mods)

        # 初始卷积
        x = self.init_conv(x)

        h = []

        # 下采样阶段
        for block1, block2, cross_attn, attn, downsample in self.downs:
            x = block1(x, conv_mods_iter = conv_mods)
            h.append(x)

            x = block2(x, conv_mods_iter = conv_mods)

            x = attn(x)

            if exists(cross_attn):
                x = cross_attn(x, context = fine_text_tokens, mask = text_mask)

            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, conv_mods_iter = conv_mods)
        x = self.mid_attn(x)
        x = self.mid_block2(x, conv_mods_iter = conv_mods)

        # rgbs
        rgbs = []

        init_rgb_shape = list(x.shape)
        init_rgb_shape[1] = self.channels

        rgb = self.mid_to_rgb(x)
        rgbs.append(rgb)

        # 上采样阶段
        for upsample, upsample_rgb, to_rgb, block1, block2, cross_attn, attn in self.ups:

            x = upsample(x)
            rgb = upsample_rgb(rgb)

            res1 = h.pop() * self.skip_connect_scale
            res2 = h.pop() * self.skip_connect_scale

            fmap_size = x.shape[-1]
            residual_fmap_size = res1.shape[-1]

            if residual_fmap_size != fmap_size:
                res1 = self.resize_image_to(res1, fmap_size)
                res2 = self.resize_image_to(res2, fmap_size)

            x = torch.cat((x, res1), dim = 1)
            x = block1(x, conv_mods_iter = conv_mods)

            x = torch.cat((x, res2), dim = 1)
            x = block2(x, conv_mods_iter = conv_mods)

            if exists(cross_attn):
                x = cross_attn(x, context = fine_text_tokens, mask = text_mask)

            x = attn(x)

            rgb = rgb + to_rgb(x)
            rgbs.append(rgb)

        x = self.final_res_block(x, conv_mods_iter = conv_mods)

        assert len([*conv_mods]) == 0

        rgb = rgb + self.final_to_rgb(x)

        if not return_all_rgbs:
            return rgb

        # 仅保留那些特征图大于要上采样的输入图像的rgbs
        rgbs = list(filter(lambda t: t.shape[-1] > shape[-1], rgbs))

        # 并将原始输入图像作为最小的rgb返回
        rgbs = [lowres_image, *rgbs]

        return rgb, rgbs
posted @ 2024-06-28 14:04  绝不原创的飞龙  阅读(27)  评论(0)    收藏  举报