Lucidrains-系列项目源码解析-四十-

Lucidrains 系列项目源码解析(四十)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\sinkhorn-transformer\examples\enwik8_simple\train.py

# 导入所需的库和模块
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化模型

model = SinkhornTransformerLM(
    num_tokens = 256,
    emb_dim = 128,
    dim = 512,
    depth = 8,
    max_seq_len = SEQ_LEN,
    heads = 8,
    bucket_size = 128,
    ff_chunks = 2,
    causal = True,
    reversible = True,
    attn_dropout = 0.1,
    n_local_attn_heads = 4
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\sinkhorn-transformer\examples\increment_by_one\train.py

# 导入 torch 库
import torch
# 从 sinkhorn_transformer 库中导入 SinkhornTransformerLM 类
from sinkhorn_transformer.sinkhorn_transformer import SinkhornTransformerLM
# 从 sinkhorn_transformer 库中导入 AutoregressiveWrapper 类
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper

# 定义批量大小
N_BATCH = 16
# 定义源序列长度
SRC_SEQ_LEN = 512
# 定义目标序列长度
TGT_SEQ_LEN = 512

# 创建 SinkhornTransformerLM 编码器对象
enc = SinkhornTransformerLM(
    num_tokens = 64,
    dim = 512,
    depth = 1,
    heads = 8,
    max_seq_len = SRC_SEQ_LEN,
    bucket_size = 64,
    return_embeddings = True
).cuda()

# 创建 SinkhornTransformerLM 解码器对象
dec = SinkhornTransformerLM(
    num_tokens = 64,
    dim = 512,
    depth = 2,
    heads = 8,
    max_seq_len = TGT_SEQ_LEN,
    bucket_size = 64,
    causal = True,
    receives_context = True
).cuda()

# 将解码器包装在 AutoregressiveWrapper 中,设置忽略索引和填充值
dec = AutoregressiveWrapper(dec, ignore_index = 0, pad_value = 0)
# 使用 Adam 优化器,传入编码器和解码器的参数,设置学习率
opt = torch.optim.Adam([*enc.parameters(), *dec.parameters()], lr=2e-4)

# 定义起始符、结束符和位置符
bos = 1 * torch.ones(N_BATCH, 1).long()
eos = 2 * torch.ones(N_BATCH, 1).long()
pos = 3 * torch.ones(N_BATCH, 1).long()

# 循环训练
for i in range(10000):
    # 生成随机训练序列
    train_seq_in = torch.randint(4, 63, (N_BATCH, SRC_SEQ_LEN-2)).long()
    # 目标序列为输入序列加一
    train_seq_out = train_seq_in + 1

    # 在序列开头和结尾添加起始符和结束符,并转移到 GPU
    x = torch.cat([bos, train_seq_in, eos], dim=1).cuda()
    y = torch.cat([bos, train_seq_out, eos], dim=1).cuda()

    # 编码输入序列,得到上下文信息
    context = enc(x)
    # 计算解码器的损失
    loss = dec(y, context = context, return_loss = True)
    # 反向传播计算梯度
    loss.backward()

    # 更新优化器参数
    opt.step()
    # 梯度清零
    opt.zero_grad()
    # 打印当前迭代次数和损失值
    print(i, loss.item())

Sinkhorn Transformer

PyPI version


This is a reproduction of the work outlined in Sparse Sinkhorn Attention, with additional enhancements.

It includes a parameterized sorting network, using sinkhorn normalization to sample a permutation matrix that matches the most relevant buckets of keys to the buckets of queries.

This work also brings in reversible networks and feed forward chunking (concepts introduced from Reformer) to bring about further memory savings.

Open In Colab 204k tokens (demonstration purposes)

Install

$ pip install sinkhorn_transformer

Use

A Sinkhorn Transformer based language model

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

A plain Sinkhorn Transformer, layers of sinkhorn attention

import torch
from sinkhorn_transformer import SinkhornTransformer

model = SinkhornTransformer(
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128
)

x = torch.randn(1, 2048, 1024)
model(x) # (1, 2048, 1024)

Sinkhorn Encoder / Decoder Transformer

import torch
from sinkhorn_transformer import SinkhornTransformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = DE_SEQ_LEN,
    reversible = True,
    return_embeddings = True
).cuda()

dec = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = EN_SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
).cuda()

x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).cuda()
y = torch.randint(0, 20000, (1, EN_SEQ_LEN)).cuda()

x_mask = torch.ones_like(x).bool().cuda()
y_mask = torch.ones_like(y).bool().cuda()

context = enc(x, input_mask=x_mask)
dec(y, context=context, input_mask=y_mask, context_mask=x_mask) # (1, 4096, 20000)

Autopadder

By default the model will complain if given an input that is not a multiple of the bucket size. To avoid having to make the same padding calculations each time, you can use the helper Autopadder class. It will take care of the input_mask for you as well, if given. Contextual key/values and mask are supported as well.

import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 2048,
    bucket_size = 128,
    causal = True
)

model = Autopadder(model, pad_left=True) # autopadder will fetch the bucket size and autopad input

x = torch.randint(0, 20000, (1, 1117)) # odd sequence length
model(x) # (1, 1117, 20000)

Sinkhorn

This repository has diverged from the paper and is now using attention in place of the original sorting net + gumbel sinkhorn sampling. I have not found a noticeable difference in performance yet, and the new scheme allows me to generalize the network to flexible sequence lengths. If you would like to try Sinkhorn, please use the following settings, which only works for non-causal networks.

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128,
    max_seq_len = 8192,
    use_simple_sort_net = True, # turn off attention sort net
    sinkhorn_iter = 7,          # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,              # use sortcut to reduce complexity to linear time
    temperature = 0.75,         # gumbel temperature - default is set at reported best in paper
    non_permutative = False,    # allow buckets of keys to be sorted to queries more than once
)

x = torch.randint(0, 20000, (1, 8192))
model(x) # (1, 8192, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Issues

Decoding and sequence lengths

Sinkhorn, when trained on fixed length sequences, seems to have trouble decoding sequences from scratch, mainly due to the fact that the sorting net has trouble generalizing when the buckets are partially filled with padding tokens.

Fortunately, I think I have found a simple solution. During training, for causal networks, randomly truncate the sequences and force the sorting net to generalize. I have provided a flag (randomly_truncate_sequence) for the AutoregressiveWrapper instance to make this easy.

import torch
from sinkhorn_transformer import SinkhornTransformerLM, AutoregressiveWrapper

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 75,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

I am open to suggestions if someone has found a better solution.

Causal sorting net

There is a potential problem with the causal sorting network, where the decision of which key/value buckets of the past sorts to a bucket is dependent only on the first token and not the rest (due to the bucketing scheme and preventing leakage of future to past).

I have attempted to alleviate this problem by rotating half the heads to the left by bucket size - 1, thereby promoting the last token to be first. This is also the reason why the AutoregressiveWrapper defaults to left padding during training, to always make sure that the last token in the sequence have a say in what to retrieve.

If anyone has found a cleaner solution, please let me know in the issues.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Reformer - https://github.com/lucidrains/reformer-pytorch

Citations

@misc{tay2020sparse,
    title   = {Sparse Sinkhorn Attention},
    author  = {Yi Tay and Dara Bahri and Liu Yang and Donald Metzler and Da-Cheng Juan},
    year    = {2020},
    url.    = {https://arxiv.org/abs/2002.11296}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}

.\lucidrains\sinkhorn-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'sinkhorn_transformer',
  # 查找并排除示例文件夹以外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '0.11.4',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Sinkhorn Transformer - Sparse Sinkhorn Attention',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/sinkhorn-transformer',
  # 关键词
  keywords = ['transformers', 'attention', 'artificial intelligence'],
  # 安装依赖项
  install_requires=[
      'axial-positional-embedding>=0.1.0',
      'local-attention',
      'product-key-memory',  
      'torch'
  ],
  # 分类标签
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\sinkhorn-transformer\sinkhorn_transformer\autopadder.py

# 导入数学库
import math
# 导入 PyTorch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F
# 从 sinkhorn_transformer.sinkhorn_transformer 中导入 SinkhornTransformer 和 SinkhornTransformerLM 类
from sinkhorn_transformer.sinkhorn_transformer import SinkhornTransformer, SinkhornTransformerLM

# 定义一个函数,用于查找指定类型的模块
def find_module(nn_module, type):
    # 遍历 nn_module 中的所有模块
    for module in nn_module.modules():
        # 如果模块是指定类型的实例,则返回该模块
        if isinstance(module, type):
            return module
    # 如果没有找到指定类型的模块,则返回 None
    return None

# 定义一个函数,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, pad_left = False):
    # 获取张量在指定维度上的长度
    seqlen = tensor.shape[dim]
    # 计算需要填充的长度
    m = seqlen / multiple
    # 如果 m 是整数,则不需要填充
    if m.is_integer():
        return tensor, 0

    # 计算填充前的偏移量
    pre_pad_offset = (0,) * (-1 - dim) * 2
    # 计算需要填充的长度
    padding = math.ceil(m) * multiple - seqlen
    # 根据填充方式进行填充
    offset = (padding, 0) if pad_left else (0, padding)
    # 对张量进行填充操作
    padded_tensor = F.pad(tensor, (*pre_pad_offset, *offset), value=0)
    return padded_tensor, padding

# 定义一个自动填充器类
class Autopadder(nn.Module):
    def __init__(self, net, pad_left=False):
        super().__init__()
        # 断言 net 是 SinkhornTransformer 或 SinkhornTransformerLM 类的实例
        assert isinstance(net, (SinkhornTransformer, SinkhornTransformerLM)), 'only modules SinkhornTransformer and SinkhornTransformerLM accepted'
        self.net = net

        # 判断 net 是否为 SinkhornTransformerLM 类的实例
        is_lm = isinstance(net, SinkhornTransformerLM)
        # 查找 net 中的 SinkhornTransformer 模块
        sinkhorn = find_module(net, SinkhornTransformer)
        # 获取填充到桶大小的值
        self.bucket_size = sinkhorn.pad_to_bucket_size
        # 获取上下文桶大小的值
        self.context_bucket_size = sinkhorn.context_bucket_size

        # 根据 net 的类型确定填充的维度
        self.pad_dim = -1 if is_lm else -2
        # 设置填充的方式
        self.pad_left = pad_left

    # 定义前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入张��的 batch 大小和时间步长
        b, t, device = *x.shape[:2], x.device

        # 获取关键字参数中的上下文和输入掩码
        context = kwargs.get('context')
        input_mask = kwargs.get('input_mask')
        context_mask = kwargs.get('context_mask')

        # 如果输入掩码为空,则创建一个全为 True 的掩码张量
        if input_mask is None:
            input_mask = torch.full(x.shape[:2], True, device=x.device, dtype=torch.bool)

        # 如果存在上下文且上下文掩码为空,则创建一个全为 True 的上下文掩码张量
        if context is not None and context_mask is None:
            context_mask = torch.full(context.shape[0:2], True, device=x.device, dtype=torch.bool)

        # 对输入张量进行填充操作
        x, padding = pad_to_multiple(x, self.bucket_size, dim=self.pad_dim, pad_left=self.pad_left)

        # 如果有填充操作,则更新输入掩码
        if padding != 0:
            offset = (0, padding) if not self.pad_left else (padding, 0)
            new_mask = F.pad(input_mask, offset, value=False)
            kwargs.update(input_mask=new_mask)

        # 如果存在上下文,则对上下文进行填充操作
        if context is not None:
            context, context_padding = pad_to_multiple(context, self.context_bucket_size, dim=-2)

            # 如果有填充操作,则更新上下文掩码
            if context_padding != 0:
                new_mask = F.pad(context_mask, (0, context_padding), value=False)
                kwargs.update(context_mask=new_mask)

            # 更新关键字参数中的上下文
            kwargs.update(context=context)

        # 调用 net 的前向传播函数
        out = self.net(x, **kwargs)

        # 根据填充方式获取输出切片
        output_slice = slice(0, t) if not self.pad_left else slice(padding, None)
        return out[:, output_slice]

.\lucidrains\sinkhorn-transformer\sinkhorn_transformer\autoregressive_wrapper.py

# 导入必要的库
from functools import partial
import torch
from random import randint
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from sinkhorn_transformer.sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autopadder import Autopadder

# 定义一个函数,如果值为None,则返回默认值
def default(value, default):
    return value if value is not None else default

# 从logits中选择概率最高的元素,保留概率总和大于阈值的元素
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 从logits中选择概率最高的k个元素,其余设置为负无穷
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 将序列左侧填充到相同长度
def pad_sequence_left(seqs, value):
    m = max([len(s) for s in seqs])
    return torch.stack([F.pad(s, (m - len(s), 0)) for s in seqs])

# 随机截断输入序列
def random_truncate_inputs(inputs, mask = None, pad_value=0):
    b, t, device, dtype = *inputs.shape, inputs.device, inputs.dtype
    mask = default(mask, torch.ones_like(inputs))
    rand_lengths = torch.randint(2, t, (b, 1))
    rand_mask = (torch.arange(t) < rand_lengths).to(device)
    target_seqs = [t.masked_select(mask) for mask, t in zip(rand_mask, inputs)]
    mask_seqs = [m.masked_select(mask) for mask, m in zip(rand_mask, rand_mask)]
    return pad_sequence_left(target_seqs, pad_value), pad_sequence_left(mask_seqs, False)

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = None, pad_value = 0, pad_left = True):
        super().__init__()
        assert isinstance(net, SinkhornTransformerLM), 'generative trainer wrapper can only accept SinkhornTransformerLM class'
        self.pad_value = pad_value
        self.ignore_index = default(ignore_index, pad_value)

        self.net = Autopadder(net, pad_left = pad_left)
        self.max_seq_len = net.max_seq_len

    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        input_mask = kwargs.pop('input_mask', None)

        if input_mask is None:
            input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            input_mask = input_mask[:, -self.max_seq_len:]
            logits = self.net(x, input_mask=input_mask, **kwargs)[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            input_mask = F.pad(input_mask, (0, 1), value=True)
            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out
    # 定义前向传播函数,接受输入 x,是否返回损失值,是否随机截断序列等参数
    def forward(self, x, return_loss = False, randomly_truncate_sequence = False, **kwargs):
        # 定义一个填充函数,用于在序列维度上进行填充,保证输入数据维度一致
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        # 如果不需要返回损失值
        if not return_loss:
            # 如果输入不是张量,则进行填充
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            # 返回网络处理后的结果
            return self.net(x, **kwargs)

        # 从参数中弹出输入掩码
        m = kwargs.pop('input_mask', None)

        # 如果需要随机截断序列
        if randomly_truncate_sequence:
            # 对输入进行随机截断
            x, m = random_truncate_inputs(x, m, pad_value = self.pad_value)

        # 如果输入是张量
        if isinstance(x, torch.Tensor):
            # 分别获取输入序列和目标序列
            xi, xo = x[:, :-1], x[:, 1:]
        else:
            # 对输入进行填充和截断,获取输入序列和目标序列
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        # 如果存在输入掩码
        if m is not None:
            # 确保输入掩码形状与输入序列的形状一致
            assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
            # 更新参数中的输入掩码
            kwargs.update(input_mask = m[:, :-1])

        # 将输入序列传入网络得到输出
        out = self.net(xi, **kwargs)

        # 计算交叉熵损失
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        # 返回损失值
        return loss

.\lucidrains\sinkhorn-transformer\sinkhorn_transformer\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    for key in matched_keys:
        val = args[key]
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 根据概率丢弃层的函数
def layer_drop(layers, prob):
    to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
    blocks = [block for block, drop in zip(layers, to_drop) if not drop]
    blocks = layers[:1] if len(blocks) == 0 else blocks
    return blocks

# 保存和设置随机数种子的类
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 可逆块类,受启发于 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

# 可逆函数类
class _ReversibleFunction(Function):
    @staticmethod
    # 前向传播函数,接收上下文对象 ctx,输入数据 x,模块列表 blocks 和参数列表 args
    def forward(ctx, x, blocks, args):
        # 将参数列表 args 存储到上下文对象 ctx 中
        ctx.args = args
        # 遍历模块列表 blocks 和参数列表 args,对输入数据 x 进行处理
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        # 将处理后的数据 x 分离出来,并存储到上下文对象 ctx 中
        ctx.y = x.detach()
        # 将模块列表 blocks 存储到上下文对象 ctx 中
        ctx.blocks = blocks
        # 返回处理后的数据 x
        return x

    # 反向传播函数,接收上下文对象 ctx 和梯度 dy
    @staticmethod
    def backward(ctx, dy):
        # 获取上下文对象 ctx 中存储的处理后的数据 y 和参数列表 args
        y = ctx.y
        args = ctx.args
        # 反向遍历模块列表 blocks 和参数列表 args,对梯度 dy 进行处理
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用模块的反向传播函数,更新梯度 dy 和数据 y
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度 dy
        return dy, None, None
class SequentialSequence(nn.Module):
    # 定义一个顺序执行的神经网络模块
    def __init__(self, layers, args_route = {}, layer_dropout = 0.):
        super().__init__()
        # 断言每个参数路由映射的深度与顺序层的数量相同
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        self.layers = layers
        self.args_route = args_route
        self.layer_dropout = layer_dropout

    def forward(self, x, **kwargs):
        # 根据参数路由和关键字参数获取参数
        args = route_args(self.args_route, kwargs, len(self.layers))
        layers_and_args = list(zip(self.layers, args))

        if self.training and self.layer_dropout > 0:
            # 如果处于训练状态且存在层丢弃率,则执行层丢弃
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)

        for (f, g), (f_args, g_args) in layers_and_args:
            # 依次执行每个顺序层的前向传播
            x = x + f(x, **f_args)
            x = x + g(x, **g_args)
        return x

class ReversibleSequence(nn.Module):
    # 定义一个可逆的序列神经网络模块
    def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
        super().__init__()
        self.args_route = args_route
        self.layer_dropout = layer_dropout
        # 创建包含可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    def forward(self, x, **kwargs):
        # 在最后一个维度上连接输入张量的副本
        x = torch.cat([x, x], dim=-1)

        blocks = self.blocks
        # 根据参数路由和关键字参数获取参数
        args = route_args(self.args_route, kwargs, len(blocks))
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        layers_and_args = list(zip(blocks, args))

        if self.training and self.layer_dropout > 0:
            # 如果处于训练状态且存在层丢弃率,则执行层丢弃
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
            blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))

        # 调用自定义的可逆函数进行前向传播
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上分割输出并求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\sinkhorn-transformer\sinkhorn_transformer\sinkhorn_transformer.py

# 导入 math 模块
import math
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 子模块
from torch import nn
# 从 operator 模块中导入 mul 函数
from operator import mul
# 从 math 模块中导入 gcd 函数
from math import gcd
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 inspect 模块中导入 isfunction 函数
from inspect import isfunction
# 从 functools 模块中导入 partial, wraps, reduce 函数
from functools import partial, wraps, reduce

# 导入自定义模块
from local_attention import LocalAttention
from axial_positional_embedding import AxialPositionalEmbedding
from product_key_memory import PKM
from sinkhorn_transformer.reversible import ReversibleSequence, SequentialSequence

# 辅助函数

# 定义返回输入的函数
def identity(x, *args, **kwargs): return x

# 定义返回默认值的函数
def default(x, d):
    if x is None:
        return d if not isfunction(d) else d()
    return x

# 将输入转换为元组的函数
def cast_tuple(x):
    return x if isinstance(x, tuple) else (x,)

# 判断一个数是否能被另一个数整除的函数
def divisible_by(num, divisor):
    return num % divisor == 0

# 计算多个数的最小公倍数的函数
def lcm(*numbers):
    return int(reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1)

# 判断多个元素是否都为 None 的函数
def all_none(*arr):
    return all(el is None for el in arr)

# 缓存函数的装饰器
def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, **kwargs):
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# 将张量向左旋转的函数
def rotate_left(t, n, dim=0):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(n, None))
    r = (*pre_slices, slice(0, n))
    return torch.cat((t[l], t[r]), dim=dim)

# 将张量向右旋转的函数
def rotate_right(t, n, dim=0):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(-n, None))
    r = (*pre_slices, slice(None, -n))
    return torch.cat((t[l], t[r]), dim=dim)

# 合并张量的维度的函数
def merge_dims(ind_from, ind_to, tensor):
    shape = list(tensor.shape)
    arr_slice = slice(ind_from, ind_to + 1)
    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
    return tensor.reshape(*shape)

# 合并张量的头部的函数
def merge_heads(h, v):
    b, t, d = v.shape
    return v.view(b, t, h, -1).transpose(1, 2).reshape(b, h, t, -1)

# 分割张量的头部的函数
def split_heads(h, v):
    *_, t, d = v.shape
    return v.view(-1, h, t, d).transpose(1, 2).reshape(-1, t, d * h)

# 在指定索引处分割张量的函数
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 将张量分桶的函数
def bucket(buckets, t, dim=1):
    shape = list(t.shape)
    shape[dim:dim+1] = [buckets, -1]
    return t.reshape(*shape)

# 将分桶后的张量还原的函数
def unbucket(t, dim=1):
    shape = list(t.shape)
    shape[dim:dim+2] = [-1]
    return t.reshape(*shape)

# 采样 Gumbel 分布的函数
def sample_gumbel(shape, device, dtype, eps=1e-6):
    u = torch.empty(shape, device=device, dtype=dtype).uniform_(0, 1)
    return -log(-log(u, eps), eps)

# Sinkhorn 排序算子的函数
def sinkhorn_sorting_operator(r, n_iters=8):
    n = r.shape[1]
    for _ in range(n_iters):
        r = r - torch.logsumexp(r, dim=2, keepdim=True)
        r = r - torch.logsumexp(r, dim=1, keepdim=True)
    return torch.exp(r)

# Gumbel Sinkhorn 函数
def gumbel_sinkhorn(r, n_iters=8, temperature=0.7):
    r = log(r)
    gumbel = sample_gumbel(r.shape, r.device, r.dtype)
    r = (r + gumbel) / temperature
    return sinkhorn_sorting_operator(r, n_iters)

# 重新排序分桶后的张量的函数
def reorder_buckets(t, r):
    return torch.einsum('buv,bvtd->butd', r, t)

# 对张量取对数的函数
def log(t, eps = 1e-6):
    return torch.log(t + eps)

# 获取张量最大负值的函数
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 沿指定维度计算累积平均值的函数
def cumavg(t, dim):
    r = torch.arange(1, t.shape[dim] + 1, device=t.device, dtype=t.dtype)
    expand_slice = [None] * len(t.shape)
    expand_slice[dim] = slice(None, None)
    return t.cumsum(dim=dim) / r[tuple(expand_slice)]

# 批量索引选择的函数
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

# 在指定维度扩展张量的函数
def expand_dim(t, dim, k):
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 扩展批次并合并头部的函数
def expand_batch_and_merge_head(b, t):
    shape = list(t.squeeze(0).shape)
    t = expand_dim(t, 0, b)
    shape[0] = shape[0] * b
    return t.reshape(*shape)

# 可微分的 Top-K 函数
def differentiable_topk(x, k, temperature=1.):
    *_, n, dim = x.shape
    topk_tensors = []
    # 遍历 k 次,每次生成一个 topk tensor
    for i in range(k):
        # 判断是否是最后一次循环
        is_last = i == (k - 1)
        # 对输入 x 进行 softmax 操作,然后取最大的值和对应的索引
        values, indices = (x / temperature).softmax(dim=-1).topk(1, dim=-1)
        # 根据索引和值生成一个新的 tensor,并替换原来的值
        topks = torch.zeros_like(x).scatter_(-1, indices, values)
        # 将生成的 topk tensor 添加到列表中
        topk_tensors.append(topks)
        # 如果不是最后一次循环,则将对应索引的值设为负无穷
        if not is_last:
            x.scatter_(-1, indices, float('-inf'))

    # 将所有生成的 topk tensor 拼接在一起
    topks = torch.cat(topk_tensors, dim=-1)
    # 将拼接后的 tensor 重新 reshape 成指定形状
    return topks.reshape(*_, k * n, dim)
# 定义一个名为 Chunk 的类,继承自 nn.Module
class Chunk(nn.Module):
    # 初始化方法,接受参数 chunks、fn 和 along_dim,默认值为 -1
    def __init__(self, chunks, fn, along_dim = -1):
        super().__init__()
        # 设置对象属性
        self.dim = along_dim
        self.chunks = chunks
        self.fn = fn

    # 前向传播方法
    def forward(self, x):
        # 将输入 x 按照指定维度分块
        chunks = x.chunk(self.chunks, dim = self.dim)
        # 对每个分块应用函数 fn,并在指定维度上拼接结果
        return torch.cat([self.fn(c) for c in chunks], dim = self.dim)

# 定义一个名为 GELU_ 的类,继承自 nn.Module
class GELU_(nn.Module):
    # 前向传播方法
    def forward(self, x):
        # 计算 GELU 激活函数的值
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 类,则使用 nn.GELU,否则使用 GELU_
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义一个名为 FeedForward 的类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化方法,接受参数 dim、mult、dropout、activation 和 glu,默认值为 False
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        super().__init__()
        # 设置对象属性
        activation = default(activation, GELU)

        self.glu = glu
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        self.act = activation()
        self.dropout = nn.Dropout(dropout)
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播方法
    def forward(self, x, **kwargs):
        if not self.glu:
            x = self.w1(x)
            x = self.act(x)
        else:
            x, v = self.w1(x).chunk(2, dim=-1)
            x = self.act(x) * v

        x = self.dropout(x)
        x = self.w2(x)
        return x

# 定义一个名为 ReZero 的类,继承自 nn.Module
class ReZero(nn.Module):
    # 初始化方法,接受参数 fn
    def __init__(self, fn):
        super().__init__()
        # 定义可学习参数 g
        self.g = nn.Parameter(torch.zeros(1))
        self.fn = fn

    # 前向传播方法
    def forward(self, x, **kwargs):
        # 返回 fn 函数的结果乘以可学习参数 g
        return self.fn(x, **kwargs) * self.g

# 定义一个名为 PreNorm 的类,继承自 nn.Module
class PreNorm(nn.Module):
    # 初始化方法,接受参数 norm_class、dim 和 fn
    def __init__(self, norm_class, dim, fn):
        super().__init__()
        # 实例化规范化层对象
        self.norm = norm_class(dim)
        self.fn = fn

    # 前向传播方法
    def forward(self, x, **kwargs):
        # 对输入 x 进行规范化
        x = self.norm(x)
        # 返回 fn 函数的结果
        return self.fn(x, **kwargs)

# 定义一个名为 ProjectInOut 的类,继承自 nn.Module
class ProjectInOut(nn.Module):
    # 初始化方法,接受参数 fn、dim_in、dim_out 和 project_out,默认为 True
    def __init__(self, fn, dim_in, dim_out, project_out = True):
        super().__init__()
        # 设置对象属性
        self.fn = fn
        self.project_in = nn.Linear(dim_in, dim_out)
        self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity

    # 前向传播方法
    def forward(self, x, **kwargs):
        # 对输入 x 进行投影
        x = self.project_in(x)
        # 对投影后的结果应用函数 fn
        x = self.fn(x, **kwargs)
        # 对结果进行逆投影
        x = self.project_out(x)
        return x

# 定义一个名为 SimpleSortNet 的类,继承自 nn.Module
class SimpleSortNet(nn.Module):
    # 初始化方法,接受参数 heads、bucket_size、max_buckets、dim、non_permutative、temperature 和 sinkhorn_iter
    def __init__(self, heads, bucket_size, max_buckets, dim, non_permutative, temperature, sinkhorn_iter):
        super().__init__()
        # 设置对象属性
        self.dim = dim
        self.heads = heads
        self.max_buckets = max_buckets
        self.bucket_size = bucket_size
        self.non_permutative = non_permutative
        self.temperature = temperature
        self.sinkhorn_iter = sinkhorn_iter
        self.linear = nn.Parameter(torch.randn(1, heads, dim, max_buckets))
        self.act = nn.ReLU()

    # 前向传播方法
    def forward(self, q, k, topk=1):
        bh, t, _ = q.shape
        b = bh // self.heads
        buckets = t // self.bucket_size

        b_q, b_k = bucket(buckets, q), bucket(buckets, k)
        x = torch.cat((b_q.sum(dim=2), b_k.sum(dim=2)), dim=-1)

        W = expand_batch_and_merge_head(b, self.linear)
        R = self.act(x @ W)

        return differentiable_topk(R, k=topk, temperature=self.temperature) if self.non_permutative else gumbel_sinkhorn(R, self.sinkhorn_iter, self.temperature)

# 定义一个名为 AttentionSortNet 的类,继承自 nn.Module
class AttentionSortNet(nn.Module):
    # 初始化方法,接受参数 heads、bucket_size、kv_bucket_size、dim、non_permutative、temperature、sinkhorn_iter 和 n_sortcut,默认为 0
    def __init__(self, heads, bucket_size, kv_bucket_size, dim, non_permutative, temperature, sinkhorn_iter, n_sortcut = 0):
        super().__init__()
        # 设置对象属性
        self.heads = heads
        self.bucket_size = bucket_size
        self.kv_bucket_size = kv_bucket_size
        self.dim = dim
        self.non_permutative = non_permutative
        self.temperature = temperature
        self.sinkhorn_iter = sinkhorn_iter
        self.n_sortcut = n_sortcut
    # 定义一个前向传播函数,接受查询向量 q、键向量 k 和 topk 参数,默认为 1
    def forward(self, q, k, topk=1):
        # 解构赋值,获取查询向量 q 的形状信息
        bh, *_, bucket_size, kv_bucket_size, device, dtype, dim = *q.shape, self.bucket_size, self.kv_bucket_size, q.device, q.dtype, self.dim
        # 计算每个头部的批次大小
        b = bh // self.heads

        # 计算查询向量 q 的桶数
        buckets = q.shape[1] // bucket_size
        # 计算键向量 k 的桶数
        kv_buckets = k.shape[1] // kv_bucket_size

        # 将查询向量 q 分桶,如果 n_sortcut 为 0 则只有一个桶,否则按照桶大小分桶
        b_q = bucket(buckets, q) if self.n_sortcut == 0 else bucket(1, q)
        # 将键向量 k 分桶
        b_k = bucket(kv_buckets, k)

        # 计算查询向量 q 的均值
        sq = b_q.mean(dim=2)
        # 计算键向量 k 的均值
        sk = b_k.mean(dim=2)

        # 计算 R 矩阵,使用 einsum 函数计算点积并乘以缩放因子
        R = torch.einsum('bie,bje->bij', sq, sk).to(q) * (dim ** -0.5)

        # 如果是非排列不变的注意力机制
        if self.non_permutative:
            # 如果 n_sortcut 为 0,则返回前 k 个最大值,否则返回前 n_sortcut 个最大值
            k = topk if self.n_sortcut == 0 else self.n_sortcut
            return differentiable_topk(R, k=k)

        # 如果是排列不变的注意力机制,则使用 Gumbel Sinkhorn 进行计算
        return gumbel_sinkhorn(F.relu(R), self.sinkhorn_iter, self.temperature)
# 定义 SinkhornAttention 类,继承自 nn.Module
class SinkhornAttention(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, bucket_size, dim, dim_heads, heads, max_seq_len, temperature = 0.75, non_permutative = True, sinkhorn_iter = 7, n_sortcut = 0, dropout = 0., kv_bucket_size = None, use_simple_sort_net = False, n_top_buckets = 1):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化各个参数
        self.bucket_size = bucket_size
        # 如果 kv_bucket_size 为 None,则使用 bucket_size
        self.kv_bucket_size = default(kv_bucket_size, bucket_size)

        self.dim = dim
        self.heads = heads
        self.temperature = temperature
        self.non_permutative = non_permutative
        self.sinkhorn_iter = sinkhorn_iter
        self.n_sortcut = n_sortcut

        # 根据 use_simple_sort_net 的值选择不同的排序网络
        if use_simple_sort_net:
            self.sort_net = SimpleSortNet(heads, self.kv_bucket_size, max_seq_len // self.kv_bucket_size, dim_heads * 2, non_permutative = non_permutative, temperature = temperature, sinkhorn_iter = sinkhorn_iter)
        else:
            self.sort_net = AttentionSortNet(heads, self.bucket_size, self.kv_bucket_size, dim_heads, non_permutative = non_permutative, temperature = temperature, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut)

        self.n_top_buckets = n_top_buckets
        # 初始化一个 dropout 层
        self.dropout = nn.Dropout(dropout)
    # 定义前向传播函数,接受查询、键、值以及查询和键值的掩码作为输入
    def forward(self, q, k, v, q_mask = None, kv_mask = None):
        # 解包变量,获取批次大小、头数、序列长度、隐藏维度、前n个桶、维度、头数、温度、桶大小、键值桶大小、设备
        b, h, t, d_h, n_top, d, heads, temperature, bucket_size, kv_bucket_size, device = *q.shape, self.n_top_buckets, self.dim, self.heads, self.temperature, self.bucket_size, self.kv_bucket_size, q.device

        # 计算批次头数
        bh = b * h
        # 计算查询的桶数和键值的桶数
        buckets = q.shape[2] // bucket_size
        kv_buckets = k.shape[2] // kv_bucket_size
        # 确保前n个桶不超过键值桶数
        n_top = min(n_top, kv_buckets)

        # 合并批次和头维度
        merge_batch_head = partial(merge_dims, 0, 1)
        q, k, v = map(merge_batch_head, (q, k, v))

        # 桶化查询、键、值
        b_q = bucket(buckets, q)
        b_k, b_v = map(partial(bucket, kv_buckets), (k, v))

        bsz = b_k.shape[2]

        # 使用简单排序网络计算重新排序矩阵R
        R = self.sort_net(q, k, topk=n_top)
        R = R.type_as(q).to(q)

        # 拼接重新排序后的桶
        b_k_r = reorder_buckets(b_k, R)
        b_v_r = reorder_buckets(b_v, R)

        # 选择前n个排名的桶作为所有查询桶的顶部n个桶
        if self.n_sortcut > 0:
            b_k_r = b_k_r[:, 0:self.n_sortcut].reshape(bh, 1, -1, d_h)
            b_v_r = b_v_r[:, 0:self.n_sortcut].reshape(bh, 1, -1, d_h)
            b_k_r = expand_dim(b_k_r, 1, buckets)
            b_v_r = expand_dim(b_v_r, 1, buckets)
        else:
            b_k_r = b_k_r.reshape(bh, buckets, -1, d_h)
            b_v_r = b_k_r.reshape(bh, buckets, -1, d_h)

        # 拼接查询桶和键值桶
        b_k = torch.cat((b_k_r, b_k), dim=2) if buckets == kv_buckets else b_k_r
        b_v = torch.cat((b_v_r, b_v), dim=2) if buckets == kv_buckets else b_v_r

        # 计算点积
        dots = torch.einsum('buie,buje->buij', b_q, b_k) * (d_h ** -0.5)

        # 掩码
        mask_value = max_neg_value(dots)

        # 如果查询和键值掩码不全为空
        if not all_none(q_mask, kv_mask):
            q_mask = default(q_mask, lambda: torch.ones((b, t), device=device).bool())
            kv_mask = default(kv_mask, q_mask)
            mq, mk = bucket(buckets, q_mask), bucket(kv_buckets, kv_mask)
            expand_head_and_merge_into_batch = lambda x: merge_dims(0, 1, expand_dim(x.unsqueeze(1), 1, h))
            mq, mk = map(expand_head_and_merge_into_batch, (mq, mk))

            mk_r = batched_index_select(mk, R.abs().argmax(dim=-1))

            if self.n_sortcut > 0:
                mk_r = mk_r[:, 0:self.n_sortcut].reshape(-1, 1, bsz * self.n_sortcut)
                mk_r = expand_dim(mk_r, 1, buckets)
            else:
                mk_r = mk_r.reshape(bh, buckets, -1)

            mk = torch.cat((mk_r, mk), dim=2) if buckets == kv_buckets else mk_r
            mask = mq[:, :, :, None] * mk[:, :, None, :]
            dots.masked_fill_(~mask, mask_value)
            del mask            

        # 注意力
        dots = dots.softmax(dim=-1)
        dots = self.dropout(dots)

        out = torch.einsum('buij,buje->buie', dots, b_v)
        out = unbucket(out)

        out = out.reshape(b, h, t, d_h)
        return out
# 定义函数,生成一个掩码矩阵,用于重新排序
def mask_reordering_matrix(R, topk, temperature):
    # 获取矩阵的列数,即桶的数量
    buckets = R.shape[1]

    # 获取矩阵中的最大值,用于生成掩码
    mask_value = max_neg_value(R)
    # 创建一个与 R 相同形状的全零张量,用于存储掩码
    mask = torch.zeros(R.shape, device=R.device).bool()
    # 获取上三角矩阵的索引
    i, j = torch.triu_indices(buckets, buckets)
    # 将掩码应用到 R 上,将指定位置的值替换为 mask_value
    mask[:, i, j + topk] = True

    # 使用掩码将 R 中的值替换为 mask_value
    R.masked_fill_(mask, mask_value)
    # 返回经过不同iable_topk 函数处理后的结果
    return differentiable_topk(R, topk, temperature)

# 定义一个简单的排序网络模型
class CausalSimpleSortNet(nn.Module):
    def __init__(self, heads, bucket_size, max_buckets, n_top_buckets, dim, temperature):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.bucket_size = bucket_size
        self.max_buckets = max_buckets
        self.n_top_buckets = n_top_buckets
        self.temperature = temperature
        # 初始化线性层参数
        self.linear = nn.Parameter(torch.randn(1, heads, dim, max_buckets + n_top_buckets))
        # 初始化激活函数
        self.act = nn.LeakyReLU()

    # 前向传播函数
    def forward(self, q, k, topk=1):
        # 获取张量的维度信息
        bh, *_, h, max_buckets = *q.shape, self.heads, self.max_buckets
        b = bh // h
        # 计算桶的数量
        buckets = k.shape[1] // self.bucket_size

        # 对 k 进行处理,将其累积平均后进行桶化
        k_r = torch.cat((cumavg(k, dim=1), k), dim=-1)
        k_r = bucket(buckets, k_r)

        # 对于因果排序网络,取每个桶的第一个标记以防止未来信息泄漏到过去
        x = k_r[:, :, 0]

        # 扩展线性层参数并合并头部
        W = expand_batch_and_merge_head(b, self.linear)
        R = self.act(x @ W)
        R = R[:, 0:buckets, 0:(buckets + self.n_top_buckets)]

        # 返回经过 mask_reordering_matrix 函数处理后的结果
        return mask_reordering_matrix(R, topk, self.temperature)

# 定义一个因果注意力排序网络模型
class CausalAttentionSortNet(nn.Module):
    def __init__(self, heads, bucket_size, dim, temperature):
        super().__init__()
        self.heads = heads
        self.bucket_size = bucket_size
        self.dim = dim
        self.temperature = temperature

    # 前向传播函数
    def forward(self, q, k, topk=1):
        bh, *_, h, dim = *q.shape, self.heads, self.dim

        b = bh // h
        buckets = q.shape[1] // self.bucket_size
        kv_buckets = k.shape[1] // self.bucket_size

        q_r = bucket(buckets, cumavg(q, dim=1))
        k_r = bucket(kv_buckets, cumavg(k, dim=1))

        sq = q_r[:, :, 0]
        sk = k_r.sum(dim=2)
        sk = F.pad(sk, (0, 0, topk, 0))

        R = torch.einsum('bie,bje->bij', sq, sk) * (dim ** -0.5)
        # 返回经过 mask_reordering_matrix 函数处理后的结果
        return mask_reordering_matrix(R, topk, self.temperature)

# 在指定索引处对张量进行分割,并对分割后的部分应用函数后再拼接
def apply_fn_after_split_ind(dim, ind, fn, t):
    l, r = split_at_index(dim, ind, t)
    return torch.cat((l, fn(r)), dim=dim)

# 定义 Sinkhorn 因果注意力模型
class SinkhornCausalAttention(nn.Module):
    def __init__(self, bucket_size, dim, dim_heads, heads, max_seq_len, dropout = 0., kv_bucket_size = None, use_simple_sort_net = False, n_top_buckets = 2, temperature = 1.):
        super().__init__()
        assert kv_bucket_size is None or bucket_size == kv_bucket_size, 'different bucketing for key/values for causal reordering not supported yet'

        self.dim = dim
        self.heads = heads
        self.bucket_size = bucket_size

        # 用于第一个桶的学习到的空键/值(过去没有内容需要排序)
        self.null_keys = nn.Parameter(torch.randn(heads, 1, dim_heads))
        self.null_values = nn.Parameter(torch.randn(heads, 1, dim_heads))

        # 根据 use_simple_sort_net 参数选择不同的排序网络模型
        if use_simple_sort_net:
            self.sort_net = CausalSimpleSortNet(heads, bucket_size, max_seq_len // bucket_size, n_top_buckets, dim_heads * 2, temperature)
        else:
            self.sort_net = CausalAttentionSortNet(heads, bucket_size, dim_heads, temperature)

        self.n_top_buckets = n_top_buckets
        self.dropout = nn.Dropout(dropout)
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)以及查询掩码(q_mask)和键值掩码(kv_mask)作为输入参数
    def forward(self, q, k, v, q_mask = None, kv_mask = None):
        # 获取输入张量的形状信息
        b, h, t, d_h, n_top, d, bsz, device = *q.shape, self.n_top_buckets, self.dim, self.bucket_size, q.device

        # 计算一些常用的值
        bh = b * h
        hh = h // 2
        buckets = t // bsz
        n_top = min(n_top, buckets)

        # 定义一个切片,用于获取后半部分的头信息
        hh_slice = (slice(None), slice(hh, None))

        # 定义一个部分函数,用于对输入张量进行旋转操作
        rotate_fn = partial(apply_fn_after_split_ind, 1, hh, lambda t: rotate_left(t, bsz-1, dim=2))
        q, k, v = map(rotate_fn, (q, k, v))

        # 合并批次和头信息
        merge_batch_head = partial(merge_dims, 0, 1)
        q, k, v = map(merge_batch_head, (q, k, v))

        # 对查询、键、值进行分桶操作
        b_q, b_k, b_v = map(partial(bucket, buckets), (q, k, v))

        # 计算排序矩阵R
        R = self.sort_net(q, k, topk=n_top)
        R = R.type_as(q).to(q)

        # 添加空键/值
        b_null_k = self.null_keys[None, :, None, :, :].expand(b, h, n_top, bsz, -1).reshape(bh, n_top, bsz, -1).to(k)
        b_null_v = self.null_values[None, :, None, :, :].expand(b, h, n_top, bsz, -1).reshape(bh, n_top, bsz, -1).to(v)

        b_k_r = torch.cat((b_null_k, b_k), dim=1)
        b_v_r = torch.cat((b_null_v, b_v), dim=1)

        # 重新排序桶以便进行本地注意力计算
        b_k_r = reorder_buckets(b_k_r, R)
        b_v_r = reorder_buckets(b_v_r, R)

        b_k_r = b_k_r.reshape(bh, buckets, bsz * n_top, -1)
        b_v_r = b_v_r.reshape(bh, buckets, bsz * n_top, -1)

        # 将原始桶本身连接到重新排序的桶中
        b_k = torch.cat((b_k_r, b_k), dim=2)
        b_v = torch.cat((b_v_r, b_v), dim=2)

        # 计算点积
        dots = torch.einsum('buie,buje->buij', b_q, b_k) * (d_h ** -0.5)

        # 定义掩码值
        mask_value = max_neg_value(q)

        # 如果存在查询掩码和键值掩码,则进行掩码操作
        if not all_none(q_mask, kv_mask):
            q_mask = default(q_mask, lambda: torch.ones((b, t), device=device).bool())
            kv_mask = default(kv_mask, q_mask)

            expand_head = lambda x: x.unsqueeze(1).repeat(1, h, 1)
            q_mask, kv_mask = map(expand_head, (q_mask, kv_mask))

            q_mask[hh_slice] = rotate_left(q_mask[hh_slice], bsz-1, dim=2)
            kv_mask[hh_slice] = rotate_left(kv_mask[hh_slice], bsz-1, dim=2)

            q_mask, kv_mask = map(lambda x: merge_dims(0, 1, x), (q_mask, kv_mask))
            mq, mk = bucket(buckets, q_mask), bucket(buckets, kv_mask)

            mk_with_null = F.pad(mk, (0, 0, 2, 0), value=True)
            mk_r = batched_index_select(mk_with_null, R.abs().argmax(dim=-1))

            mk_r = mk_r.reshape(bh, buckets, -1)
            mk = torch.cat((mk_r, mk), dim=2)
            mask = mq[:, :, :, None] * mk[:, :, None, :]
            dots.masked_fill_(~mask, mask_value)
            del mask

        # 为半头旋转进行掩码操作
        shift = n_top * bsz
        total_shift = shift + bsz

        mask = torch.ones((b, h, buckets, bsz, total_shift), device=device).bool()
        i, j = torch.triu_indices(bsz, bsz, 1)
        mask[:, :, :, i, j + shift] = False
        mask[:, hh:, -1, 0:shift, 0:shift+1] = False
        mask[:, hh:, -1, 0, 0:shift+1] = True
        mask = mask.reshape(b * h, buckets, bsz, total_shift)

        dots.masked_fill_(~mask, mask_value)
        del mask

        # 注意力计算
        dots = dots.softmax(dim=-1)
        dots = self.dropout(dots)

        out = torch.einsum('buij,buje->buie', dots, b_v)
        out = unbucket(out)

        out = out.reshape(b, h, t, d_h)
        out = apply_fn_after_split_ind(1, hh, lambda t: rotate_right(t, bsz-1, dim=2), out)
        return out
# 定义 SinkhornSelfAttention 类,继承自 nn.Module
class SinkhornSelfAttention(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, dim, bucket_size, max_seq_len, heads = 8, dim_head = None, kv_bucket_size = None, causal = False, non_permutative = True, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, attn_dropout = 0., dropout = 0., context_only = False, use_simple_sort_net = False, n_local_attn_heads = 0, n_top_buckets = 1):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保 dim_head 不为空或者 dim 可以被 heads 整除
        assert dim_head or divisible_by(dim, heads), f'If dim_head is None, dimension {dim} must be divisible by the number of heads {heads}'
        # 断言确保 sortcut 只能用于非因果注意力
        assert not (causal and n_sortcut > 0), 'sortcut can only be used for non causal attention'
        # 断言确保 context_only 自注意力层不能是因果的
        assert not (causal and context_only), 'context only self attention layer cannot be causal'
        # 断言确保本地注意力头数不超过总头数
        assert n_local_attn_heads <= heads, 'number of local attention heads cannot exceed total heads'

        # 如果 dim_head 为空,则设置为 dim 除以 heads
        dim_head = default(dim_head, dim // heads)
        # 计算 dim_heads
        dim_heads = dim_head * heads
        self.dim_head = dim_head

        self.heads = heads
        self.bucket_size = bucket_size
        self.kv_bucket_size = default(kv_bucket_size, bucket_size)

        self.context_only = context_only
        # 将输入转换为查询向量
        self.to_q = nn.Linear(dim, dim_heads, bias=False)
        # 如果不是仅上下文自注意力,则将输入转换为键值对
        self.to_kv = nn.Linear(dim, dim_heads * 2, bias=False) if not context_only else None

        # 将输出转换为线性层
        self.to_out = nn.Linear(dim_heads, dim)

        self.n_local_attn_heads = n_local_attn_heads
        # 创建本地注意力对象
        self.local_attention = LocalAttention(bucket_size, causal, dropout = attn_dropout, look_forward=(1 if not causal else 0))

        # 计算 Sinkhorn 注意力头数
        sink_heads = heads - n_local_attn_heads

        # 如果是因果的,则创建 SinkhornCausalAttention 对象,否则创建 SinkhornAttention 对象
        if causal:
            attn = SinkhornCausalAttention(bucket_size, dim, dim_head, sink_heads, max_seq_len, dropout = attn_dropout, kv_bucket_size = kv_bucket_size, use_simple_sort_net = use_simple_sort_net, n_top_buckets = n_top_buckets, temperature = temperature)
        else:
            attn = SinkhornAttention(bucket_size, dim, dim_head, sink_heads, max_seq_len, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, dropout = attn_dropout, kv_bucket_size = kv_bucket_size, use_simple_sort_net = use_simple_sort_net, n_top_buckets = n_top_buckets)

        # 设置 Sinkhorn 注意力对象
        self.sinkhorn_attention = attn

        # 创建丢弃层
        self.dropout = nn.Dropout(dropout)

    # 前向传播函数,接受输入 x、输入掩码 input_mask、上下文 context 和上下文掩码 context_mask
    def forward(self, x, input_mask = None, context = None, context_mask = None):
        # 获取输入 x 的形状信息
        b, t, d, h, dh, l_h = *x.shape, self.heads, self.dim_head, self.n_local_attn_heads
        # 断言确保序列 t 可以被 bucket_size 整除
        assert divisible_by(t, self.bucket_size), f'sequence {t} needs to be divisible by bucket size {self.bucket_size}'
        # 断言确保如果是仅上下文自注意力,则必须提供上下文键/值
        assert not (self.context_only and context is None), 'context key / values must be supplied if context self attention layer'
        # 断言确保如果提供上下文,则上下文的批次和维度与解码器相同
        assert not (context is not None and (context.shape[0], context.shape[2]) !=  (b, d)), 'contextual key / values must have the same batch and dimensions as the decoder'

        # 将输入转换为查询向量
        q = self.to_q(x)

        # 如果不是仅上下文自注意力,则将输入转换为键值对,并根据维度切分
        kv = self.to_kv(x).chunk(2, dim=-1) if not self.context_only else (context, context)
        kv_mask = input_mask if not self.context_only else context_mask

        # 断言确保键/值序列可以被键/值 bucket_size 整除
        assert divisible_by(kv[0].shape[1], self.kv_bucket_size), 'key/value sequences need to be divisible by key/value bucket size'

        # 将查询向量和键值对合并
        qkv = (q, *kv)
        merge_heads_fn = partial(merge_heads, h)
        q, k, v = map(merge_heads_fn, qkv)

        # 部分函数,用于在特定索引处切分张量
        split_index_fn = partial(split_at_index, 1, l_h)
        (lq, q), (lk, k), (lv, v) = map(split_index_fn, (q, k, v))
        # 检查是否存在本地和 Sinkhorn 注意力
        has_local, has_sinkhorn = map(lambda x: x.shape[1] > 0, (lq, q))

        out = []

        # 如果存在本地注意力,则将结果添加到输出列表中
        if has_local > 0:
            out.append(self.local_attention(lq, lk, lv, input_mask = input_mask))

        # 如果存在 Sinkhorn 注意力,则将结果添加到输出列表中
        if has_sinkhorn > 0:
            out.append(self.sinkhorn_attention(q, k, v, q_mask = input_mask, kv_mask = kv_mask))

        # 在指定维度上连接输出列表中的张量
        out = torch.cat(out, dim=1)
        # 将输出张量按头数拆分
        out = split_heads(h, out)
        # 将输出转换为指定维度
        out = self.to_out(out)
        # 应用丢弃层
        out = self.dropout(out)
        return out

# 定义 SinkhornTransformer 类,继承自 nn.Module
class SinkhornTransformer(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(self, dim, depth, max_seq_len = None, causal = False, heads = 8, dim_head = None, bucket_size = 64, kv_bucket_size = None, context_bucket_size = None, non_permutative = False, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, reversible = False, ff_chunks = 1, ff_dropout = 0., attn_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., weight_tie = False, ff_glu = False, use_simple_sort_net = None, receives_context = False, context_n_sortcut = 2, n_local_attn_heads = 0, use_rezero = False, n_top_buckets = 1,  pkm_layers = tuple(), pkm_num_keys = 128):
        # 调用父类的初始化函数
        super().__init__()
        # 创建空的模型层列表
        layers = nn.ModuleList([])

        # 设置默认的 kv_bucket_size 和 context_bucket_size
        kv_bucket_size = default(kv_bucket_size, bucket_size)
        context_bucket_size = default(context_bucket_size, bucket_size)

        # 定义获取注意力层、前馈层和 PKM 层的 lambda 函数
        get_attn = lambda: SinkhornSelfAttention(dim, bucket_size, max_seq_len, causal = causal, heads = heads, dim_head = dim_head, kv_bucket_size = kv_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, attn_dropout = attn_dropout, dropout = attn_layer_dropout, use_simple_sort_net = use_simple_sort_net, n_local_attn_heads = n_local_attn_heads, n_top_buckets = n_top_buckets)
        get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, glu = ff_glu), along_dim=1)
        get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys)

        # 定义获取上下文注意力层和上下文前馈层的 lambda 函数
        get_attn_context = lambda: SinkhornSelfAttention(dim, bucket_size, max_seq_len, context_only = True, heads = heads, dim_head = dim_head, kv_bucket_size = context_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = context_n_sortcut, temperature = temperature, attn_dropout = attn_dropout, dropout = attn_layer_dropout, n_top_buckets = n_top_buckets)
        get_ff_context = lambda: FeedForward(dim, dropout = ff_dropout, glu = ff_glu)

        # 如果权重共享为真,则缓存获取注意力层和前馈层的函数
        if weight_tie:
            get_attn, get_attn_context, get_ff, get_ff_context = map(cache_fn, (get_attn, get_attn_context, get_ff, get_ff_context))

        # 根据是否使用 PKM 层,选择获取并行函数
        for ind in range(depth):
            layer_num = ind + 1
            use_pkm = layer_num in pkm_layers

            get_parallel_fn = get_ff if not use_pkm else get_pkm

            # 将注意力层和并行函数添加到模型层列表中
            layers.append(nn.ModuleList([
                fn_wrapper(get_attn()),
                fn_wrapper(get_parallel_fn())
            ]))

            # 如果不接收上下文,则继续下一个循环
            if not receives_context:
                continue

            # 将上下文注意力层和上下文前馈层添加到模型层列表中
            layers.append(nn.ModuleList([
                fn_wrapper(get_attn_context()),
                fn_wrapper(get_ff_context())
            ]))

        # 根据是否可逆选择执行类型
        execute_type = ReversibleSequence if reversible else SequentialSequence

        # 设置上下文路由和注意力路由
        attn_context_layer = ((True, False),) if receives_context else tuple()
        route_attn = ((True, False), *attn_context_layer) * depth
        route_context = ((False, False), *attn_context_layer) * depth

        context_route_map = {'context': route_context, 'context_mask': route_context} if receives_context else {}
        attn_route_map = {'input_mask': route_attn}

        # 创建模型层序列,设置参数路由和层丢弃率
        self.layers = execute_type(layers, args_route = {**context_route_map, **attn_route_map}, layer_dropout = layer_dropout)
        self.receives_context = receives_context

        # 设置最大序列长度、填充到桶大小、上下文桶大小和是否固定长度
        self.max_seq_len = max_seq_len
        self.pad_to_bucket_size = lcm(bucket_size, kv_bucket_size)
        self.context_bucket_size = context_bucket_size
        self.is_fixed_length = use_simple_sort_net and not causal

        # 如果不使用注意力排序且不是因果的,强制固定序列长度
        assert not (self.is_fixed_length and self.max_seq_len is None), 'maximum sequence length must be specified if length is fixed'
    # 定义一个前向传播函数,接受输入 x 和其他关键字参数
    def forward(self, x, **kwargs):
        # 如果模型要求输入是固定长度的序列,并且输入 x 的长度不等于最大序列长度,则抛出断言错误
        assert not (self.is_fixed_length and x.shape[1] != self.max_seq_len), f'you must supply a sequence of length {self.max_seq_len}'
        # 如果关键字参数中包含 'context' 且模型接收上下文信息,则通过,否则抛出断言错误
        assert ('context' not in kwargs or self.receives_context), 'needs to be initted with receives_context True if passing contextual key / values'
        # 调用模型的 layers 方法进行前向传播,并返回结果
        return self.layers(x, **kwargs)
class SinkhornTransformerLM(nn.Module):
    # 定义 SinkhornTransformerLM 类,继承自 nn.Module
    def __init__(self, num_tokens, dim, max_seq_len, depth, heads = 8, dim_head = None, bucket_size = 64, kv_bucket_size = None, context_bucket_size = None, causal = False, non_permutative = True, sinkhorn_iter = 5, n_sortcut = 0, temperature = 0.75, reversible = False, ff_chunks = 1, ff_glu = False, return_embeddings = False, ff_dropout = 0., attn_dropout = 0., attn_layer_dropout = 0., layer_dropout = 0., emb_dropout = 0., weight_tie = False, emb_dim = None, use_simple_sort_net = None, receives_context = False, context_n_sortcut = 0, n_local_attn_heads = 0, use_rezero = False, n_top_buckets = 2, pkm_layers = tuple(), pkm_num_keys = 128):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数

        emb_dim = default(emb_dim, dim)
        # 如果 emb_dim 为 None,则使用 dim

        self.max_seq_len = max_seq_len
        # 设置最大序列长度

        self.to_token_emb = nn.Embedding(num_tokens, emb_dim)
        # 创建一个嵌入层,将输入的 token 映射为嵌入向量
        self.axial_pos_emb = AxialPositionalEmbedding(emb_dim, axial_shape = (max_seq_len // bucket_size, bucket_size))
        # 创建轴向位置编码层
        self.emb_dropout = nn.Dropout(emb_dropout)
        # 创建一个丢弃层,用于嵌入向量的丢弃

        self.sinkhorn_transformer = SinkhornTransformer(dim, depth, max_seq_len = max_seq_len, causal = causal, heads = heads, dim_head = dim_head, bucket_size = bucket_size, kv_bucket_size = kv_bucket_size, context_bucket_size = context_bucket_size, non_permutative = non_permutative, sinkhorn_iter = sinkhorn_iter, n_sortcut = n_sortcut, temperature = temperature, reversible = reversible, ff_chunks = ff_chunks, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, weight_tie = weight_tie, ff_glu = ff_glu, use_simple_sort_net = use_simple_sort_net, receives_context = receives_context, context_n_sortcut = context_n_sortcut, n_local_attn_heads = n_local_attn_heads, use_rezero = use_rezero, n_top_buckets = n_top_buckets,  pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
        # 创建 SinkhornTransformer 模型

        if emb_dim != dim:
            # 如果嵌入维度不等于 dim
            self.sinkhorn_transformer = ProjectInOut(self.sinkhorn_transformer, emb_dim, dim, project_out =(not return_embeddings))
            # 使用 ProjectInOut 对象将嵌入维度转换为 dim

        self.norm = nn.LayerNorm(emb_dim)
        # 创建一个 LayerNorm 层,用于归一化
        self.to_logits = identity if return_embeddings else nn.Linear(emb_dim, num_tokens)
        # 如果 return_embeddings 为真,则使用 identity 函数,否则使用线性层将嵌入向量映射为输出 logits

    def forward(self, x, **kwargs):
        # 前向传播函数,接受输入 x 和关键字参数 kwargs
        _, t, device = *x.shape, x.device
        # 获取输入 x 的形状和设备信息
        assert t <= self.max_seq_len, f'sequence length {t} is greater than maximum sequence length {self.max_seq_len}'
        # 断言序列长度不超过最大序列长度

        x = self.to_token_emb(x)
        # 将输入 x 映射为嵌入向量
        x = self.axial_pos_emb(x) + x
        # 添加轴向位置编码到嵌入向量上
        x = self.emb_dropout(x)
        # 对嵌入向量进行丢弃
        x = self.sinkhorn_transformer(x, **kwargs)
        # 使用 SinkhornTransformer 处理嵌入向量
        x = self.norm(x)
        # 对处理后的向量进行归一化
        return self.to_logits(x)
        # 返回最终的 logits

.\lucidrains\sinkhorn-transformer\sinkhorn_transformer\__init__.py

# 从 sinkhorn_transformer 模块中导入 SinkhornTransformer, SinkhornTransformerLM, SinkhornSelfAttention 类
# 以及 AutoregressiveWrapper, Autopadder 类
from sinkhorn_transformer.sinkhorn_transformer import SinkhornTransformer, SinkhornTransformerLM, SinkhornSelfAttention
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
from sinkhorn_transformer.autopadder import Autopadder

SIREN in Pytorch

PyPI version

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function

Install

$ pip install siren-pytorch

Usage

A SIREN based multi-layered neural network

import torch
from torch import nn
from siren_pytorch import SirenNet

net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    final_activation = nn.Sigmoid(),   # activation of final layer (nn.Identity() for direct output)
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

coor = torch.randn(1, 2)
net(coor) # (1, 3) <- rgb value

One SIREN layer

import torch
from siren_pytorch import Siren

neuron = Siren(
    dim_in = 3,
    dim_out = 256
)

coor = torch.randn(1, 3)
neuron(coor) # (1, 256)

Sine activation (just a wrapper around torch.sin)

import torch
from siren_pytorch import Sine

act = Sine(1.)
coor = torch.randn(1, 2)
act(coor)

Wrapper to train on a specific image of specified height and width from a given SirenNet, and then to subsequently generate.

import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

wrapper = SirenWrapper(
    net,
    image_width = 256,
    image_height = 256
)

img = torch.randn(1, 3, 256, 256)
loss = wrapper(img)
loss.backward()

# after much training ...
# simply invoke the wrapper without passing in anything

pred_img = wrapper() # (1, 3, 256, 256)

Modulation with Latent Code

A new paper proposes that the best way to condition a Siren with a latent code is to pass the latent vector through a modulator feedforward network, where each layer's hidden state is elementwise multiplied with the corresponding layer of the Siren.

You can use this simply by setting an extra keyword latent_dim, on the SirenWrapper

import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

wrapper = SirenWrapper(
    net,
    latent_dim = 512,
    image_width = 256,
    image_height = 256
)

latent = nn.Parameter(torch.zeros(512).normal_(0, 1e-2))
img = torch.randn(1, 3, 256, 256)

loss = wrapper(img, latent = latent)
loss.backward()

# after much training ...
# simply invoke the wrapper without passing in anything

pred_img = wrapper(latent = latent) # (1, 3, 256, 256)

Citations

@misc{sitzmann2020implicit,
    title   = {Implicit Neural Representations with Periodic Activation Functions},
    author  = {Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein},
    year    = {2020},
    eprint  = {2006.09661},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mehta2021modulated,
    title   = {Modulated Periodic Activations for Generalizable Local Functional Representations}, 
    author  = {Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker},
    year    = {2021},
    eprint  = {2104.03960},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\siren-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'siren-pytorch',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.1.7',
  # 许可证
  license='MIT',
  # 描述
  description = 'Implicit Neural Representations with Periodic Activation Functions',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/siren-pytorch',
  # 关键词
  keywords = ['artificial intelligence', 'deep learning'],
  # 安装依赖
  install_requires=[
      'einops',
      'torch'
  ],
  # 分类
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\siren-pytorch\siren_pytorch\siren_pytorch.py

# 导入数学库和PyTorch库
import math
import torch
# 从torch库中导入神经网络模块
from torch import nn
# 从torch.nn.functional中导入函数F
import torch.nn.functional as F
# 从einops库中导入rearrange函数
from einops import rearrange

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 将值转换为元组的函数
def cast_tuple(val, repeat = 1):
    return val if isinstance(val, tuple) else ((val,) * repeat)

# 正弦激活函数

class Sine(nn.Module):
    def __init__(self, w0 = 1.):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

# Siren层

class Siren(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        w0 = 1.,
        c = 6.,
        is_first = False,
        use_bias = True,
        activation = None,
        dropout = 0.
    ):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c = c, w0 = w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation
        self.dropout = nn.Dropout(dropout)

    def init_(self, weight, bias, c, w0):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def forward(self, x):
        out =  F.linear(x, self.weight, self.bias)
        out = self.activation(out)
        out = self.dropout(out)
        return out

# Siren网络

class SirenNet(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_hidden,
        dim_out,
        num_layers,
        w0 = 1.,
        w0_initial = 30.,
        use_bias = True,
        final_activation = None,
        dropout = 0.
    ):
        super().__init__()
        self.num_layers = num_layers
        self.dim_hidden = dim_hidden

        self.layers = nn.ModuleList([])
        for ind in range(num_layers):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden

            layer = Siren(
                dim_in = layer_dim_in,
                dim_out = dim_hidden,
                w0 = layer_w0,
                use_bias = use_bias,
                is_first = is_first,
                dropout = dropout
            )

            self.layers.append(layer)

        final_activation = nn.Identity() if not exists(final_activation) else final_activation
        self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

    def forward(self, x, mods = None):
        mods = cast_tuple(mods, self.num_layers)

        for layer, mod in zip(self.layers, mods):
            x = layer(x)

            if exists(mod):
                x *= rearrange(mod, 'd -> () d')

        return self.last_layer(x)

# 调制前馈

class Modulator(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([])

        for ind in range(num_layers):
            is_first = ind == 0
            dim = dim_in if is_first else (dim_hidden + dim_in)

            self.layers.append(nn.Sequential(
                nn.Linear(dim, dim_hidden),
                nn.ReLU()
            ))

    def forward(self, z):
        x = z
        hiddens = []

        for layer in self.layers:
            x = layer(x)
            hiddens.append(x)
            x = torch.cat((x, z))

        return tuple(hiddens)

# 包装器

class SirenWrapper(nn.Module):
    # 初始化函数,接受神经网络、图像宽度、图像高度和潜在维度作为参数
    def __init__(self, net, image_width, image_height, latent_dim = None):
        # 调用父类的初始化函数
        super().__init__()
        # 断言网络类型为 SirenNet
        assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'

        # 初始化网络、图像宽度和图像高度
        self.net = net
        self.image_width = image_width
        self.image_height = image_height

        # 初始化调制器为 None,如果传入了潜在维度,则创建 Modulator 对象
        self.modulator = None
        if exists(latent_dim):
            self.modulator = Modulator(
                dim_in = latent_dim,
                dim_hidden = net.dim_hidden,
                num_layers = net.num_layers
            )

        # 创建坐标张量
        tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)]
        mgrid = torch.stack(torch.meshgrid(*tensors, indexing = 'ij'), dim=-1)
        mgrid = rearrange(mgrid, 'h w c -> (h w) c')
        # 将坐标张量注册为缓冲区
        self.register_buffer('grid', mgrid)

    # 前向传播函数,接受图像或潜在向量作为参数
    def forward(self, img = None, *, latent = None):
        # 判断是否需要调制
        modulate = exists(self.modulator)
        # 断言只有在初始化时传入了潜在向量才能提供潜在向量
        assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation'

        # 如果需要调制,则计算调制结果
        mods = self.modulator(latent) if modulate else None

        # 复制坐标张量并设置为需要梯度
        coords = self.grid.clone().detach().requires_grad_()
        # 将坐标张量输入网络得到输出
        out = self.net(coords, mods)
        out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width)

        # 如果提供了图像,则计算均方误差损失
        if exists(img):
            return F.mse_loss(img, out)

        # 返回输出结果
        return out

.\lucidrains\siren-pytorch\siren_pytorch\__init__.py

# 从 siren_pytorch.siren_pytorch 模块中导入 Sine, Siren, SirenNet, SirenWrapper 类
from siren_pytorch.siren_pytorch import Sine, Siren, SirenNet, SirenWrapper

Slot Attention

Implementation of Slot Attention from the paper 'Object-Centric Learning with Slot Attention' in Pytorch. Here is a video that describes what this network can do.

Update: The official repository has been released here

Install

$ pip install slot_attention

Usage

import torch
from slot_attention import SlotAttention

slot_attn = SlotAttention(
    num_slots = 5,
    dim = 512,
    iters = 3   # iterations of attention, defaults to 3
)

inputs = torch.randn(2, 1024, 512)
slot_attn(inputs) # (2, 5, 512)

After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the num_slots keyword in forward.

slot_attn(inputs, num_slots = 8) # (2, 8, 512)

Citation

@misc{locatello2020objectcentric,
    title = {Object-Centric Learning with Slot Attention},
    author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
    year = {2020},
    eprint = {2006.15055},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\slot-attention\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'slot_attention',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '1.1.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Implementation of Slot Attention in Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/slot-attention',  # 项目链接
  keywords = ['attention', 'artificial intelligence'],  # 关键词
  install_requires=[
      'torch'  # 安装依赖
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 开发状态
      'Intended Audience :: Developers',  # 预期受众
      'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
      'License :: OSI Approved :: MIT License',  # 许可证
      'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

.\lucidrains\slot-attention\slot_attention\slot_attention.py

import torch
from torch import nn
from torch.nn import init

class SlotAttention(nn.Module):
    # 定义 SlotAttention 类,继承自 nn.Module
    def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
        # 初始化函数,接受 num_slots(槽的数量)、dim(维度)、iters(迭代次数,默认为3)、eps(小数值,默认为1e-8)、hidden_dim(隐藏层维度,默认为128)
        super().__init__()
        # 调用父类的初始化函数

        self.num_slots = num_slots
        # 设置槽的数量
        self.iters = iters
        # 设置迭代次数
        self.eps = eps
        # 设置小数值
        self.scale = dim ** -0.5
        # 计算缩放因子

        self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化槽的均值参数
        self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
        # 初始化槽的对数标准差参数
        init.xavier_uniform_(self.slots_logsigma)
        # 使用 Xavier 初始化方法初始化槽的对数标准差参数

        self.to_q = nn.Linear(dim, dim)
        # 创建线性层,用于将输入转换为查询向量
        self.to_k = nn.Linear(dim, dim)
        # 创建线性层,用于将输入转换为键向量
        self.to_v = nn.Linear(dim, dim)
        # 创建线性层,用于将输入转换为值向量

        self.gru = nn.GRUCell(dim, dim)
        # 创建 GRU 单元,用于更新槽的状态

        hidden_dim = max(dim, hidden_dim)
        # 计算隐藏层维度

        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(inplace = True),
            nn.Linear(hidden_dim, dim)
        )
        # 创建多层感知机模型,用于更新槽的状态

        self.norm_input  = nn.LayerNorm(dim)
        # 创建 LayerNorm 层,用于对输入进行归一化
        self.norm_slots  = nn.LayerNorm(dim)
        # 创建 LayerNorm 层,用于对槽的状态进行归一化
        self.norm_pre_ff = nn.LayerNorm(dim)
        # 创建 LayerNorm 层,用于对前馈网络的输出进行归一化

    def forward(self, inputs, num_slots = None):
        # 前向传播函数,接受输入和槽的数量(可选)
        b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
        # 获取输入的形状、设备和数据类型
        n_s = num_slots if num_slots is not None else self.num_slots
        # 设置槽的数量为给定值或默认值

        mu = self.slots_mu.expand(b, n_s, -1)
        # 复制槽的均值参数以匹配批次大小和槽的数��
        sigma = self.slots_logsigma.exp().expand(b, n_s, -1)
        # 计算槽的标准差并复制以匹配批次大小和槽的数量

        slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)
        # 生成服从正态分布的槽的状态

        inputs = self.norm_input(inputs)
        # 对输入进行归一化
        k, v = self.to_k(inputs), self.to_v(inputs)
        # 将输入转换为键和值

        for _ in range(self.iters):
            # 迭代更新槽的状态
            slots_prev = slots
            # 保存上一次的槽状态

            slots = self.norm_slots(slots)
            # 对槽的状态进行归一化
            q = self.to_q(slots)
            # 将槽的状态转换为查询向量

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            # 计算查询向量和键向量的点积,并乘以缩放因子
            attn = dots.softmax(dim=1) + self.eps
            # 对点积结果进行 softmax 操作,并加上小数值

            attn = attn / attn.sum(dim=-1, keepdim=True)
            # 归一化注意力权重

            updates = torch.einsum('bjd,bij->bid', v, attn)
            # 根据注意力权重更新值向量

            slots = self.gru(
                updates.reshape(-1, d),
                slots_prev.reshape(-1, d)
            )
            # 使用 GRU 单元更新槽的状态

            slots = slots.reshape(b, -1, d)
            # 重新调整槽的状态的形状
            slots = slots + self.mlp(self.norm_pre_ff(slots))
            # 使用多层感知机更新槽的状态

        return slots
        # 返回更新后的槽的状态

.\lucidrains\slot-attention\slot_attention\slot_attention_experimental.py

import torch
from torch import nn
from torch.nn import init

class WeightedAttention(nn.Module):
    def __init__(self, dim, eps = 1e-8, softmax_dim = 1, weighted_mean_dim = 2):
        super().__init__()
        self.norm_input = nn.LayerNorm(dim)  # 对输入进行归一化
        self.norm_context = nn.LayerNorm(dim)  # 对上下文进行归一化

        self.to_q = nn.Linear(dim, dim)  # 线性变换,将输入转换为查询向量
        self.to_k = nn.Linear(dim, dim)  # 线性变换,将上下文转换为键向量
        self.to_v = nn.Linear(dim, dim)  # 线性变换,将上下文转换为值向量

        self.eps = eps  # 用于稳定softmax计算的小值
        self.scale = dim ** -0.5  # 缩放因子
        self.softmax_dim = softmax_dim  # softmax计算的维度
        self.weighted_mean_dim = weighted_mean_dim  # 加权平均的维度

    def forward(self, inputs, context):

        inputs = self.norm_input(inputs)  # 对输入进行归一化
        context = self.norm_context(context)  # 对上下文进行归一化

        q = self.to_q(inputs)  # 计算查询向量
        k = self.to_k(context)  # 计算键向量
        v = self.to_v(context)  # 计算值向量

        dots = torch.einsum('bid,bjd->bij', q, k) * self.scale  # 计算点积
        attn = dots.softmax(dim = self.softmax_dim) + self.eps  # 计算注意力权重
        attn = attn / attn.sum(dim = self.weighted_mean_dim, keepdim=True)  # 计算加权平均

        updates = torch.einsum('bjd,bij->bid', v, attn)  # 计算更新
        return updates

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x):
        return x + self.fn(x)

class GatedResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.gru = nn.GRUCell(dim, dim)  # GRU单元
        self.fn = fn
    def forward(self, *args):
        inputs = args[0]
        b, _, d = inputs.shape

        updates = self.fn(*args)

        inputs = self.gru(
            updates.reshape(-1, d),
            inputs.reshape(-1, d)
        )
        return inputs.reshape(b, -1, d)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        hidden_dim = max(dim, hidden_dim)

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.ReLU(inplace = True),  # ReLU激活函数
            nn.Linear(hidden_dim, dim)  # 线性变换
        )
        self.norm = nn.LayerNorm(dim)  # 对输出进行归一化

    def forward(self, x):
        x = self.norm(x)  # 对输入进行归一化
        return self.net(x)

class SlotAttentionExperimental(nn.Module):
    def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
        super().__init__()
        scale = dim ** -0.5
        self.num_slots = num_slots
        self.iters = iters

        self.norm_inputs = nn.LayerNorm(dim)  # 对输入进行归一化

        self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))  # 槽的均值参数

        self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))  # 槽的对数标准差参数
        init.xavier_uniform_(self.slots_logsigma)  # 初始化槽的对数标准差参数

        self.slots_to_inputs_attn = GatedResidual(dim, WeightedAttention(dim, eps = eps))  # 槽到输入的注意力机制
        self.slots_ff = GatedResidual(dim, FeedForward(dim, hidden_dim))  # 槽的前馈网络

        self.inputs_to_slots_attn = GatedResidual(dim, WeightedAttention(dim, eps = eps, softmax_dim = 2, weighted_mean_dim = 1))  # 输入到槽的注意力机制
        self.inputs_ff = GatedResidual(dim, FeedForward(dim, hidden_dim))  # 输入的前馈网络

    def forward(self, inputs, num_slots = None):
        b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
        n_s = num_slots if num_slots is not None else self.num_slots

        mu = self.slots_mu.expand(b, n_s, -1)  # 扩展槽的均值参数
        sigma = self.slots_logsigma.exp().expand(b, n_s, -1)  # 扩展槽的对数标准差参数

        slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)  # 生成槽

        inputs = self.norm_inputs(inputs)  # 对输入进行归一化

        for _ in range(self.iters):
            slots = self.slots_to_inputs_attn(slots, inputs)  # 槽到输入的注意力机制
            slots = self.slots_ff(slots)  # 槽的前馈网络

            inputs = self.inputs_to_slots_attn(inputs, slots)  # 输入到槽的注意力机制
            inputs = self.inputs_ff(inputs)  # 输入的前馈网络

        return slots, inputs  # 返回槽和输入

.\lucidrains\slot-attention\slot_attention\__init__.py

# 从slot_attention模块中导入SlotAttention类
from slot_attention.slot_attention import SlotAttention
# 从slot_attention_experimental模块中导入SlotAttentionExperimental类
from slot_attention.slot_attention_experimental import SlotAttentionExperimental

.\lucidrains\soft-moe-pytorch\assert.py

# 导入必要的库
import os
from copy import deepcopy

import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from soft_moe_pytorch.soft_moe import Experts, FeedForward as Expert
from soft_moe_pytorch.distributed import all_gather_variable_dim

# 设置初始化函数,用于初始化分布式进程组
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

# 清理函数,用于销毁进程组
def cleanup():
    dist.destroy_process_group()

# 主函数,启动分布式训练
def start(
    rank,
    world_size,
    batch_size,
    batch_size_var_len,
    num_experts,
    tokens_per_expert,
    dim,
):
    # 初始化分布式进程组
    setup(rank, world_size)

    # 创建专家网络
    net = Experts([Expert(dim) for _ in range(num_experts)])

    # 根据是否变长批次设置批次大小
    if batch_size_var_len:
        batch_size = batch_size + rank

    # 生成随机输入序列
    seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim)

    # 分布式训练

    # 使用分布式数据并行包装模型
    model = DDP(net)
    out = model(seq)
    out.mean().backward()

    # 所有进程收集输出
    ddp_all_out, _ = all_gather_variable_dim(out)

    # 单设备上

    # 所有进程收集输入
    all_inputs, _ = all_gather_variable_dim(seq)
    copied_net = deepcopy(net)

    # 在单设备上进行前向传播
    single_out = copied_net(
        all_inputs,
        is_distributed=False
    )

    single_out.mean().backward()

    if rank == 0:
        # 验证输出是否相同
        # 如果在单台机器上和多台机器上进行

        assert torch.allclose(single_out, ddp_all_out), 'output is not the same'

        # 验证梯度和grad是否相同

        get_first_expert_grad = lambda t: t.experts[0][0].weight.grad

        assert torch.allclose(
            get_first_expert_grad(net),
            get_first_expert_grad(copied_net),
            atol=1e-2
        ), 'grad is not the same'

        print('✅')

    # 清理进程组
    cleanup()

if __name__ == '__main__':
    # 设置参数
    world_size = 9
    num_experts = 8
    batch_size = 2
    batch_size_var_len = False

    seq_len = 32
    dim = 8

    # 多进程启动
    mp.spawn(
        start,
        args=(
            world_size,
            batch_size,
            batch_size_var_len,
            num_experts,
            seq_len,
            dim
        ),
        nprocs=world_size,
        join=True
    )

Soft MoE - Pytorch

Implementation of Soft MoE (Mixture of Experts), proposed by Brain's Vision team, in Pytorch.

This MoE has only been made to work with non-autoregressive encoder. However, some recent text-to-image models have started using MoE with great results, so may be a fit there.

If anyone has any ideas for how to make it work for autoregressive, let me know (through email or discussions). I meditated on it but can't think of a good way. The other issue with the slot scheme is that the routing suffers the quadratic as sequence length increases (much like attention)

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • Einops for making my life easy

Install

$ pip install soft-moe-pytorch

Usage

import torch
from soft_moe_pytorch import SoftMoE

moe = SoftMoE(
    dim = 512,         # model dimensions
    seq_len = 1024,    # max sequence length (will automatically calculate number of slots as seq_len // num_experts) - you can also set num_slots directly
    num_experts = 4    # number of experts - (they suggest number of experts should be high enough that each of them get only 1 slot. wonder if that is the weakness of the paper?)
)

x = torch.randn(1, 1024, 512)

out = moe(x) + x # (1, 1024, 512) - add in a transformer in place of a feedforward at a certain layer (here showing the residual too)

For an improvised variant that does dynamic slots so that number of slots ~= sequence length, just import DynamicSlotsSoftMoe instead

import torch
from soft_moe_pytorch import DynamicSlotsSoftMoE

# sequence length or number of slots need not be specified

moe = DynamicSlotsSoftMoE(
    dim = 512,         # model dimensions
    num_experts = 4,   # number of experts
    geglu = True
)

x = torch.randn(1, 1023, 512)

out = moe(x) + x # (1, 1023, 512)

Todo

Citations

@misc{puigcerver2023sparse,
    title 	= {From Sparse to Soft Mixtures of Experts}, 
    author 	= {Joan Puigcerver and Carlos Riquelme and Basil Mustafa and Neil Houlsby},
    year 	= {2023},
    eprint 	= {2308.00951},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}

.\lucidrains\soft-moe-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置软件包的元数据
setup(
  name = 'soft-moe-pytorch', # 软件包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.1.7', # 版本号
  license='MIT', # 许可证
  description = 'Soft MoE - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/soft-moe-pytorch', # 项目链接
  keywords = [
    'artificial intelligence', # 关键词:人工智能
    'deep learning', # 关键词:深度学习
    'mixture of experts' # 关键词:专家混合
  ],
  install_requires=[
    'einops>=0.6.1', # 安装所需的依赖项:einops 版本大于等于 0.6.1
    'torch>=2.0' # 安装所需的依赖项:torch 版本大于等于 2.0
  ],
  classifiers=[
    'Development Status :: 4 - Beta', # 分类器:开发状态为 Beta
    'Intended Audience :: Developers', # 分类器:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License', # 分类器:许可证为 MIT 许可证
    'Programming Language :: Python :: 3.6', # 分类器:编程语言为 Python 3.6
  ],
)

.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function

# 从 torch.distributed 模块中导入 dist 对象
import torch.distributed as dist

# 从 einops 库中导入 rearrange, pack, unpack 函数
from einops import rearrange, pack, unpack

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数,判断两个数是否整除
def divisible_by(num, den):
    return (num % den) == 0

# 定义函数,将张量在指定维度上进行填充
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)

# 定义函数,对所有进程进行相同维度的全局收集
def all_gather_same_dim(t):
    world_size = dist.get_world_size()
    t = t.contiguous()
    gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
    dist.all_gather(gathered_tensors, t)
    return gathered_tensors

# 定义函数,收集指定维度的大小信息
def gather_sizes(t, *, dim):
    size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
    sizes = all_gather_same_dim(size)
    return torch.stack(sizes)

# 定义函数,判断张量是否只有一个值
def has_only_one_value(t):
    return (t == t[0]).all()

# 定义函数,对变量维度进行全局收集
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

    if not exists(sizes):
        sizes = gather_sizes(t, dim = dim)

    if has_only_one_value(sizes):
        gathered_tensors = all_gather_same_dim(t)
        gathered_tensors = torch.cat(gathered_tensors, dim = dim)
        return gathered_tensors, sizes

    max_size = sizes.amax().item()

    padded_t = pad_dim_to(t, max_size, dim = dim)
    gathered_tensors = all_gather_same_dim(padded_t)

    gathered_tensors = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensors = gathered_tensors.index_select(dim, indices)

    return gathered_tensors, sizes

# 定义一个继承自 Function 的类 AllGatherFunction
class AllGatherFunction(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 定义一个继承自 nn.Module 的类 AllGather
class AllGather(nn.Module):
    def __init__(self, *, dim = 0):
        super().__init__()
        self.dim = dim

    def forward(self, x, sizes = None):
        return AllGatherFunction.apply(x, self.dim, sizes)

# 定义函数,根据进程排名拆分张量
def split_by_rank(x):
    rank = dist.get_rank()
    out = x[rank]
    return out

.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\soft_moe.py

# 导入 torch 库
import torch
# 从 torch.nn 中导入 Module 类
from torch.nn import Module
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch.distributed 中导入 dist
import torch.distributed as dist
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor

# 从 einops 中导入 rearrange, pack, unpack
from einops import rearrange, pack, unpack

# 从 soft_moe_pytorch.distributed 中导入 AllGather, split_by_rank, gather_sizes, has_only_one_value
from soft_moe_pytorch.distributed import (
    AllGather,
    split_by_rank,
    gather_sizes,
    has_only_one_value
)

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 将一个数均匀分成若干份
def chunk_num(num, chunks):
    num_per_chunk, remainder = divmod(num, chunks)

    out = []
    for i in range(chunks):
        n = num_per_chunk
        out.append(n + int(i < remainder))

    return out

# 将一个张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将一个打包后的张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = - 1)

# 计算张量的累积和(exclusive)
def cumsum_exclusive(t, dim = -3):
    assert dim < 0
    num_pad_dims = -dim - 1
    pre_padding = (0, 0) * num_pad_dims
    return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim)

# 计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 归一化

# LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# RMSNorm 类
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return l2norm(x) * self.scale * self.gamma

# expert

# 创建 FeedForward 网络
def FeedForward(
    dim,
    mult = 4,
    dropout = 0.
):
    dim_hidden = int(dim * mult)
    return nn.Sequential(
        nn.Linear(dim, dim_hidden),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim)
    )

# GEGLU 类
class GEGLU(Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 创建 GLUFeedForward 网络
def GLUFeedForward(
    dim,
    mult = 4,
    dropout = 0.
):
    dim_hidden = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.Linear(dim, dim_hidden * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim)
    )

# experts

# 专家类
class Experts(nn.Module):
    def __init__(
        self,
        experts,
        is_distributed = None,
        offload_unused_experts_to_cpu = True
    ):
        super().__init__()
        self.num_experts = len(experts)
        self.experts = nn.ModuleList(experts)

        self.is_distributed = is_distributed
        if not exists(self.is_distributed):
            self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

        # 是否将未使用的专家转移到 CPU,需要优化器处理梯度转换到正确设备
        self.offload_unused_experts_to_cpu = offload_unused_experts_to_cpu

        self.all_gather = AllGather()
        self.register_buffer('dummy', torch.ones(1), persistent = False)

    @property
    def device(self):
        return self.dummy.device

    # 将所有专家转移到 CPU,除了指定的专家
    def all_experts_to_cpu_besides(self, selection):
        if not self.offload_unused_experts_to_cpu:
            return

        if isinstance(selection, int):
            experts = [self.experts[selection]]
        if isinstance(selection, slice):
            experts = self.experts[selection]
        else:
            experts = selection

        experts_set = set(experts)

        for expert in self.experts:
            device = self.device if expert in experts_set else 'cpu'
            expert.to(device)

    def forward(
        self,
        x,
        is_distributed = None
        """
        einops notation:
        b - batch
        r - rank (device / machines)
        e - experts
        n - sequence (number of tokens per expert)
        d - feature dimension
        """

        # 检查是否为分布式环境,默认为 self.is_distributed
        is_distributed = default(is_distributed, self.is_distributed)
        # 获取输入张量 x 的形状和专家数量
        shape, num_experts = x.shape, self.num_experts

        # 如果是分布式环境,则在批次维度上进行全局收集,暂时简单处理,后续优化
        if is_distributed:
            # 收集每个专家的序列大小
            seq_sizes = gather_sizes(x, dim=-2)
            assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same'

            # 在批次维度上进行全局收集
            x, batch_sizes = self.all_gather(x)
            total_batch_size = x.shape[0]

            world_size = dist.get_world_size()
            rank = dist.get_rank()
        else:
            world_size = 1
            rank = 0

        # 在当前 rank 上使用的专家

        if is_distributed:
            if world_size <= num_experts:
                num_experts_across_ranks = chunk_num(num_experts, world_size)
                start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim=-1)

                num_experts_per_rank = num_experts_across_ranks[rank]
                num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks)

                expert_start_index = start_indices[rank].item()
            else:
                num_batch_chunks = world_size // num_experts
                total_ranks_in_use = num_batch_chunks * num_experts

                expert_start_index = rank // num_batch_chunks

                batch_splits = chunk_num(total_batch_size, num_batch_chunks)
                num_experts_batches_across_ranks = batch_splits * num_experts

                # 目前,剩余的机器不处理任何内容

                remain_ranks = world_size % num_experts
                num_experts_batches_across_ranks += (0,) * remain_ranks

                num_experts_per_rank = int(rank < total_ranks_in_use)

            assert len(num_experts_batches_across_ranks) == world_size

            expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
        else:
            num_experts_per_rank = num_experts
            expert_slice = slice(0, num_experts)

        # 如果是分布式的,每台机器只处理专家和批次的子集

        # 重新排列输入张量 x 的维度
        x = rearrange(x, 'b e n d -> e b n d')

        if is_distributed:
            # 打包 x,获取打包后的形状
            x, expert_batch_packed_shape = pack_one(x, '* n d')
            x = x.split(num_experts_batches_across_ranks, dim=0)
            x = split_by_rank(x)

            if num_experts_per_rank > 0:
                x = rearrange(x, '(e b) n d -> e b n d', e=num_experts_per_rank)
            else:
                x = x.reshape(num_experts, *x.shape)

        # 获取正在使用的专家

        self.all_experts_to_cpu_besides(expert_slice)

        experts = self.experts[expert_slice]

        # 将标记路由到适当的专家

        outs = []
        for expert, expert_input in zip(experts, x):
            out = expert(expert_input)
            outs.append(out)

        if len(outs) > 0:
            outs = torch.stack(outs)
        else:
            outs = torch.empty_like(x).requires_grad_()

        # 在合并的专家批次维度上进行全局收集,然后将批次维度拆分回来

        if is_distributed:
            outs = rearrange(outs, 'e b n d -> (e b) n d')
            outs, _ = self.all_gather(outs)
            outs = unpack_one(outs, expert_batch_packed_shape, '* n d')

        outs = rearrange(outs, 'e b n d -> b e n d')

        if is_distributed:
            outs = outs.split(batch_sizes.tolist())
            outs = split_by_rank(outs)

        assert outs.shape == shape
        return outs
# 主类 SoftMoE
class SoftMoE(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        seq_len = None,
        num_experts = 4,
        num_slots = None,
        expert_mult = 4,
        dropout = 0.,
        geglu = False,
        is_distributed = None,
        offload_unused_experts_to_cpu = True,
        use_layernorm = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言语句,确保 seq_len 或 num_slots 必须传入 SoftMoE
        assert exists(seq_len) ^ exists(num_slots), 'either seq_len, or num_slots must be passed into SoftMoE'

        # 如果 num_slots 为 None,则计算默认值
        num_slots = default(num_slots, seq_len // num_experts)

        # 根据 use_layernorm 的值选择不同的归一化类
        norm_klass = LayerNorm if use_layernorm else RMSNorm
        # 初始化 norm 层
        self.norm = norm_klass(dim)

        # 初始化 slot_norm 层
        self.slot_norm = norm_klass(dim)
        # 初始化 slot_embeds 参数
        self.slot_embeds = nn.Parameter(torch.randn(num_experts, num_slots, dim))

        # 根据 geglu 的值选择不同的 FeedForward 类
        expert_klass = GLUFeedForward if geglu else FeedForward

        # 初始化 experts 层
        self.experts = Experts(
            experts = [expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts)],
            is_distributed = is_distributed,
            offload_unused_experts_to_cpu = offload_unused_experts_to_cpu
        )

    # 前向传播函数
    def forward(self, x, mask = None, add_noise = False, noise_mult = 1.):
        """
        einstein notation
        b - batch
        n - sequence length
        e - number of experts
        s - number of slots per expert
        d - feature dimension
        """

        # 判断输入是否为单个 token
        is_single_token = x.ndim == 2
        # 判断输入是否为图像
        is_image = x.ndim == 4

        # 如果输入为图像,则重新排列维度
        if is_image:
            x = rearrange(x, 'b d h w -> b h w d')
            x, ps = pack([x], 'b * d')
        # 如果输入为单个 token,则重新排列维度
        elif is_single_token:
            x = rearrange(x, 'b d -> b 1 d')

        # 对输入进行归一化
        x = self.norm(x)
        # 对 slot_embeds 进行归一化
        slot_embeds = self.slot_norm(self.slot_embeds)

        # 计算 logits
        logits = einsum('b n d, e s d -> b n e s', x, slot_embeds)

        # 添加噪音到 dispatch 和 combine gate logits,如果需要则进行退火
        if add_noise:
            noise = gumbel_noise(logits) * noise_mult
            logits = logits + noise

        # 处理 key padding mask
        if exists(mask):
            mask = rearrange(mask, 'b n -> b n 1 1')
            logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)

        # 获取 dispatch 和 combine 权重(在正确的维度上进行 softmax)
        dispatch_weights = logits.softmax(dim = 1)

        combine_weights = rearrange(logits, 'b n e s -> b n (e s)')
        combine_weights = combine_weights.softmax(dim = -1)

        # 通过使用上面的 dispatch 权重对输入 token 进行加权平均,得到 slots
        slots = einsum('b n d, b n e s -> b e s d', x, dispatch_weights)

        # 将每个专家的 slots 路由到每个专家
        out = self.experts(slots)

        # 合并输出
        out = rearrange(out, ' b e s d -> b (e s) d')
        out = einsum('b s d, b n s -> b n d', out, combine_weights)

        # 如果输入为图像,则恢复原始维度
        if is_image:
            out, = unpack(out, ps, 'b * d')
            out = rearrange(out, 'b h w d -> b d h w')
        # 如果输入为单个 token,则恢复原始维度
        elif is_single_token:
            out = rearrange(out, 'b 1 d -> b d')

        return out

.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\soft_moe_with_dynamic_slots.py

# 导入数学库
import math

# 导入 PyTorch 库
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch import nn, einsum, Tensor

# 导入 einops 库中的函数
from einops import rearrange, reduce, pack, unpack
from einops.layers.torch import Rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对输入张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 将张量填充到指定的倍数
def pad_to_multiple(
    tensor,
    multiple,
    dim = -1,
    value = 0
):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple

    if m.is_integer():
        return False, tensor

    remainder = math.ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value)

# 归一化模块

class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return l2norm(x) * self.scale * self.gamma

# 专家模块

# 创建前馈神经网络
def FeedForward(
    dim,
    mult = 4,
    dropout = 0.
):
    dim_hidden = int(dim * mult)
    return nn.Sequential(
        nn.Linear(dim, dim_hidden),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim)
    )

# GEGLU 激活函数
class GEGLU(Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 创建 GLU 前馈神经网络
def GLUFeedForward(
    dim,
    mult = 4,
    dropout = 0.
):
    dim_hidden = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.Linear(dim, dim_hidden * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim)
    )

# 主类

class DynamicSlotsSoftMoE(Module):
    def __init__(
        self,
        dim,
        *,
        num_experts = 4,
        expert_mult = 4,
        dropout = 0.,
        geglu = False
    ):
        super().__init__()
        self.norm = RMSNorm(dim)

        self.num_experts = num_experts

        # 将输入映射到槽位嵌入
        self.to_slot_embeds = nn.Sequential(
            nn.Linear(dim, dim * num_experts, bias = False),
            Rearrange('b n (e d) -> b e n d', e = num_experts),
            RMSNorm(dim)
        )

        # 根据是否使用 GEGLU 创建专家模块
        expert_klass = GLUFeedForward if geglu else FeedForward

        # 创建多个专家模块
        self.experts = nn.ModuleList([
            expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts)
        ])
    # 定义前向传播函数,接受输入 x 和 mask(可选)
    def forward(self, x, mask = None):
        """
        einstein notation
        b - batch
        n - sequence length
        e - number of experts
        s - number of slots per expert
        d - feature dimension
        """

        # 获取输入 x 的序列长度、是否为图像、专家数量等信息
        seq_len, is_image, num_experts = x.shape[-2], x.ndim == 4, self.num_experts

        # 如果输入为图像,则重新排列维度
        if is_image:
            x = rearrange(x, 'b d h w -> b h w d')
            x, ps = pack([x], 'b * d')

        # 对输入进行归一化处理
        x = self.norm(x)

        # 动态槽嵌入
        # 首先对连续的令牌进行平均,然后将每个位置投影到相应数量的专家槽令牌
        # 槽的数量应该约等于序列长度,就像通常的具有 1 个专家的 MoE 一样

        # 检查是否需要填充,对输入进行填充
        is_padded, x = pad_to_multiple(x, num_experts, dim = -2)

        # 如果需要填充,且没有提供 mask,则创建一个全为 True 的 mask
        if is_padded:
            if not exists(mask):
                mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)

            _, mask = pad_to_multiple(mask, num_experts, dim = -1, value = False)

        # 对输入进行分段处理
        x_segmented = rearrange(x, 'b (n e) d -> b n e d', e = num_experts)

        # 如果存在 mask,则根据 mask 进行填充
        if exists(mask):
            segmented_mask = rearrange(mask, 'b (n e) -> b n e', e = num_experts)
            x_segmented = x_segmented.masked_fill(~rearrange(segmented_mask, '... -> ... 1'), 0.)

        # 执行带有 mask 的均值计算
        if exists(mask):
            num = reduce(x_segmented, 'b n e d -> b n d', 'sum')
            den = reduce(segmented_mask.float(), 'b n e -> b n 1', 'sum').clamp(min = 1e-5)
            x_consecutive_mean = num / den
            slots_mask = segmented_mask.any(dim = -1)
        else:
            x_consecutive_mean = reduce(x_segmented, 'b n e d -> b n d', 'mean')

        # 投影以获取动态槽嵌入
        slot_embeds = self.to_slot_embeds(x_consecutive_mean)

        logits = einsum('b n d, b e s d -> b n e s', x, slot_embeds)

        # 考虑键填充 mask

        if exists(mask):
            mask = rearrange(mask, 'b n -> b n 1 1')
            slots_mask = rearrange(slots_mask, 'b s -> b 1 1 s')

            logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)
            logits = logits.masked_fill(~slots_mask, -torch.finfo(logits.dtype).max)

        # 获取分发权重和组合权重(在正确的维度上进行 softmax)

        dispatch_weights = logits.softmax(dim = 1)

        combine_weights = rearrange(logits, 'b n e s -> b n (e s)')
        combine_weights = combine_weights.softmax(dim = -1)

        # 通过使用上述分发权重对输入令牌进行加权平均,得到槽
        slots = einsum('b n d, b n e s -> e b s d', x, dispatch_weights)

        # 将每个专家的槽路由到每个专家

        out = []
        for slots_per_expert, expert in zip(slots, self.experts):
            out.append(expert(slots_per_expert))

        out = torch.stack(out)

        # 合并输出

        out = rearrange(out, 'e b s d -> b (e s) d')
        out = einsum('b s d, b n s -> b n d', out, combine_weights)

        # 如果输入为图像,则恢复原始维度
        if is_image:
            out, = unpack(out, ps, 'b * d')
            out = rearrange(out, 'b h w d -> b d h w')

        return out[:, :seq_len]

.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\__init__.py

# 从 soft_moe_pytorch 软件包中导入 SoftMoE 类
# 从 soft_moe_pytorch 软件包中导入 DynamicSlotsSoftMoE 类
from soft_moe_pytorch.soft_moe import SoftMoE
from soft_moe_pytorch.soft_moe_with_dynamic_slots import DynamicSlotsSoftMoE

Soundstorm - Pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch.

They basically applied MaskGiT to the residual vector quantized codes from Soundstream. The transformer architecture they chose to use is one that fits well with the audio domain, named Conformer

Project Page

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • Lucas Newman for numerous contributions, including the initial training code, acoustic prompting logic, per-level quantizer decoding!

  • 🤗 Accelerate for providing a simple and powerful solution for training

  • Einops for the indispensable abstraction that makes building neural networks fun, easy, and uplifting

  • Steven Hillis for submitting the correct masking strategy and for verifying that the repository works! 🙏

  • Lucas Newman for basically training a small working Soundstorm with models across multiple repositories, showing it all works end-to-end. Models include SoundStream, Text-to-Semantic T5, and finally the SoundStorm transformer here.

  • @Jiang-Stan for identifying a critical bug in the iterative demasking!

Install

$ pip install soundstorm-pytorch

Usage

import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

model = SoundStorm(
    conformer,
    steps = 18,          # 18 steps, as in original maskgit paper
    schedule = 'cosine'  # currently the best schedule is cosine
)

# get your pre-encoded codebook ids from the soundstream from a lot of raw audio

codes = torch.randint(0, 1024, (2, 1024, 12)) # (batch, seq, num residual VQ)

# do the below in a loop for a ton of data

loss, _ = model(codes)
loss.backward()

# model can now generate in 18 steps. ~2 seconds sounds reasonable

generated = model.generate(1024, batch_size = 2) # (2, 1024)

To directly train on raw audio, you need to pass in your pretrained SoundStream into SoundStorm. You can train your own SoundStream at audiolm-pytorch.

import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper, Conformer, SoundStream

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 12,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 12,
    attn_window_size = 128,
    attn_depth = 2
)

model = SoundStorm(
    conformer,
    soundstream = soundstream   # pass in the soundstream
)

# find as much audio you'd like the model to learn

audio = torch.randn(2, 10080)

# course it through the model and take a gazillion tiny steps

loss, _ = model(audio)
loss.backward()

# and now you can generate state-of-the-art speech

generated_audio = model.generate(seconds = 30, batch_size = 2)  # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)

Complete text-to-speech will rely on a trained TextToSemantic encoder / decoder transformer. You will then load the weights and pass it into the SoundStorm as spear_tts_text_to_semantic

This is a work-in-progress, as spear-tts-pytorch only has the model architecture complete, and not the pretraining + pseudo-labeling + backtranslation logic.

from spear_tts_pytorch import TextToSemantic

text_to_semantic = TextToSemantic(
    dim = 512,
    source_depth = 12,
    target_depth = 12,
    num_text_token_ids = 50000,
    num_semantic_token_ids = 20000,
    use_openai_tokenizer = True
)

# load the trained text-to-semantic transformer

text_to_semantic.load('/path/to/trained/model.pt')

# pass it into the soundstorm

model = SoundStorm(
    conformer,
    soundstream = soundstream,
    spear_tts_text_to_semantic = text_to_semantic
).cuda()

# and now you can generate state-of-the-art speech

generated_speech = model.generate(
    texts = [
        'the rain in spain stays mainly in the plain',
        'the quick brown fox jumps over the lazy dog'
    ]
) # (2, n) - raw waveform decoded from soundstream

Todo

Citations

@misc{borsos2023soundstorm,
    title   = {SoundStorm: Efficient Parallel Audio Generation}, 
    author  = {Zalán Borsos and Matt Sharifi and Damien Vincent and Eugene Kharitonov and Neil Zeghidour and Marco Tagliasacchi},
    year    = {2023},
    eprint  = {2305.09636},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\soundstorm-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'soundstorm-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.4.2', # 版本号
  license='MIT', # 许可证
  description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/soundstorm-pytorch', # URL
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'audio generation'
  ],
  install_requires=[ # 安装依赖列表
    'accelerate',
    'audiolm-pytorch>=1.2.8',
    'beartype',
    'classifier-free-guidance-pytorch>=0.1.5',
    'gateloop-transformer>=0.1.1',
    'einops>=0.6.1',
    'spear-tts-pytorch>=0.4.0',
    'torch>=1.6',
  ],
  classifiers = [ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 保证函数只执行一次的装饰器
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用once装饰的print函数,确保只打印一次
print_once = once(print)

# 主要类

class Attend(nn.Module):
    def __init__(
        self,
        causal = False,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于cuda和cpu的高效注意力配置

        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    # 生成掩码
    def get_mask(self, i, j, device):
        return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

    # Flash Attention函数
    def flash_attn(self, q, k, v, mask = None, attn_bias = None):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 单头键/值

        if k.ndim == 3:
            k = rearrange(k, 'b n d -> b 1 n d')

        if v.ndim == 3:
            v = rearrange(v, 'b n d -> b 1 n d')

        # 检查掩码是否存在并扩展到兼容的形状
        # 掩码是B L,因此必须扩展为B H N L

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于Flash Attention

        config = self.cuda_config if is_cuda else self.cpu_config

        causal = self.causal

        # 处理注意力偏置

        if exists(attn_bias):
            mask_value = -torch.finfo(q.dtype).max // 2
            causal_mask = self.get_mask(q_len, k_len, device)
            attn_bias = attn_bias.masked_fill(causal_mask, mask_value)

            if exists(mask):
                attn_bias = attn_bias.masked_fill(~mask, mask_value)

            mask = attn_bias
            causal = False

        # 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = causal
            )

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)、掩码(mask)和注意力偏置(attn_bias)作为参数
    def forward(self, q, k, v, mask = None, attn_bias = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取查询(q)和键(k)的序列长度以及设备信息
        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        # 计算缩放因子
        scale = q.shape[-1] ** -0.5

        # 根据键(k)的维度确定 einsum 的等式
        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        # 如果启用了 flash 模式,则调用 flash_attn 函数进行注意力计算
        if self.flash:
            assert not exists(attn_bias)
            return self.flash_attn(q, k, v, mask = mask)

        # 计算相似度
        sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale

        # 添加注意力偏置
        if exists(attn_bias):
            sim = sim + attn_bias

        # 如果启用了因果模式,则获取因果掩码
        if self.causal:
            causal_mask = self.get_mask(q_len, k_len, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 如果存在掩码,则根据掩码进行填充
        if exists(mask):
            if mask.ndim != 4:
                mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 计算注意力权重
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)

        return out

.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\soundstorm.py

import math
from random import random, randrange  # 导入随机数生成相关函数
from functools import wraps  # 导入wraps装饰器
from contextlib import nullcontext  # 导入nullcontext上下文管理器
from collections import namedtuple  # 导入namedtuple命名元组
from pathlib import Path  # 导入Path路径操作模块

import torch  # 导入PyTorch深度学习库
from torch.cuda.amp import autocast  # 导入自动混合精度计算
from torch import Tensor, nn, einsum  # 导入张量、神经网络、einsum函数
import torch.nn.functional as F  # 导入PyTorch中的函数模块

from einops import rearrange, reduce, repeat, unpack, pack  # 导入einops库中的函数
from einops.layers.torch import Rearrange, EinMix  # 导入einops库中的层函数

from beartype import beartype  # 导入beartype类型检查库
from beartype.door import is_bearable  # 导入is_bearable函数
from beartype.typing import Union, Dict, Optional, List, Optional  # 导入beartype中的类型注解

from soundstorm_pytorch.attend import Attend  # 导入Attend模块

from spear_tts_pytorch import TextToSemantic  # 导入TextToSemantic模块

from audiolm_pytorch import SoundStream  # 导入SoundStream模块
from audiolm_pytorch import HubertWithKmeans, FairseqVQWav2Vec  # 导入HubertWithKmeans和FairseqVQWav2Vec模块

from gateloop_transformer import SimpleGateLoopLayer as GateLoop  # 导入SimpleGateLoopLayer模块

from tqdm import tqdm  # 导入tqdm进度条模块

# helpers

def exists(val):
    return val is not None  # 判断值是否存在

def default(val, d):
    return val if exists(val) else d  # 如果值存在则返回值,否则返回默认值

def divisible_by(numer, denom):
    return (numer % denom) == 0  # 判断是否可以整除

def calc_same_padding(kernel_size):
    pad = kernel_size // 2  # 计算padding值
    return (pad, pad - (kernel_size + 1) % 2)  # 返回padding元组

def eval_decorator(fn):
    @wraps(fn)
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()  # 设置模型为评估模式
        out = fn(model, *args, **kwargs)  # 调用函数
        model.train(was_training)  # 恢复模型训练模式
        return out
    return inner

# sampling helpers

def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])  # 计算top-k值
    val, ind = logits.topk(k, dim = -1)  # 获取top-k值和索引
    probs = torch.full_like(logits, float('-inf'))  # 创建与logits相同形状的全为负无穷的张量
    probs.scatter_(2, ind, val)  # 根据索引填充top-k值
    return probs  # 返回top-k值

def log(t, eps = 1e-10):
    return torch.log(t + eps)  # 计算对数

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)  # 生成均匀分布的噪声
    return -log(-log(noise))  # 计算Gumbel噪声

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)  # 计算Gumbel采样

# prob helpers

def sample_prob(prob):
    return random() < prob  # 根据概率进行采样

def coin_flip():
    return sample_prob(0.5)  # 以0.5的概率进行翻转

# tensor helpers

@beartype
def get_mask_subset_prob(
    mask: Tensor,
    prob: Union[float, Tensor],
    min_mask: int = 0
):
    batch, seq, device = *mask.shape, mask.device  # 获取批次大小、序列长度和设备信息

    if isinstance(prob, Tensor):
        prob = rearrange(prob, 'b -> b 1')  # 重排概率张量的维度

    num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)  # 计算要屏蔽的数量
    logits = torch.rand((batch, seq), device = device)  # 生成随机数张量
    logits = logits.masked_fill(~mask, -1)  # 根据mask进行填充

    randperm = logits.argsort(dim = -1).argsort(dim = -1).float()  # 对logits进行排序

    num_padding = (~mask).sum(dim = -1, keepdim = True)  # 计算填充数量
    randperm -= num_padding  # 减去填充数量

    subset_mask = randperm < num_to_mask  # 生成子集mask
    subset_mask.masked_fill_(~mask, False)  # 根据mask进行填充
    return subset_mask  # 返回子集mask

# schedules

def linear_schedule(t):
    return 1 - t  # 线性调度函数

def cosine_schedule(t):
    """ https://arxiv.org/abs/2202.04200 """
    return torch.cos(t * math.pi / 2)  # 余弦调度函数

# rotary embedding

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))  # 计算频率
        self.register_buffer("inv_freq", inv_freq, persistent = False)  # 注册缓冲区

    @property
    def device(self):
        return next(self.buffers()).device  # 获取设备信息

    @autocast(enabled = False)
    def forward(self, seq_len):
        t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)  # 生成序列长度张量
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)  # 计算频率
        freqs = torch.cat((freqs, freqs), dim = -1)  # 拼接频率
        return freqs  # 返回频率

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)  # 将张量分成两部分
    return torch.cat((-x2, x1), dim=-1)  # 拼接张量

@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())  # 应用旋转位置嵌入

# t5 relative positional bias

class T5RelativePositionBias(nn.Module):
    def __init__(
        self,
        scale = 1.,
        num_buckets = 32,
        max_distance = 128,
        heads = 8
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化缩放因子
        self.scale = scale
        # 初始化桶的数量
        self.num_buckets = num_buckets
        # 初始化最大距离
        self.max_distance = max_distance
        # 创建相对注意力偏置的嵌入层
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(
        relative_position,
        num_buckets = 32,
        max_distance = 128
    ):
        # 初始化返回值
        ret = 0
        # 计算相对位置的负值
        n = -relative_position

        # 将桶的数量减半
        num_buckets //= 2
        # 根据n是否小于0来更新ret
        ret += (n < 0).long() * num_buckets
        n = torch.abs(n)

        # 计算最大精确值
        max_exact = num_buckets // 2
        # 判断n是否小于最大精确值
        is_small = n < max_exact

        # 计算大值时的结果
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()

        # 将大值结果限制在桶的范围内
        val_if_large = torch.min(
            val_if_large,
            torch.full_like(val_if_large, num_buckets - 1)
        )

        # 根据is_small选择n或者val_if_large
        ret += torch.where(is_small, n, val_if_large)
        return ret

    @property
    def device(self):
        # 返回参数的设备信息
        return next(self.parameters()).device

    def forward(self, n):
        # 生成长度为n的张量
        pos = torch.arange(n, device = self.device).long()
        # 计算相对位置
        rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')

        # 计算相对位置的桶
        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        # 获取相对注意力偏置的值
        values = self.relative_attention_bias(rp_bucket)

        # 重排values的维度
        bias = rearrange(values, 'i j h -> h i j')
        return bias * self.scale
# 定义 Swish 激活函数模块
class Swish(nn.Module):
    # 前向传播函数
    def forward(self, x):
        return x * x.sigmoid()

# 定义 GLU 模块
class GLU(nn.Module):
    # 初始化函数
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    # 前向传播函数
    def forward(self, x):
        # 将输入张量按维度分割成两部分
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()

# 定义 DepthWiseConv1d 模块
class DepthWiseConv1d(nn.Module):
    # 初始化函数
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.padding = padding
        # 创建深度卷积层
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)

    # 前向传播函数
    def forward(self, x, mask=None):
        # 如果存在掩码,则将掩码应用到输入张量上
        if exists(mask):
            mask = rearrange(mask, 'b n -> b 1 n')
            x = x.masked_fill(~mask, 0.)

        # 对输入张量进行填充
        x = F.pad(x, self.padding)
        # 进行卷积操作
        out = self.conv(x)

        # 如果存在掩码,则将掩码应用到输出张量上
        if exists(mask):
            out = out.masked_fill(~mask, 0.)

        return out

# 定义 Scale 模块
class Scale(nn.Module):
    # 初始化函数
    def __init__(self, scale, fn):
        super().__init__()
        self.fn = fn
        self.scale = scale

    # 前向传播函数
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# 定义 ChanLayerNorm 模块
class ChanLayerNorm(nn.Module):
    # 初始化函数
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(1, dim, 1))

    # 前向传播函数
    def forward(self, x):
        eps = 1e-6 if x.dtype == torch.float32 else 1e-4
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) * var.clamp(min=eps).rsqrt() * self.gamma

# 定义 PreNorm 模块
class PreNorm(nn.Module):
    # 初始化函数
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义 Attention 模块
class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        heads=8,
        dim_head=64,
        dropout=0.,
        flash=True
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = Attend(
            flash=flash,
            dropout=dropout
        )

        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

    # 前向传播函数
    def forward(
        self,
        x,
        context=None,
        mask=None,
        rotary_emb=None,
        attn_bias=None
    ):
        n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        if exists(rotary_emb):
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        out = self.attend(q, k, v, mask=mask, attn_bias=attn_bias)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 FeedForward 模块
class FeedForward(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        mult=4,
        dropout=0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            Swish(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 ConformerConvModule 模块
class ConformerConvModule(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        causal=False,
        expansion_factor=2,
        kernel_size=31,
        dropout=0.
    # 定义一个类,继承自 nn.Module
    ):
        # 调用父类的构造函数
        super().__init__()

        # 计算内部维度
        inner_dim = dim * expansion_factor
        # 计算填充大小,如果是因果卷积则填充为 (kernel_size - 1, 0)
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

        # 定义网络结构 net1,包括 LayerNorm、Rearrange、Conv1d 和 GLU 激活函数
        self.net1 = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            nn.Conv1d(dim, inner_dim * 2, 1),
            GLU(dim=1)
        )

        # 定义深度卷积层 ds_conv
        self.ds_conv = DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding)

        # 定义网络结构 net2,包括 Swish 激活函数、ChanLayerNorm、Conv1d、Rearrange 和 Dropout
        self.net2 = nn.Sequential(
            Swish(),
            ChanLayerNorm(inner_dim),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange('b c n -> b n c'),
            nn.Dropout(dropout)
        )

    # 定义前向传播函数
    def forward(self, x, mask = None):
        # 使用 net1 进行前向传播
        x = self.net1(x)
        # 使用 ds_conv 进行前向传播
        x = self.ds_conv(x, mask = mask)
        # 使用 net2 进行前向传播
        return self.net2(x)
# Conformer Block

# 定义 ConformerBlock 类
class ConformerBlock(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 维度
        dim_head = 64,  # 头的维度
        heads = 8,  # 头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        conv_expansion_factor = 2,  # 卷积扩展因子
        conv_kernel_size = 31,  # 卷积核大小
        attn_dropout = 0.,  # 注意力机制的 dropout
        attn_flash = True,  # 是否使用闪存注意力
        ff_dropout = 0.,  # FeedForward 层的 dropout
        conv_dropout = 0.,  # 卷积层的 dropout
        conv_causal = False,  # 是否是因果卷积
        use_gateloop_layers = False  # 是否使用门循环层
    ):
        super().__init__()
        # 创建第一个 FeedForward 层
        self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)

        # 如果使用门循环层,则创建 GateLoop 层
        self.gateloop = GateLoop(dim) if use_gateloop_layers else None

        # 创建注意力机制层
        self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash)
        # 创建 ConformerConvModule 层
        self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
        # 创建第二个 FeedForward 层
        self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)

        # 对注意力机制层进行预归一化
        self.attn = PreNorm(dim, self.attn)
        # 对第一个 FeedForward 层进行预归一化
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
        # 对第二个 FeedForward 层进行预归一化
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

        # 创建 LayerNorm 层
        self.post_norm = nn.LayerNorm(dim)

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        rotary_emb = None,
        attn_bias = None
    ):
        # 第一个 FeedForward 层
        x = self.ff1(x) + x

        # 如果存在门循环层,则应用门循环层
        if exists(self.gateloop):
            x = self.gateloop(x) + x

        # 注意力机制层
        x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x
        # 卷积层
        x = self.conv(x, mask = mask) + x
        # 第二个 FeedForward 层
        x = self.ff2(x) + x
        # LayerNorm 层
        x = self.post_norm(x)
        return x

# Conformer

# 定义 Conformer 类
class Conformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        depth,  # 深度
        dim_head = 64,  # 头的维度
        heads = 8,  # 头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        conv_expansion_factor = 2,  # 卷积扩展因子
        conv_kernel_size = 31,  # 卷积核大小
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # FeedForward 层的 dropout
        conv_dropout = 0.,  # 卷积层的 dropout
        conv_causal = False,  # 是否是因果卷积
        attn_flash = True,  # 是否使用闪存注意力
        t5_rel_pos_bias = False,  # 是否使用 T5 相对位置偏置
        use_gateloop_layers = True  # 是否使用门循环层
    ):
        super().__init__()

        # 断言,确保闪存注意力和学习偏置不兼容
        assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias'

        self.dim = dim
        self.layers = nn.ModuleList([])

        # 如果不使用 T5 相对位置偏置,则创建 RotaryEmbedding 层
        self.rotary_emb = RotaryEmbedding(dim_head) if not t5_rel_pos_bias else None
        # 如果使用 T5 相对位置偏置,则创建 T5RelativePositionBias 层
        self.rel_pos_bias = T5RelativePositionBias(dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None

        # 根据深度循环创建 ConformerBlock 层
        for _ in range(depth):
            self.layers.append(ConformerBlock(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                conv_expansion_factor = conv_expansion_factor,
                conv_kernel_size = conv_kernel_size,
                attn_dropout = attn_dropout,
                ff_dropout = ff_dropout,
                conv_dropout = conv_dropout,
                conv_causal = conv_causal,
                attn_flash = attn_flash,
                use_gateloop_layers = use_gateloop_layers
            ))

    # 前向传播函数
    def forward(self, x, mask = None):
        seq_len = x.shape[-2]

        # 如果存在 RotaryEmbedding 层,则创建旋转嵌入
        rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None
        # 如果存在 T5RelativePositionBias 层,则创建注意力偏置
        attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None

        # 遍历 ConformerBlock 层进行前向传播
        for block in self.layers:
            x = block(
                x,
                mask = mask,
                rotary_emb = rotary_emb,
                attn_bias = attn_bias
            )

        return x

# conformer with sum reduction across quantized tokens at the beginning, along with heads

# 定义 ConformerWrapper 类
class ConformerWrapper(nn.Module):

    @beartype
    # 初始化函数
    def __init__(
        self,
        *,
        codebook_size,  # 代码本大小
        num_quantizers,  # 量化器数量
        conformer: Union[Conformer, Dict[str, any]],  # Conformer 模型
        grouped_quantizers = 1  # 分组量化器数量
        ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化属性conformer
        self.conformer = conformer

        # 如果conformer是字典类型,则使用Conformer类初始化self.conformer
        if isinstance(conformer, dict):
            self.conformer = Conformer(**self.conformer)

        # 获取conformer的维度
        dim = self.conformer.dim

        # 根据grouped_quantizers的值判断是否需要进行embedding投影
        self.embedding_proj = nn.Sequential(
            nn.Linear(dim * grouped_quantizers, dim),
            nn.LayerNorm(dim)
        ) if grouped_quantizers > 1 else nn.Identity()

        # 计算带有mask的量化器代码数量
        num_codes_with_mask = codebook_size + 1
        num_effective_quantizers = num_quantizers * grouped_quantizers

        # 初始化代码嵌入层
        self.code_embeds = nn.Embedding(num_codes_with_mask * num_effective_quantizers, dim)

        # 注册缓冲区,存储量化器偏移和mask标记
        self.register_buffer('quantizer_offsets', torch.arange(num_effective_quantizers) * num_codes_with_mask, persistent=False)
        self.register_buffer('mask_tokens', self.quantizer_offsets + num_codes_with_mask, persistent=False)

        # 初始化其他属性
        self.dim = dim
        self.codebook_size = codebook_size
        self.num_codes_with_mask = num_codes_with_mask
        self.num_quantizers = num_quantizers
        self.grouped_quantizers = grouped_quantizers

        # 初始化头部
        self.heads = nn.Sequential(
            nn.Linear(dim, dim * num_effective_quantizers),
            Rearrange('b n (h d) -> b (n h) d', h=num_effective_quantizers)
        )

        # 每个量化器代码本都需要自己的logits权重和偏置矩阵
        # 使用EinMix和einops实现
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b (n gq) d -> b n gq d', gq=num_effective_quantizers),
            EinMix(
                'b n gq d -> b n gq l',
                weight_shape='gq d l',
                bias_shape='gq l',
                gq=num_effective_quantizers,
                l=codebook_size,
                d=dim
            ),
            Rearrange('b ... d -> b (...) d')
        )

    def forward(
        self,
        x,
        *,
        mask=None,
        cond=None,
        sum_embeds=None,
        return_embeddings=False,
        return_logits_and_embeddings=False
    ):
        """
        einops notation:
        b - batch
        n - sequence
        g - groups
        q - quantizers
        d - feature dimension
        """

        # 获取x的维度信息
        n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
        assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers'

        # 重排x的维度
        x = rearrange(x, 'b (n gq) -> b n gq', gq=g * q)
        x = x + self.quantizer_offsets

        # 对x进行代码嵌入
        x = self.code_embeds(x)

        # 对x进行降维操作
        x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g=g)

        # 对x进行嵌入投影
        x = self.embedding_proj(x)

        # 如果存在sum_embeds,则将其加到x上
        if exists(sum_embeds):
            x = x + sum_embeds

        # 如果存在cond,则将其加到x上
        if exists(cond):
            if cond.ndim == 2:
                cond = rearrange(cond, 'b d -> b 1 d')

            x = x + cond

        # 对x进行Conformer处理
        x = self.conformer(x, mask=mask)
        embeds = self.heads(x)

        # 如果需要返回嵌入向量或者没有to_logits,则返回embeds
        if return_embeddings or not exists(self.to_logits):
            return embeds

        # 获取logits
        logits = self.to_logits(embeds)

        # 如果需要返回logits和嵌入向量,则返回logits和embeds
        if return_logits_and_embeddings:
            return logits, embeds

        return logits
# 定义 LogitHead 类,用于处理主要的 logits 以及自我 token 评论
class LogitHead(nn.Module):
    def __init__(
        self,
        net: ConformerWrapper,
        logit_dim
    ):
        super().__init__()
        self.net = net
        dim = net.dim
        self.to_logits = nn.Linear(dim, logit_dim)

    def forward(self, x):
        # 获取网络的嵌入表示
        embed = self.net(x, return_embeddings = True)
        return self.to_logits(embed)

# 定义 LossBreakdown 命名元组,包含生成器损失和评论家损失
LossBreakdown = namedtuple('LossBreakdown', ['generator_loss', 'critic_loss'])

# 定义 SoundStorm 类,用于处理声音数据
class SoundStorm(nn.Module):

    @beartype
    def __init__(
        self,
        net: ConformerWrapper,
        *,
        soundstream: Optional[SoundStream] = None,
        spear_tts_text_to_semantic: Optional[TextToSemantic] = None,
        wav2vec: Optional[Union[HubertWithKmeans, FairseqVQWav2Vec]] = None,
        steps = 18,
        self_cond = False,
        self_cond_train_prob = 0.75,
        no_replace_prob = 0.15,          # 原始 MLM 论文中指定的一定比例的 tokens 会保持不变
        random_token_prob = 0.1,         # 原始 MLM 论文中指定的一定比例的 tokens 会被替换为随机 token
        schedule = 'linear',
        can_mask_prev_unmasked = False,  # 当解除 mask 时,是否可以重新 mask 之前未 mask 的 tokens
        self_token_critic = False,       # 是否使用自我 token 评论家
        critic_loss_weight = 1.,
        num_semantic_token_ids = None,
        semantic_pad_id = -1,
        pad_id = None,
        wav2vec_target_sample_hz = None,
        wav2vec_downsample_factor = None,
        codec_target_sample_hz = None,
        codec_downsample_factor = None,
    @property
    def device(self):
        return next(self.net.parameters()).device

    def load(self, path, strict = True):
        # 加载模型参数
        # 返回 pkg,以便如果此函数从 Trainer 函数调用中调用,则 Trainer 也可以访问从检查点加载的 package
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')
        self.load_state_dict(pkg['model'], strict = strict)
        return pkg

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        num_latents = None,
        *,
        mask = None,
        texts: Optional[Union[List[str], Tensor]] = None,
        cond_semantic_token_ids = None,
        prompt_acoustic_token_ids = None,
        seconds = None,
        batch_size = None,
        start_temperature = 1.,
        filter_thres = 0.7,
        noise_level_scale = 1.,
        num_full_sampling_levels = 1,
        text_to_semantic_generate_kwargs: dict = {},
        spec_decode = False,
        spec_decode_gamma = 5,
        **kwargs
    # 定义一个方法,用于获取条件信息
    def maybe_get_condition(self, token_ids = None, length = None):
        # 断言条件:如果传入的 token_ids 存在,则应该开启文本条件化,反之亦然
        assert not (exists(token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa'

        # 如果 token_ids 不存在,则返回 None
        if not exists(token_ids):
            return None

        # 根据是否存在文本到语义的映射,选择是否开启 torch 的无梯度上下文
        context = torch.no_grad if exists(self.text_to_semantic) else nullcontext

        # 在上下文中执行以下代码块
        with context():
            # 创建一个 mask,用于过滤掉语义填充标记
            mask = token_ids != self.semantic_pad_id

            # 如果存在文本到语义的映射,并且自动设置了 eos 语义标记 id
            if exists(self.text_to_semantic) and self.text_to_semantic.autoset_eos_id['speech']:
                # 进一步过滤掉 eos 语义标记 id
                mask &= token_ids != self.num_semantic_token_ids

            # 将不符合 mask 的 token_ids 替换为 0
            token_ids = token_ids.masked_fill(~mask, 0)

            # 获取语义标记的嵌入
            semantic_tokens = self.semantic_token_emb(token_ids)
            # 将语义标记转换为模型维度的条件 tokens
            cond_tokens = self.semantic_cond_to_model_dim(semantic_tokens)

            # 将填充部分的值设为 0,让网络学习处理
            cond_tokens = cond_tokens.masked_fill(~rearrange(mask, '... -> ... 1'), 0.)

        # 需要插值条件 tokens,以使语义和向量量化 tokens 在时间上对齐
        cond_length = cond_tokens.shape[-2]

        # 计算目标条件长度
        target_cond_length = math.ceil(cond_length * (self.wav2vec_downsample_factor / self.wav2vec_target_sample_hz) / (self.codec_downsample_factor / self.codec_target_sample_hz))

        # 由于 PyTorch 不支持 1D 插值,将数据转换为 2D 进行插值
        if cond_length != target_cond_length:
            cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1')
            cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear')
            cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d')

        # 根据长度是否存在,决定是截断还是填充条件 tokens
        cond_length = cond_tokens.shape[-2]

        if exists(length):
            if cond_length < length:
                cond_tokens = F.pad(cond_tokens, (0, 0, 0, length - cond_length), value = 0.)
            elif cond_length > length:
                cond_tokens = cond_tokens[:, :length]

        # 返回处理后的条件 tokens
        return cond_tokens

    # 定义前向传播方法
    def forward(
        self,
        x,
        *,
        mask = None,
        cond_semantic_token_ids = None,
        only_train_generator = False,
        only_train_critic = False,
        generator_sample_temperature = None,
        **kwargs
posted @ 2024-06-28 14:07  绝不原创的飞龙  阅读(78)  评论(0)    收藏  举报