Lucidrains-系列项目源码解析-三十六-
Lucidrains 系列项目源码解析(三十六)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Reformer with Deepspeed for Enwik8
Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with Reformer Pytorch!
-
First install Deepspeed following instructions from their official repository https://github.com/microsoft/DeepSpeed
-
Run the following command in this folder
$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json
.\lucidrains\reformer-pytorch\examples\enwik8_deepspeed\train.py
# 导入 deepspeed 库
import deepspeed
# 从 reformer_pytorch 库中导入 ReformerLM 类
from reformer_pytorch import ReformerLM
# 从 reformer_pytorch 库中导入 TrainingWrapper 类
from reformer_pytorch.generative_tools import TrainingWrapper
# 导入 argparse 库
import argparse
# 导入 random 库
import random
# 导入 tqdm 库
import tqdm
# 导入 gzip 库
import gzip
# 导入 numpy 库
import numpy as np
# 导入 torch 库
import torch
# 从 torch 中导入 optim 模块
import torch.optim as optim
# 从 torch.nn 中导入 functional 模块
from torch.nn import functional as F
# 从 torch.utils.data 中导入 DataLoader 和 Dataset 类
from torch.utils.data import DataLoader, Dataset
# 定义 add_argument 函数
def add_argument():
# 创建 ArgumentParser 对象,描述为 'enwik8'
parser=argparse.ArgumentParser(description='enwik8')
# 添加参数 '--with_cuda',默认为 False,支持存储为 True
parser.add_argument('--with_cuda', default=False, action='store_true',
help='use CPU in case there\'s no GPU support')
# 添加参数 '--use_ema',默认为 False,支持存储为 True
parser.add_argument('--use_ema', default=False, action='store_true',
help='whether use exponential moving average')
# 添加参数 '-b' 或 '--batch_size',默认为 32,类型为整数
parser.add_argument('-b', '--batch_size', default=32, type=int,
help='mini-batch size (default: 32)')
# 添加参数 '-e' 或 '--epochs',默认为 30,类型为整数
parser.add_argument('-e', '--epochs', default=30, type=int,
help='number of total epochs (default: 30)')
# 添加参数 '--local_rank',类型为整数,默认为 -1
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
# 调用 deepspeed 库的 add_config_arguments 函数,将参数添加到 parser 中
parser = deepspeed.add_config_arguments(parser)
# 解析参数并返回结果
args=parser.parse_args()
return args
# 定义常量
EPOCHS = 20
GRADIENT_ACCUMULATE_EVERY = 4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 4096
# 定义辅助函数
# 解码单个 token
def decode_token(token):
return str(chr(max(32, token)))
# 解码一组 tokens
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化模型
# 创建 ReformerLM 模型对象
model = ReformerLM(
dim = 512,
depth = 6,
max_seq_len = SEQ_LEN,
num_tokens = 256,
heads = 8,
bucket_size = 64,
n_hashes = 4,
ff_chunks = 10,
lsh_dropout = 0.1,
weight_tie = True,
causal = True,
n_local_attn_heads = 4,
use_full_attn = False # set this to true for comparison with full attention
)
# 使用 TrainingWrapper 对模型进行包装
model = TrainingWrapper(model)
# 将模型移至 GPU
model.cuda()
# 准备 enwik8 数据
# 使用 gzip 打开 enwik8.gz 文件
with gzip.open('./data/enwik8.gz') as file:
# 从文件中读取数据并转换为 numpy 数组
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
# 将数据分为训练集和验证集
trX, vaX = np.split(X, [int(90e6)])
# 将数据转换为 torch 张量
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义 TextSamplerDataset 类,继承自 Dataset 类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
# 随机选择起始位置
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 获取完整序列
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的 Dataset 对象
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 设置 deepspeed
# 添加参数并获取命令行参数
cmd_args = add_argument()
# 使用 deepspeed 初始化模型引擎、优化器、训练数据加载器
model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(), training_data=train_dataset)
# 训练
# 循环训练多个 epochs
for _ in range(EPOCHS):
# 遍历训练数据加载器
for i, data in enumerate(trainloader):
# 设置模型为训练模式
model_engine.train()
# 将数据移至本地 GPU
data = data.to(model_engine.local_rank)
# 计算损失
loss = model_engine(data, return_loss = True)
# 反向传播
model_engine.backward(loss)
# 更新参数
model_engine.step()
# 打印损失值
print(loss.item() * GRADIENT_ACCUMULATE_EVERY)
# 每隔一定步数进行验证
if i % VALIDATE_EVERY == 0:
# 设置模型为评估模式
model.eval()
with torch.no_grad():
# 从验证集中随机选择一个样本
inp = random.choice(val_dataset)[:-1]
# 计算验证集上的损失
loss = model(inp[None, :].cuda(), return_loss = True)
print(f'validation loss: {loss.item()}')
# 每隔一定步数生成文本
if i % GENERATE_EVERY == 0:
# 设置模型为评估模式
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
# 生成文本
sample = model.generate(inp.cuda(), GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\reformer-pytorch\examples\enwik8_simple\train.py
# 导入所需的库
from reformer_pytorch import ReformerLM
from reformer_pytorch.generative_tools import TrainingWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096
# 定义辅助函数
def cycle(loader):
# 无限循环生成数据
while True:
for data in loader:
yield data
def decode_token(token):
# 将 token 解码为字符
return str(chr(max(32, token)))
def decode_tokens(tokens):
# 将 tokens 解码为字符串
return ''.join(list(map(decode_token, tokens)))
# 实例化模型
model = ReformerLM(
dim = 512,
depth = 6,
max_seq_len = SEQ_LEN,
num_tokens = 256,
heads = 8,
bucket_size = 64,
n_hashes = 4,
ff_chunks = 10,
lsh_dropout = 0.1,
weight_tie = True,
causal = True,
n_local_attn_heads = 4,
use_full_attn = False # 设置为 true 以与全注意力进行比较
)
model = TrainingWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader), return_loss = True)
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader), return_loss = True)
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
Pretraining
The goal of this script is to provide a method for creating pretrained Reformer Models to later be used for transfer learning.
I download the data from https://dumps.wikimedia.org/ and extract the data to json objects using https://github.com/attardi/wikiextractor
.\lucidrains\reformer-pytorch\pretraining\self-supervised.py
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from reformer_pytorch import Reformer, ReformerLM
from transformers import BertTokenizer, PreTrainedTokenizer
from fairseq.optim.adafactor import Adafactor
import os
import json
import logging
from datetime import datetime
# 定义一个自定义数据集类,用于处理Wiki数据集
class WikiDataset(Dataset):
def __init__(self, path="", prefix="train"):
# 确保给定的路径是一个目录
assert os.path.isdir(path)
self.documents = []
filename_list = os.listdir(path)
# 遍历目录下的文件
for file in filename_list:
path_to_file = os.path.join(path, file)
# 如果不是文件则跳过
if not os.path.isfile(path_to_file):
continue
self.documents.append(path_to_file)
def __len__(self):
""" Returns the number of documents. """
return len(self.documents)
def __getitem__(self, idx):
document_path = self.documents[idx]
document_name = document_path.split("/")[-1]
items = []
with open(document_path, encoding="utf-8") as source:
raw_text = source.readlines()
# 读取每个文档中的文本内容
for obj in raw_text:
text = json.loads(obj)['text']
# 替换文本中的换行符和多余空格
text = re.sub('\\n', ' ', text)
text = re.sub('\\s+', ' ', text)
items.append(text)
return items
# 定义一个Reformer模型训练器类
class ReformerTrainer(object):
def __init__(self,
dataset,
model,
tokenizer,
device=None,
train_batch_size=8,
eval_batch_size=None,
tb_writer=True,
tb_dir='./tb_logs',
log_dir='./logs'):
"""
Provides an easy to use class for pretraining and evaluating a Reformer Model.
:param dataset: (torch.utils.data.Dataset) containing all of the data you wish to utilize during training.
:param model: (reformer_pytorch.Reformer)
:param tokenizer: (transformers.PreTrainedTokenizer) defaults to BertTokenizer ('bert-base-case')
:param device: provide manual device placement. If None, will default to cuda:0 if available.
:param tb_writer: (bool) Whether to write to tensorboard or not.
:param tb_dir: (str) Where to write TB logs to.
:param log_dir: (str) Where to write generic logs to.
"""
self.dataset = dataset
self.model = model
self.tokenizer = tokenizer
self.device = device
self.n_gpu = torch.cuda.device_count() if torch.cuda.is_available() else 0
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.tb_writer = tb_writer
self.log_dir = log_dir
# 如果未提供tokenizer,则使用默认的BertTokenizer
if tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# 如果未提供device,则根据是否有cuda选择设备
if device is None:
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# 如果未提供eval_batch_size,则使用train_batch_size
if eval_batch_size is None:
self.eval_batch_size = train_batch_size
# 如果需要写入tensorboard,则初始化SummaryWriter
if tb_writer:
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=tb_dir)
# 配置日志记录
logging.basicConfig(filename=f'{log_dir}/{datetime.now().date()}.log', level=logging.INFO)
def build_dataloaders(self, train_test_split=0.1, train_shuffle=True, eval_shuffle=True):
"""
Builds the Training and Eval DataLoaders
:param train_test_split: The ratio split of test to train data.
:param train_shuffle: (bool) True if you wish to shuffle the train_dataset.
:param eval_shuffle: (bool) True if you wish to shuffle the eval_dataset.
:return: train dataloader and evaluation dataloader.
"""
# 获取数据集的长度
dataset_len = len(self.dataset)
# 计算用于评估的数据集长度
eval_len = int(dataset_len * train_test_split)
# 计算用于训练的数据集长度
train_len = dataset_len - eval_len
# 随机划分数据集为训练集和评估集
train_dataset, eval_dataset = random_split(self.dataset, (train_len, eval_len))
# 创建训练数据加载器
train_loader = DataLoader(train_dataset, batch_size=self.train_batch_size, shuffle=train_shuffle)
# 创建评估数据加载器
eval_loader = DataLoader(eval_dataset, batch_size=self.eval_batch_size, shuffle=eval_shuffle)
# 记录日志信息
logging.info(f'''train_dataloader size: {len(train_loader.dataset)} | shuffle: {train_shuffle}
eval_dataloader size: {len(eval_loader.dataset)} | shuffle: {eval_shuffle}''')
# 返回训练数据加载器和评估数据加载器
return train_loader, eval_loader
def mask_tokens(self, inputs: torch.Tensor, mlm_probability=0.15, pad=True):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
# 复制输入作为标签
labels = inputs.clone()
# 创建概率矩阵,用于控制MASK的概率
probability_matrix = torch.full(labels.shape, mlm_probability)
# 获取特殊标记的掩码
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
# 根据特殊标记掩码更新概率矩阵
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
# 如果存在填充标记,将填充标记的位置概率设为0
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
# 生成MASK的索引
masked_indices = torch.bernoulli(probability_matrix).bool()
# 将非MASK的标记设为-100,用于计算损失
labels[~masked_indices] = -100
# 80%的情况下,用[MASK]替换MASK的输入标记
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10%的情况下,用随机词替换MASK的输入标记
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# 如果需要填充,将输入和标签填充到最大长度
if pad:
input_pads = self.tokenizer.max_len - inputs.shape[-1]
label_pads = self.tokenizer.max_len - labels.shape[-1]
inputs = F.pad(inputs, pad=(0, input_pads), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, pad=(0, label_pads), value=self.tokenizer.pad_token_id)
# 剩余的10%的情况下,保持MASK的输入标记不变
return inputs, labels
def _tokenize_input_ids(self, input_ids: list, pad_to_max_length: bool = True):
"""
Helper function to clean up the train and eval functions
:param input_ids: inputs to tokenize.
:param pad_to_max_length: Whether you want to pad the inputs to the tokenizer.max_len
:return: Tensor containing training data.
"""
# 将输入ID列表转换为张量
inputs = torch.cat(
[
self.tokenizer.encode(
input_ids[i],
add_special_tokens=True,
max_length=self.tokenizer.max_len,
pad_to_max_length=pad_to_max_length,
return_tensors='pt'
) \
for i in range(len(input_ids))
]
)
return inputs
def evaluate(self, dataloader):
"""
Runs through the provided dataloader with torch.no_grad()
:param dataloader: (torch.utils.data.DataLoader) Evaluation DataLoader
:return: None
"""
# 定义交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
# 如果有多个 GPU 并且模型不是 nn.DataParallel 类型,则使用 nn.DataParallel 包装模型
if self.n_gpu > 1 and not isinstance(self.model, nn.DataParallel):
self.model = nn.DataParallel(self.model)
# 将模型设置为评估模式
self.model.eval()
eval_loss = 0.0
perplexity = 0.0
eval_steps = 0
# 记录当前时间并输出评估信息
logging.info(f'{datetime.now()} | Evaluating...')
# 遍历数据加载器中的每个批次数据
for step, batch in tqdm(enumerate(dataloader), desc='Evaluating', leave=True, total=len(dataloader)):
# 遍历批次中的每个数据
for data in batch:
# 对输入数据进行标记化处理,并填充到最大长度
inputs = self._tokenize_input_ids(data, pad_to_max_length=True)
# 对输入数据进行掩码处理
inputs, labels = self.mask_tokens(inputs)
# 将输入数据和标签移动到设备上
inputs, labels = inputs.to(self.device), labels.to(self.device)
# 使用 torch.no_grad() 禁用梯度计算
with torch.no_grad():
# 获取模型的输出
output = self.model(inputs)
# 计算损失的掩码
loss_mx = labels != -100
output_ids = output[loss_mx].view(-1, self.tokenizer.vocab_size)
labels = labels[loss_mx].view(-1)
# 计算临时评估损失和困惑度
tmp_eval_loss = loss_fn(output_ids, labels)
tmp_perplexity = torch.exp(tmp_eval_loss)
# 如果有多个 GPU,则计算平均损失
if self.n_gpu > 1:
tmp_eval_loss = tmp_eval_loss.mean()
# 累加评估损失和困惑度
eval_loss += tmp_eval_loss.item()
perplexity += tmp_perplexity.item()
eval_steps += 1
# 计算平均评估损失和困惑度
eval_loss /= eval_steps
perplexity /= eval_steps
# 如果有 TensorBoard 写入器,则记录评估损失和困惑度
if self.tb_writer:
self.writer.add_scalar('Eval/Loss', eval_loss, eval_steps)
self.writer.close()
self.writer.add_scalar('Perplexity', perplexity, eval_steps)
self.writer.close()
# 输出评估信息
logging.info(f'{datetime.now()} | Step: {step} | Eval Loss: {eval_loss} | Perplexity: {perplexity}')
return None
# 如果当前脚本作为主程序运行
if __name__ == '__main__':
# 创建一个WikiDataset对象,指定数据集路径
dataset = WikiDataset(path='D:/data/enwiki')
# 从预训练的bert-base-cased模型中加载分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# 设置分词器的最大长度为128
tokenizer.max_len = 128
# 创建一个ReformerLM模型对象,设置相关参数
model = ReformerLM(
num_tokens=tokenizer.vocab_size,
dim=512,
depth=6,
heads=8,
max_seq_len=tokenizer.max_len,
causal=True
)
# 创建一个ReformerTrainer对象,传入数据集、模型、分词器等参数
trainer = ReformerTrainer(dataset, model, tokenizer, train_batch_size=32, eval_batch_size=32)
# 构建训练集和验证集的数据加载器
train_dataloader, eval_dataloader = trainer.build_dataloaders(train_test_split=0.90)
# 训练模型,返回训练后的模型
model = trainer.train(epochs=3,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
log_steps=10,
ckpt_steps=100,
ckpt_dir='./ckpts',
gradient_accumulation_steps=1)
# 保存训练后的模型到指定路径
torch.save(model, './ckpts/model.bin')
Reformer, the Efficient Transformer, in Pytorch
This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB
It includes LSH attention, reversible network, and chunking. It has been validated with an auto-regressive task (enwik8).
81k tokens with half precision
Install
$ pip install reformer_pytorch
Usage
A simple Reformer language model
# should fit in ~ 5gb - 8k tokens
import torch
from reformer_pytorch import ReformerLM
model = ReformerLM(
num_tokens= 20000,
dim = 1024,
depth = 12,
max_seq_len = 8192,
heads = 8,
lsh_dropout = 0.1,
ff_dropout = 0.1,
post_attn_dropout = 0.1,
layer_dropout = 0.1, # layer dropout from 'Reducing Transformer Depth on Demand' paper
causal = True, # auto-regressive or not
bucket_size = 64, # average size of qk per bucket, 64 was recommended in paper
n_hashes = 4, # 4 is permissible per author, 8 is the best but slower
emb_dim = 128, # embedding factorization for further memory savings
dim_head = 64, # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
ff_chunks = 200, # number of chunks for feedforward layer, make higher if there are memory issues
attn_chunks = 8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
num_mem_kv = 128, # persistent learned memory key values, from all-attention paper
full_attn_thres = 1024, # use full attention if context length is less than set value
reverse_thres = 1024, # turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value
use_scale_norm = False, # use scale norm from 'Transformers without tears' paper
use_rezero = False, # remove normalization and use rezero from 'ReZero is All You Need'
one_value_head = False, # use one set of values for all heads from 'One Write-Head Is All You Need'
weight_tie = False, # tie parameters of each layer for no memory per additional depth
weight_tie_embedding = False, # use token embedding for projection of output, some papers report better results
n_local_attn_heads = 2, # many papers suggest mixing local attention heads aids specialization and improves on certain tasks
pkm_layers = (4,7), # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
pkm_num_keys = 128, # defaults to 128, but can be increased to 256 or 512 as memory allows
use_full_attn = False # only turn on this flag to override and turn on full attention for all sequence lengths. for comparison with LSH to show that it is working
).cuda()
x = torch.randint(0, 20000, (1, 8192)).long().cuda()
y = model(x) # (1, 8192, 20000)
The Reformer (just a stack of reversible LSH attention)
# should fit in ~ 5gb - 8k embeddings
import torch
from reformer_pytorch import Reformer
model = Reformer(
dim = 512,
depth = 12,
heads = 8,
lsh_dropout = 0.1,
causal = True
).cuda()
x = torch.randn(1, 8192, 512).cuda()
y = model(x) # (1, 8192, 512)
Self Attention with LSH
import torch
from reformer_pytorch import LSHSelfAttention
attn = LSHSelfAttention(
dim = 128,
heads = 8,
bucket_size = 64,
n_hashes = 8,
causal = False
)
x = torch.randn(10, 1024, 128)
y = attn(x) # (10, 1024, 128)
LSH (locality sensitive hashing) Attention
import torch
from reformer_pytorch import LSHAttention
attn = LSHAttention(
bucket_size = 64,
n_hashes = 16,
causal = True
)
qk = torch.randn(10, 1024, 128)
v = torch.randn(10, 1024, 128)
out, attn, buckets = attn(qk, v) # (10, 1024, 128)
# attn contains the unsorted attention weights, provided return_attn is set to True (costly otherwise)
# buckets will contain the bucket number (post-argmax) of each token of each batch
Masking
This repository supports masks on the input sequence input_mask (b x i_seq), the context sequence context_mask (b x c_seq), as well as the rarely used full attention matrix itself input_attn_mask (b x i_seq x i_seq), all made compatible with LSH attention. Masks are made of booleans where False denotes masking out prior to the softmax.
The causal triangular mask is all taken care of for you if you set causal = True.
import torch
from reformer_pytorch import ReformerLM
CONTEXT_LEN = 512
SEQ_LEN = 8192
model = ReformerLM(
num_tokens= 20000,
dim = 1024,
depth = 1,
max_seq_len = SEQ_LEN,
ff_chunks = 8,
causal = True
)
c = torch.randn(1, CONTEXT_LEN, 1024)
x = torch.randint(0, 20000, (1, SEQ_LEN)).long()
i_mask = torch.ones(1, SEQ_LEN).bool()
c_mask = torch.ones(1, CONTEXT_LEN).bool()
y = model(x, keys = c, input_mask = i_mask, context_mask = c_mask)
# masking done correctly in LSH attention
Positional Embeddings
The default positional embedding uses rotary embeddings.
However, Aran has informed me that the Reformer team used axial position embeddings with great results on longer sequences.
You can turn on axial positional embedding and adjust the shape and dimension of the axial embeddings by following the instructions below.
import torch
from reformer_pytorch import ReformerLM
model = ReformerLM(
num_tokens= 20000,
dim = 1024,
depth = 12,
max_seq_len = 8192,
ff_chunks = 8,
attn_chunks = 2,
causal = True,
axial_position_emb = True, # set this to True
axial_position_shape = (128, 64), # the shape must multiply up to the max_seq_len (128 x 64 = 8192)
)
x = torch.randint(0, 20000, (1, 8192)).long()
y = model(x) # (1, 8192, 20000)
If you would rather use absolute positional embeddings, you can turn it on with absolute_position_emb = True flag on initialization.
Training
Since version 0.17.0, and some corrections to the reversible network, Reformer Pytorch is compatible with Microsoft's Deepspeed! If you have multiple local GPUs, you can follow the instructions / example here.
Examples
A full Reformer sequence → sequence, say translation
import torch
from reformer_pytorch import ReformerLM
DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096
encoder = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 1024,
depth = 12,
heads = 8,
max_seq_len = DE_SEQ_LEN,
fixed_position_emb = True,
return_embeddings = True # return output of last attention layer
).cuda()
decoder = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 1024,
depth = 12,
heads = 8,
max_seq_len = EN_SEQ_LEN,
fixed_position_emb = True,
causal = True
).cuda()
x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
yi = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long().cuda()
enc_keys = encoder(x) # (1, 4096, 1024)
yo = decoder(yi, keys = enc_keys) # (1, 4096, 20000)
A full Reformer image → caption
import torch
from torch.nn import Sequential
from torchvision import models
from reformer_pytorch import Reformer, ReformerLM
resnet = models.resnet50(pretrained=True)
resnet = Sequential(*list(resnet.children())[:-4])
SEQ_LEN = 4096
encoder = Reformer(
dim = 512,
depth = 6,
heads = 8,
max_seq_len = 4096
)
decoder = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 6,
heads = 8,
max_seq_len = SEQ_LEN,
causal = True
)
x = torch.randn(1, 3, 512, 512)
yi = torch.randint(0, 20000, (1, SEQ_LEN)).long()
visual_emb = resnet(x)
b, c, h, w = visual_emb.shape
visual_emb = visual_emb.view(1, c, h * w).transpose(1, 2) # nchw to nte
enc_keys = encoder(visual_emb)
yo = decoder(yi, keys = enc_keys) # (1, 4096, 20000)
Reformer Encoder Decoder Architecture
There is a bug in versions < 0.21.0. Please upgrade to at least the version specified for the working encoder / decoder Reformer.
By popular demand, I have coded up a wrapper that removes a lot of the manual work in writing up a generic Reformer encoder / decoder architecture. To use, you would import the ReformerEncDec class. Encoder keyword arguments would be passed with a enc_ prefix and decoder keyword arguments with dec_. The model dimension (dim) must be prefix free and will be shared between encoder and decoder. The framework will also take care of passing the encoder input mask to the decoder context mask, unless explicitly overridden.
import torch
from reformer_pytorch import ReformerEncDec
DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096
enc_dec = ReformerEncDec(
dim = 512,
enc_num_tokens = 20000,
enc_depth = 6,
enc_max_seq_len = DE_SEQ_LEN,
dec_num_tokens = 20000,
dec_depth = 6,
dec_max_seq_len = EN_SEQ_LEN
).cuda()
train_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
train_seq_out = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long().cuda()
input_mask = torch.ones(1, DE_SEQ_LEN).bool().cuda()
loss = enc_dec(train_seq_in, train_seq_out, return_loss = True, enc_input_mask = input_mask)
loss.backward()
# learn
# evaluate with the following
eval_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
eval_seq_out_start = torch.tensor([[0.]]).long().cuda() # assume 0 is id of start token
samples = enc_dec.generate(eval_seq_in, eval_seq_out_start, seq_len = EN_SEQ_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= 1024) decode the tokens
Product Key Memory
To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)
You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates
Customizing Feedforward
By default, the activation function is GELU. If you would like an alternative activation function, you can pass in the class to the keyword ff_activation.
import torch
from reformer_pytorch import ReformerLM
from torch import nn
model = ReformerLM(
num_tokens= 20000,
dim = 512,
depth = 6,
max_seq_len = 8192,
ff_chunks = 8,
ff_dropout = 0.1,
ff_mult = 6,
ff_activation = nn.LeakyReLU,
ff_glu = True # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
)
x = torch.randint(0, 20000, (1, 8192)).long()
y = model(x) # (1, 8192, 20000)
Research
To access the attention weights and bucket distribution, simply wrap the instantiated model with the Recorder wrapper class.
import torch
from reformer_pytorch import Reformer, Recorder
model = Reformer(
dim = 512,
depth = 12,
max_seq_len = 8192,
heads = 8,
lsh_dropout = 0.1,
causal = True
).cuda()
model = Recorder(model)
x = torch.randn(1, 8192, 512).cuda()
y = model(x)
model.recordings[0] # a list of attention weights and buckets for the first forward pass
model.turn_off() # stop recording
model.turn_on() # start recording
model.clear() # clear the recordings
model = model.eject() # recover the original model and remove all listeners
Additional Helpers
Reformer comes with a slight drawback that the sequence must be neatly divisible by the bucket size * 2. I have provided a small helper tool that can help you auto-round the sequence length to the next best multiple.
import torch
from reformer_pytorch import ReformerLM, Autopadder
model = ReformerLM(
num_tokens= 20000,
dim = 1024,
depth = 12,
max_seq_len = 8192,
heads = 8,
lsh_dropout = 0.1,
causal = True,
bucket_size = 63, # odd bucket size
num_mem_kv = 77 # odd memory key length
).cuda()
model = Autopadder(model)
SEQ_LEN = 7777 # odd sequence length
keys = torch.randn(1, 137, 1024) # odd keys length
x = torch.randint(0, 20000, (1, SEQ_LEN)).long().cuda()
y = model(x, keys = keys) # (1, 7777, 20000)
Helpers for training auto-regressive models
A lot of users are only interested in an auto-regressive language model (like GPT-2). Here is a training wrapper to make it easy to both train and evaluate on arbitrarily lengthed sequences of encoded tokens. You will have to take care of the encoding and decoding yourself.
import torch
from torch import randint
from reformer_pytorch import ReformerLM
from reformer_pytorch.generative_tools import TrainingWrapper
model = ReformerLM(
num_tokens= 20000,
dim = 1024,
depth = 12,
max_seq_len = 4096,
lsh_dropout = 0.1,
causal = True,
full_attn_thres = 1024
)
# 0 is used for padding and no loss to be calculated on it
model = TrainingWrapper(model, ignore_index = 0, pad_value = 0)
# the wrapper can handle evenly packed sequences
x_train = randint(0, 20000, (3, 357))
# or if you have a list of uneven sequences, it will be padded for you
x_train = [
randint(0, 20000, (120,)),
randint(0, 20000, (253,)),
randint(0, 20000, (846,))
]
# when training, set return_loss equal to True
model.train()
loss = model(x_train, return_loss = True)
loss.backward()
# when evaluating, just use the generate function, which will default to top_k sampling with temperature of 1.
initial = torch.tensor([[0]]).long() # assume 0 is start token
sample = model.generate(initial, 100, temperature=1., filter_thres = 0.9, eos_token = 1) # assume end token is 1, or omit and it will sample up to 100
print(sample.shape) # (1, <=100) token ids
Issues
Andrea has uncovered that using O2 optimization level when training with mixed precision can lead to instability. Please use O1 instead, which can be set with the amp_level in Pytorch Lightning, or opt_level in Nvidia's Apex library.
Alternatives
- Routing Transformer - https://github.com/lucidrains/routing-transformer
- Sinkhorn Transformer - https://github.com/lucidrains/sinkhorn-transformer
- Performer - https://github.com/lucidrains/performer-pytorch
- Linear Transformer - https://github.com/lucidrains/linear-attention-transformer/
- Compressive Transformer - https://github.com/lucidrains/compressive-transformer-pytorch
Citations
@inproceedings{kitaev2020reformer,
title = {Reformer: The Efficient Transformer},
author = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@article{DBLP:journals/corr/abs-1907-01470,
author = {Sainbayar Sukhbaatar and
Edouard Grave and
Guillaume Lample and
Herv{\'{e}} J{\'{e}}gou and
Armand Joulin},
title = {Augmenting Self-attention with Persistent Memory},
journal = {CoRR},
volume = {abs/1907.01470},
year = {2019},
url = {http://arxiv.org/abs/1907.01470}
}
@article{1910.05895,
author = {Toan Q. Nguyen and Julian Salazar},
title = {Transformers without Tears: Improving the Normalization of Self-Attention},
year = {2019},
eprint = {arXiv:1910.05895},
doi = {10.5281/zenodo.3525484},
}
@inproceedings{fan2020reducing,
title = {Reducing Transformer Depth on Demand with Structured Dropout},
author = {Angela Fan and Edouard Grave and Armand Joulin},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=SylO2yStDr}
}
@article{Shazeer2019FastTD,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam Shazeer},
journal = {ArXiv},
year = {2019},
volume = {abs/1911.02150}
}
@misc{shazeer2020glu,
title = {GLU Variants Improve Transformer},
author = {Noam Shazeer},
year = {2020},
url = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
title = {Efficient Content-Based Sparse Attention with Routing Transformers},
author = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
year = {2020},
url = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@misc{bachlechner2020rezero,
title = {ReZero is All You Need: Fast Convergence at Large Depth},
author = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
year = {2020},
url = {https://arxiv.org/abs/2003.04887}
}
@misc{lample2019large,
title = {Large Memory Layers with Product Keys},
author = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
year = {2019},
eprint = {1907.05242},
archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
title = {Low-Rank Bottleneck in Multi-head Attention Models},
author = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
year = {2020},
eprint = {2002.07028}
}
@misc{dong2021attention,
title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth},
author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
year = {2021},
eprint = {2103.03404}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@misc{vaswani2017attention,
title = {Attention Is All You Need},
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
year = {2017},
eprint = {1706.03762},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
.\lucidrains\reformer-pytorch\reformer_pytorch\autopadder.py
# 导入数学库和 PyTorch 库
import math
import torch
from torch import nn
import torch.nn.functional as F
# 导入自定义模块
from reformer_pytorch.reformer_pytorch import Reformer, ReformerLM, LSHSelfAttention
# 定义函数,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, seqlen, multiple, dim=-1):
# 计算倍数
m = seqlen / multiple
# 如果是整数倍则直接返回张量
if m.is_integer():
return tensor
# 计算需要填充的长度
remainder = math.ceil(m) * multiple - seqlen
# 计算填充的偏移量
pad_offset = (0,) * (-1 - dim) * 2
# 对张量进行填充
return F.pad(tensor, (*pad_offset, 0, remainder), value=0)
# 定义自动填充器类
class Autopadder(nn.Module):
def __init__(self, net):
super().__init__()
# 检查输入的网络类型是否符合要求
assert isinstance(net, (LSHSelfAttention, Reformer, ReformerLM)), 'only modules LSHSelfAttention, Reformer, ReformerLM accepted'
self.net = net
# 获取 Reformer 对象
reformer = net.reformer if isinstance(net, ReformerLM) else net
# 根据网络类型确定填充的维度
self.pad_dim = -1 if isinstance(net, ReformerLM) else -2
# 获取 Reformer 的参数
self.bucket_size = reformer.bucket_size
self.num_mem_kv = reformer.num_mem_kv
self.full_attn_thres = reformer.full_attn_thres
def forward(self, x, **kwargs):
# 获取输入张量的形状信息
b, t, m, device = *x.shape[:2], self.num_mem_kv, x.device
# 获取关键信息和输入掩码
keys = kwargs.get('keys')
input_mask = kwargs.get('input_mask')
input_attn_mask = kwargs.get('input_attn_mask')
# 计算关键信息的长度
k_len = 0 if keys is None else keys.shape[1]
# 计算序列长度
seqlen = t + m + k_len
# 如果序列长度超过全局注意力阈值
if seqlen > self.full_attn_thres:
# 如果输入掩码为空,则创建全为 True 的掩码
if input_mask is None:
input_mask = torch.full((b, t), True, device=x.device, dtype=torch.bool)
# 对输入张量进行填充
x = pad_to_multiple(x, seqlen, self.bucket_size * 2, dim=self.pad_dim)
# 如果输入掩码不为空,则对其进行填充
if input_mask is not None:
new_mask = F.pad(input_mask, (0, x.shape[1] - input_mask.shape[1]), value=False)
kwargs.update(input_mask=new_mask)
# 如果输入注意力掩码不为空,则对其进行填充
if input_attn_mask is not None:
offset = x.shape[1] - input_attn_mask.shape[1]
new_mask = F.pad(input_attn_mask, (0, offset, 0, offset), value=False)
kwargs.update(input_attn_mask=new_mask)
# 对输入进行网络前向传播
out = self.net(x, **kwargs)
# 返回前 t 个时间步的输出
return out[:, 0:t]
.\lucidrains\reformer-pytorch\reformer_pytorch\generative_tools.py
# 导入必要的库
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from reformer_pytorch.reformer_pytorch import ReformerLM
from reformer_pytorch.autopadder import Autopadder
# 定义函数用于根据概率阈值选择最高概率的元素
def top_p(logits, thres = 0.9):
# 对logits进行降序排序
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# 计算累积概率
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 根据阈值确定要移除的元素
sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# 将超过阈值的元素设置为负无穷
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# 定义函数用于根据概率阈值选择前k个元素
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个包装类,用于训练模型
class TrainingWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
assert isinstance(net, ReformerLM), 'generative trainer wrapper can only accept ReformerLM class'
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = Autopadder(net)
self.max_seq_len = net.max_seq_len
@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
input_mask = kwargs.pop('input_mask', None)
if input_mask is None:
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
input_mask = input_mask[:, -self.max_seq_len:]
logits = self.net(x, input_mask=input_mask, **kwargs)[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
input_mask = F.pad(input_mask, (0, 1), value=True)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
def forward(self, x, return_loss = False, **kwargs):
pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)
if not return_loss:
if not isinstance(x, torch.Tensor):
x = pad(x)
return self.net(x, **kwargs)
if isinstance(x, torch.Tensor):
xi = x[:, :-1]
xo = x[:, 1:]
else:
xi = pad(list(map(lambda t: t[:-1], x)))
xo = pad(list(map(lambda t: t[1:], x)))
out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\reformer-pytorch\reformer_pytorch\recorder.py
# 导入需要的模块
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHAttention, LSHSelfAttention
from collections import defaultdict
# 定义 Recorder 类,继承自 nn.Module
class Recorder(nn.Module):
# 初始化函数
def __init__(self, net):
super().__init__()
self.iter = 0
self.recordings = defaultdict(list) # 使用 defaultdict 创建一个空列表的字典
self.net = net
self.on = True
self.ejected = False
# 弹出函数
def eject(self):
self.ejected = True
self.clear()
self.unwire()
return self.net
# 连接函数
def wire(self):
# 遍历网络中的模块,如果是 LSHAttention 类型,则设置 _return_attn 为 True
for module in self.net.modules():
if isinstance(module, LSHAttention):
module._return_attn = True
# 如果是 LSHSelfAttention 类型,则设置 callback 为 self.record 函数
if isinstance(module, LSHSelfAttention):
module.callback = self.record
# 断开连接函数
def unwire(self):
# 遍历网络中的模块,如果是 LSHAttention 类型,则设置 _return_attn 为 False
for module in self.net.modules():
if isinstance(module, LSHAttention):
module._return_attn = False
# 如果是 LSHSelfAttention 类型,则设置 callback 为 None
if isinstance(module, LSHSelfAttention):
module.callback = None
# 打开记录功能
def turn_on(self):
self.on = True
# 关闭记录功能
def turn_off(self):
self.on = False
# 清空记录
def clear(self):
del self.recordings
self.recordings = defaultdict(list) # 使用 defaultdict 创建一个空列表的字典
self.iter = 0
# 记录函数
def record(self, attn, buckets):
if not self.on: return
data = {'attn': attn.detach().cpu(), 'buckets': buckets.detach().cpu()}
self.recordings[self.iter].append(data)
# 前向传播函数
def forward(self, x, **kwargs):
assert not self.ejected, 'Recorder has already been ejected and disposed'
if self.on:
self.wire()
out = self.net(x, **kwargs)
self.iter += 1
self.unwire()
return out
.\lucidrains\reformer-pytorch\reformer_pytorch\reformer_enc_dec.py
# 导入 re 模块,用于正则表达式操作
import re
# 从 torch 模块中导入 nn 类
from torch import nn
# 从 reformer_pytorch 模块中导入 ReformerLM 类
from reformer_pytorch.reformer_pytorch import ReformerLM
# 从 reformer_pytorch 模块中导入 TrainingWrapper 类
from reformer_pytorch.generative_tools import TrainingWrapper
# 定义编码器前缀
ENC_PREFIX = 'enc_'
# 定义解码器前缀
DEC_PREFIX = 'dec_'
# 根据条件将字典按键分组
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
return bool(re.match(f'^{prefix}', str))
# 根据键前缀将字典分组
def group_by_key_prefix(prefix, d):
return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
# 根据键前缀并移除前缀将字典分组
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
# 提取编码器和解码器的关键字参数
def extract_enc_dec_kwargs(kwargs):
enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
return enc_kwargs, dec_kwargs, kwargs
# 提取并设置编码器和解码器的关键字参数
def extract_and_set_enc_dec_kwargs(kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
if 'input_mask' in enc_kwargs:
dec_kwargs.setdefault('context_mask', enc_kwargs['input_mask'])
return enc_kwargs, dec_kwargs, kwargs
# 定义 ReformerEncDec 类,继承自 nn.Module 类
class ReformerEncDec(nn.Module):
def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs):
super().__init__()
enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
# 断言不能手动设置返回嵌入标志
assert 'return_embedding' not in enc_kwargs, 'you cannot manually set the return embeddings flag for the encoder'
# 断言必须为编码器和解码器设置维度
assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'
# 设置编码器和解码器的维度
enc_kwargs['dim'] = dec_kwargs['dim'] = dim
enc_kwargs['return_embeddings'] = True
dec_kwargs['causal'] = True
# 设置编码器和解码器的 bucket_size
enc_kwargs.setdefault('bucket_size', 64)
dec_kwargs.setdefault('bucket_size', enc_kwargs['bucket_size'] * 2)
# 创建 ReformerLM 编码器和解码器对象
enc = ReformerLM(**enc_kwargs)
dec = ReformerLM(**dec_kwargs)
# 使用 TrainingWrapper 封装编码器和解码器对象
self.enc = TrainingWrapper(enc, ignore_index = ignore_index, pad_value = pad_value)
self.dec = TrainingWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)
# 生成序列
def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
enc_keys = self.enc(seq_in, **enc_kwargs)
return self.dec.generate(seq_out_start, seq_len, keys = enc_keys, **{**dec_kwargs, **kwargs})
# 前向传播
def forward(self, seq_in, seq_out, return_loss = False, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
enc_keys = self.enc(seq_in, **enc_kwargs)
return self.dec(seq_out, return_loss = return_loss, keys = enc_keys, **dec_kwargs)
.\lucidrains\reformer-pytorch\reformer_pytorch\reformer_pytorch.py
# 导入数学库
import math
# 导入 PyTorch 库
import torch
import torch.nn as nn
# 从 torch.nn 模块导入 Identity 类
from torch.nn import Identity
# 导入 torch.nn.functional 模块
import torch.nn.functional as F
# 从 torch.autograd 模块导入 Function 类
from torch.autograd import Function
# 从 functools 模块导入 partial、reduce、wraps 函数
from functools import partial, reduce, wraps
# 从 itertools 模块导入 chain 函数
from itertools import chain
# 从 operator 模块导入 mul 函数
from operator import mul
# 导入自定义模块
from local_attention import LocalAttention
from axial_positional_embedding import AxialPositionalEmbedding
from product_key_memory import PKM
from reformer_pytorch.reversible import ReversibleSequence
# 导入 einops 库
from einops import rearrange, repeat
# 常量定义
# 用于自注意力机制的特殊值,用于半精度计算
TOKEN_SELF_ATTN_VALUE = -5e4
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 对两个张量进行排序,并返回排序后的值和对应的张量
def sort_key_val(t1, t2, dim=-1):
values, indices = t1.sort(dim=dim)
t2 = t2.expand_as(t1)
return values, t2.gather(dim, indices)
# 在指定维度上对张量进行批量索引选择
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
# 对输入进行分块处理
def process_inputs_chunk(fn, chunks=1, dim=0):
def inner_fn(*args, **kwargs):
keys, values, len_args = kwargs.keys(), kwargs.values(), len(args)
chunked_args = list(zip(*map(lambda x: x.chunk(chunks, dim=dim), list(args) + list(values))))
all_args = map(lambda x: (x[:len_args], dict(zip(keys, x[len_args:]))), chunked_args)
outputs = [fn(*c_args, **c_kwargs) for c_args, c_kwargs in all_args]
return tuple(map(lambda x: torch.cat(x, dim=dim), zip(*outputs)))
return inner_fn
# 对张量进行分块求和
def chunked_sum(tensor, chunks=1):
*orig_size, last_dim = tensor.shape
tensor = tensor.reshape(-1, last_dim)
summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
return torch.cat(summed_tensors, dim=0).reshape(orig_size)
# 返回默认值
def default(val, default_val):
return default_val if val is None else val
# 将输入转换为元组
def cast_tuple(x):
return x if isinstance(x, tuple) else (x,)
# 返回张量的最大负值
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# 缓存函��的计算结果
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, **kwargs):
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
# 缓存方法的计算结果
def cache_method_decorator(cache_attr, cache_namespace, reexecute=False):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
namespace_str = str(default(key_namespace, ''))
_cache = getattr(self, cache_attr)
_keyname = f'{cache_namespace}:{namespace_str}'
if fetch:
val = _cache[_keyname]
if reexecute:
fn(self, *args, **kwargs)
else:
val = fn(self, *args, **kwargs)
if set_cache:
setattr(self, cache_attr, {**_cache, **{_keyname: val}})
return val
return wrapper
return inner_fn
# 在指定维度上扩展张量的维度
def expand_dim(dim, k, t):
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
# 合并张量的维度
def merge_dims(ind_from, ind_to, tensor):
shape = list(tensor.shape)
arr_slice = slice(ind_from, ind_to + 1)
shape[arr_slice] = [reduce(mul, shape[arr_slice])]
return tensor.reshape(*shape)
# 在指定维度上将张量拆分为两部分
def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
# 辅助类
# 始终返回固定值的模块
class Always(nn.Module):
def __init__(self, val):
super().__init__()
self.val = val
def forward(self, *args, **kwargs):
return self.val
# 矩阵乘法模块
class MatrixMultiply(nn.Module):
def __init__(self, tensor, transpose=False, normalize=False):
super().__init__()
self.tensor = tensor
self.transpose = transpose
self.normalize = normalize
# 定义一个前向传播函数,接受输入张量 x
def forward(self, x):
# 将类中的张量赋值给变量 tensor
tensor = self.tensor
# 如果需要进行标准化操作
if self.normalize:
# 对张量进行标准化操作,沿着最后一个维度进行标准化
tensor = F.normalize(tensor, dim=-1)
# 如果需要进行转置操作
if self.transpose:
# 对张量进行转置操作
tensor = tensor.t()
# 返回输入张量与处理后的张量的矩阵乘法结果
return x @ tensor
# 定义 ReZero 类,继承自 nn.Module
class ReZero(nn.Module):
# 初始化函数,接受一个函数 fn 作为参数
def __init__(self, fn):
super().__init__()
# 创建一个可学习的参数 g,初始化为零
self.g = nn.Parameter(torch.zeros(1))
# 将传入的函数 fn 赋值给 self.fn
self.fn = fn
# 前向传播函数,接受输入 x 和其他关键字参数
def forward(self, x, **kwargs):
# 返回经过函数 fn 处理后的结果乘以参数 g
return self.fn(x, **kwargs) * self.g
# 定义 ScaleNorm 类,继承自 nn.Module
class ScaleNorm(nn.Module):
# 初始化函数,接受维度 dim 和一个小数 eps 作为参数
def __init__(self, dim, eps=1e-5):
super().__init__()
# 创建一个可学习的参数 g,初始化为一
self.g = nn.Parameter(torch.ones(1))
# 将传入的 eps 赋值给 self.eps
self.eps = eps
# 前向传播函数,接受输入 x
def forward(self, x):
# 计算 x 在指定维度上的范数,并限制最小值为 eps
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
# 返回 x 除以范数后乘以参数 g 的结果
return x / n * self.g
# 定义 PreNorm 类,继承自 nn.Module
class PreNorm(nn.Module):
# 初始化函数,接受一个规范化类 norm_class、维度 dim 和一个函数 fn 作为参数
def __init__(self, norm_class, dim, fn):
super().__init__()
# 创建一个 norm_class 类型的规范化对象,并赋值给 self.norm
self.norm = norm_class(dim)
# 将传入的函数 fn 赋值给 self.fn
self.fn = fn
# 前向传播函数,接受输入 x 和其他关键字参数
def forward(self, x, **kwargs):
# 对输入 x 进行规范化
x = self.norm(x)
# 返回经过函数 fn 处理后的结果
return self.fn(x, **kwargs)
# 定义 Chunk 类,继承自 nn.Module
class Chunk(nn.Module):
# 初始化函数,接受块数 chunks、函数 fn 和沿着的维度 along_dim 作为参数
def __init__(self, chunks, fn, along_dim=-1):
super().__init__()
# 将 along_dim 赋值给 self.dim
self.dim = along_dim
# 将 chunks 和 fn 赋值给 self.chunks 和 self.fn
self.chunks = chunks
self.fn = fn
# 前向传播函数,接受输入 x 和其他关键字参数
def forward(self, x, **kwargs):
# 如果 chunks 等于 1,则直接返回经过函数 fn 处理后的结果
if self.chunks == 1:
return self.fn(x, **kwargs)
# 将输入 x 沿着维度 self.dim 切分成多个块
chunks = x.chunk(self.chunks, dim=self.dim)
# 对每个块应用函数 fn,并在指定维度上拼接结果
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim)
# LSH attention 类,实现了论文中描述的 LSH 注意力机制
class LSHAttention(nn.Module):
# 初始化函数,接受多个参数设置
def __init__( self,
dropout=0.,
bucket_size=64,
n_hashes=8,
causal=False,
allow_duplicate_attention=True,
attend_across_buckets=True,
rehash_each_round=True,
drop_for_hash_rate=0.0,
random_rotations_per_head=False,
return_attn=False):
super().__init__()
# 如果 dropout 大于等于 1,则抛出异常
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
# 创建一个 dropout 层,用于在训练时随机丢弃部分数据
self.dropout = nn.Dropout(dropout)
self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)
# 确保每轮重新哈希或允许重复注意力的设置
assert rehash_each_round or allow_duplicate_attention, (
'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
' is not implemented.')
# 设置是否是因果关系
self.causal = causal
self.bucket_size = bucket_size
self.n_hashes = n_hashes
self._allow_duplicate_attention = allow_duplicate_attention
self._attend_across_buckets = attend_across_buckets
self._rehash_each_round = rehash_each_round
self._random_rotations_per_head = random_rotations_per_head
# 是否返回注意力矩阵
self._return_attn = return_attn
# 用于缓存可逆网络的桶,作者报告这样可以使 Reformer 在深度上工作
self._cache = {}
# 缓存方法装饰器,用于缓存 buckets
@cache_method_decorator('_cache', 'buckets', reexecute=True)
# 对输入的向量进行哈希处理,将其映射到指定数量的桶中
def hash_vectors(self, n_buckets, vecs):
# 获取输入向量的批量大小
batch_size = vecs.shape[0]
# 获取输入向量所在设备
device = vecs.device
# 参考论文 https://arxiv.org/pdf/1509.02897.pdf
# 为每一轮哈希采样不同的随机旋转,以减少哈希失配的概率
assert n_buckets % 2 == 0
rot_size = n_buckets
rotations_shape = (
batch_size if self._random_rotations_per_head else 1,
vecs.shape[-1],
self.n_hashes if self._rehash_each_round else 1,
rot_size // 2)
# 生成随机旋转矩阵
random_rotations = torch.randn(rotations_shape, dtype=vecs.dtype, device=device).expand(batch_size, -1, -1, -1)
# 对输入向量进行哈希前的丢弃处理
dropped_vecs = self.dropout_for_hash(vecs)
# 对丢弃后的向量进行旋转操作
rotated_vecs = torch.einsum('btf,bfhi->bhti', dropped_vecs, random_rotations)
if self._rehash_each_round:
# 如果每轮都重新哈希,则将旋转后的向量进行拼接
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
# 获取每个向量对应的桶索引
buckets = torch.argmax(rotated_vecs, dim=-1)
else:
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
# 在这种配置下,将每个项目映射到前 self.n_hashes 个桶中
rotated_vecs = torch.squeeze(rotated_vecs, 1)
bucket_range = torch.arange(rotated_vecs.shape[-1], device=device)
bucket_range = torch.reshape(bucket_range, (1, -1))
bucket_range = bucket_range.expand_as(rotated_vecs)
# 对旋转后的向量进行排序,获取对应的桶索引
_, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1)
# 调整桶索引的维度
buckets = buckets[... , -self.n_hashes:].transpose(1, 2)
# 每个哈希轮次的桶索引现在是 (self.n_hashes, seq_len) 的形状。接下来添加偏移量,以避免不同哈希轮次的桶号重叠
offsets = torch.arange(self.n_hashes, device=device)
offsets = torch.reshape(offsets * n_buckets, (1, -1, 1))
buckets = torch.reshape(buckets + offsets, (batch_size, -1,))
# 返回最终的桶索引
return buckets
# 定义全连接的注意力机制类
class FullQKAttention(nn.Module):
def __init__(self, causal = False, dropout = 0.):
super().__init__()
self.causal = causal
self.dropout = nn.Dropout(dropout)
def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask = None, **kwargs):
b, seq_len, dim = qk.shape
query_len = default(query_len, seq_len)
t = query_len
q = qk[:, 0:query_len]
qk = F.normalize(qk, 2, dim=-1).type_as(q)
dot = torch.einsum('bie,bje->bij', q, qk) * (dim ** -0.5)
# qk attention requires tokens not attend to self
i = torch.arange(t)
dot[:, i, i] = TOKEN_SELF_ATTN_VALUE
masked_value = max_neg_value(dot)
# Input mask for padding in variable lengthed sequences
if input_mask is not None:
mask = input_mask[:, 0:query_len, None] * input_mask[:, None, :]
mask = F.pad(mask, (0, seq_len - mask.shape[-1]), value=True)
dot.masked_fill_(~mask, masked_value)
# Mask for post qk attention logits of the input sequence
if input_attn_mask is not None:
input_attn_mask = F.pad(input_attn_mask, (0, seq_len - input_attn_mask.shape[-1]), value=True)
dot.masked_fill_(~input_attn_mask, masked_value)
if self.causal:
i, j = torch.triu_indices(t, t, 1)
dot[:, i, j] = masked_value
dot = dot.softmax(dim=-1)
dot = self.dropout(dot)
out = torch.einsum('bij,bje->bie', dot, v)
return out, dot, torch.empty(0)
# 共享的 qk 注意力机制,使用全局或 LSH 注意力机制
class LSHSelfAttention(nn.Module):
def __init__(self, dim, heads = 8, bucket_size = 64, n_hashes = 8, causal = False, dim_head = None, attn_chunks = 1, random_rotations_per_head = False, attend_across_buckets = True, allow_duplicate_attention = True, num_mem_kv = 0, one_value_head = False, use_full_attn = False, full_attn_thres = None, return_attn = False, post_attn_dropout = 0., dropout = 0., n_local_attn_heads = 0, **kwargs):
super().__init__()
assert dim_head or (dim % heads) == 0, 'dimensions must be divisible by number of heads'
assert n_local_attn_heads < heads, 'local attention heads must be less than number of heads'
dim_head = default(dim_head, dim // heads)
dim_heads = dim_head * heads
self.dim = dim
self.heads = heads
self.dim_head = dim_head
self.attn_chunks = default(attn_chunks, 1)
self.v_head_repeats = (heads if one_value_head else 1)
v_dim = dim_heads // self.v_head_repeats
self.toqk = nn.Linear(dim, dim_heads, bias = False)
self.tov = nn.Linear(dim, v_dim, bias = False)
self.to_out = nn.Linear(dim_heads, dim)
self.bucket_size = bucket_size
self.lsh_attn = LSHAttention(bucket_size=bucket_size, n_hashes=n_hashes, causal=causal, random_rotations_per_head=random_rotations_per_head, attend_across_buckets = attend_across_buckets, allow_duplicate_attention = allow_duplicate_attention, return_attn = return_attn, dropout = dropout, **kwargs)
self.full_attn = FullQKAttention(causal=causal, dropout=dropout)
self.post_attn_dropout = nn.Dropout(post_attn_dropout)
self.use_full_attn = use_full_attn
self.full_attn_thres = default(full_attn_thres, bucket_size)
self.num_mem_kv = num_mem_kv
self.mem_kv = nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) if num_mem_kv > 0 else None
self.n_local_attn_heads = n_local_attn_heads
self.local_attn = LocalAttention(window_size=bucket_size * 2, causal=causal, dropout=dropout, shared_qk=True, look_forward=(1 if not causal else 0))
self.callback = None
# 定义前向传播函数,接受输入 x 和其他可选参数
def forward(self, x, keys = None, input_mask = None, input_attn_mask = None, context_mask = None, pos_emb = None, **kwargs):
# 获取输入 x 的设备和数据类型
device, dtype = x.device, x.dtype
# 获取输入 x 的形状信息
b, t, e, h, dh, m, l_h = *x.shape, self.heads, self.dim_head, self.num_mem_kv, self.n_local_attn_heads
# 初始化记忆键值对
mem_kv = default(self.mem_kv, torch.empty(b, 0, e, dtype=dtype, device=device))
mem = mem_kv.expand(b, m, -1)
# 初始化键
keys = default(keys, torch.empty(b, 0, e, dtype=dtype, device=device))
c = keys.shape[1]
# 计算键值对的长度
kv_len = t + m + c
# 判断是否使用全局注意力
use_full_attn = self.use_full_attn or kv_len <= self.full_attn_thres
# 将输入 x、记忆和键连接起来
x = torch.cat((x, mem, keys), dim=1)
# 将输入 x 转换为查询和键
qk = self.toqk(x)
# 将输入 x 转换为值
v = self.tov(x)
# 复制值以匹配头数
v = v.repeat(1, 1, self.v_head_repeats)
# 定义合并头部的函数
def merge_heads(v):
return v.view(b, kv_len, h, -1).transpose(1, 2)
# 定义分割头部的函数
def split_heads(v):
return v.view(b, h, t, -1).transpose(1, 2).contiguous()
# 合并批次和头部维度
merge_batch_and_heads = partial(merge_dims, 0, 1)
# 对查询和键值对进行头部合并
qk, v = map(merge_heads, (qk, v))
# 判断是否有局部注意力
has_local = l_h > 0
lsh_h = h - l_h
# 分割索引函数
split_index_fn = partial(split_at_index, 1, l_h)
(lqk, qk), (lv, v) = map(split_index_fn, (qk, v))
lqk, qk, lv, v = map(merge_batch_and_heads, (lqk, qk, lv, v))
# 初始化掩码字典
masks = {}
# 如果存在输入掩码或上下文掩码
if input_mask is not None or context_mask is not None:
default_mask = torch.tensor([True], device=device)
i_mask = default(input_mask, default_mask.expand(b, t))
m_mask = default_mask.expand(b, m)
c_mask = default(context_mask, default_mask.expand(b, c))
mask = torch.cat((i_mask, m_mask, c_mask), dim=1)
mask = merge_batch_and_heads(expand_dim(1, lsh_h, mask))
masks['input_mask'] = mask
# 如果存在输入注意力掩码
if input_attn_mask is not None:
input_attn_mask = merge_batch_and_heads(expand_dim(1, lsh_h, input_attn_mask))
masks['input_attn_mask'] = input_attn_mask
# 根据是否使用全局注意力选择不同的注意力函数
attn_fn = self.lsh_attn if not use_full_attn else self.full_attn
partial_attn_fn = partial(attn_fn, query_len = t, pos_emb = pos_emb, **kwargs)
attn_fn_in_chunks = process_inputs_chunk(partial_attn_fn, chunks = self.attn_chunks)
# 执行注意力函数
out, attn, buckets = attn_fn_in_chunks(qk, v, **masks)
# 如果存在回调函数,则执行回调
if self.callback is not None:
self.callback(attn.reshape(b, lsh_h, t, -1), buckets.reshape(b, lsh_h, -1))
# 如果存在局部注意力
if has_local:
lqk, lv = lqk[:, :t], lv[:, :t]
local_out = self.local_attn(lqk, lqk, lv, input_mask=input_mask)
local_out = local_out.reshape(b, l_h, t, -1)
out = out.reshape(b, lsh_h, t, -1)
out = torch.cat((local_out, out), dim=1)
# 分割头部并重塑输出
out = split_heads(out).view(b, t, -1)
out = self.to_out(out)
return self.post_attn_dropout(out)
# 定义 GELU 激活函数类,继承自 nn.Module
class GELU_(nn.Module):
# 前向传播函数,接受输入 x
def forward(self, x):
# 使用 GELU 激活函数计算输出
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
# 如果 nn 模块中存在 GELU 类,则使用 nn.GELU,否则使用自定义的 GELU_ 类
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
# 定义前馈神经网络类 FeedForward,继承自 nn.Module
class FeedForward(nn.Module):
# 初始化函数,接受维度 dim、倍数 mult、dropout 概率、激活函数 activation 和 glu 标志
def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False):
super().__init__()
# 设置激活函数为默认值 GELU
activation = default(activation, GELU)
self.glu = glu
# 第一层全连接层,输入维度为 dim,输出维度为 dim * mult * (2 if glu else 1)
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
# 激活函数层
self.act = activation()
# Dropout 层
self.dropout = nn.Dropout(dropout)
# 第二层全连接层,输入维度为 dim * mult,输出维度为 dim
self.w2 = nn.Linear(dim * mult, dim)
# 前向传播函数,接受输入 x 和其他参数
def forward(self, x, **kwargs):
# 如果不使用 glu
if not self.glu:
# 进行第一层全连接层和激活函数的计算
x = self.w1(x)
x = self.act(x)
else:
# 如果使用 glu,进行特殊处理
x, v = self.w1(x).chunk(2, dim=-1)
x = self.act(x) * v
# Dropout
x = self.dropout(x)
# 第二层全连接层计算结果
x = self.w2(x)
return x
# 绝对位置嵌入类,继承自 nn.Module
class AbsolutePositionalEmbedding(nn.Module):
# 初始化函数,接受维度 dim 和最大序列长度 max_seq_len
def __init__(self, dim, max_seq_len):
super().__init__()
# 创建 Embedding 层,输入维度为最大序列长度,输出维度为 dim
self.emb = nn.Embedding(max_seq_len, dim)
# 前向传播函数,接受输入 x
def forward(self, x):
# 生成序列长度的张量 t
t = torch.arange(x.shape[1], device=x.device)
# 返回位置嵌入结果
return self.emb(t)
# 固定位置嵌入类,继承自 nn.Module
class FixedPositionalEmbedding(nn.Module):
# 初始化函数,接受维度 dim
def __init__(self, dim):
super().__init__()
# 计算频率
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
# 将频率作为缓冲区
self.register_buffer('inv_freq', inv_freq)
# 前向传播函数,接受输入 x 和序列维度 seq_dim
def forward(self, x, seq_dim=1):
# 生成序列长度的张量 t
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
# 计算正弦和余弦位置嵌入
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :].type_as(x)
# 旋转位置嵌入辅助函数,用于旋转每两个元素
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
# 应用旋转位置嵌入函数,接受查询键 qk 和正弦位置 sinu_pos
def apply_rotary_pos_emb(qk, sinu_pos):
sinu_pos = sinu_pos.type(qk.dtype)
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2)
sin, cos = sinu_pos.unbind(dim=-2)
sin, cos = map(lambda t: repeat(t, 'n d -> n (d j)', j=2), (sin, cos))
seq_len = sin.shape[0]
qk, qk_pass = qk[:, :seq_len], qk[:, seq_len:]
qk = (qk * cos) + (rotate_every_two(qk) * sin)
return torch.cat((qk, qk_pass), dim=1)
# Reformer 语言模型类,继承自 nn.Module
class Reformer(nn.Module):
# 初始化函数,设置模型参数
def __init__(self, dim, depth, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
# 调用父类的初始化函数
super().__init__()
# 设置模型的维度和深度
self.dim = dim
self.depth = depth
# 设置桶的大小和记忆键值对的数量
self.bucket_size = bucket_size
self.num_mem_kv = num_mem_kv
# 设置全局注意力的阈值
self.full_attn_thres = full_attn_thres
# 定义获取注意力和前馈网络的函数
get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dim_head = dim_head, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads)
get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult, glu = ff_glu), along_dim = -2)
get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys)
# 如果权重共享为真,则对获取注意力和前馈网络的函数进行缓存
if weight_tie:
get_attn, get_ff, get_pkm = map(cache_fn, (get_attn, get_ff, get_pkm))
# 初始化块列表
blocks = []
# 根据是否使用标准化类型,选择不同的标准化函数
norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm
# 根据是否使用 ReZero,选择不同的残差函数
residual_fn_wrapper = ReZero if use_rezero else partial(PreNorm, norm_type, dim)
# 循环构建深度个块
for ind in range(depth):
layer_num = ind + 1
use_pkm = layer_num in cast_tuple(pkm_layers)
parallel_net = None
# 获取注意力和前馈网络
attn = get_attn()
if use_pkm:
parallel_net = get_pkm()
else:
parallel_net = get_ff()
f = residual_fn_wrapper(attn)
g = residual_fn_wrapper(parallel_net)
blocks.append(nn.ModuleList([f, g]))
# 构建可逆序列
self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = layer_dropout, reverse_thres = reverse_thres, send_signal = True)
# 前向传播函数
def forward(self, x, **kwargs):
# 在最后一个维度上拼接输入张量
x = torch.cat([x, x], dim = -1)
# 使用可逆序列进行前向传播
x = self.layers(x, **kwargs)
# 将结果张量按最后一个维度分块,取均值
return torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
class ReformerLM(nn.Module):
# 定义 ReformerLM 类,继承自 nn.Module
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
# 初始化函数,接受多个参数
super().__init__()
# 调用父类的初始化函数
emb_dim = default(emb_dim, dim)
# 如果 emb_dim 为 None,则使用 dim
self.max_seq_len = max_seq_len
# 设置最大序列长度
self.token_emb = nn.Embedding(num_tokens, emb_dim)
# 创建一个嵌入层,用于将输入的 token 转换为向量表示
self.to_model_dim = Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)
# 如果 emb_dim 等于 dim,则使用 Identity(),否则使用线性层将 emb_dim 转换为 dim
self.pos_emb = Always(0)
self.layer_pos_emb = Always(None)
# 初始化位置编码
if axial_position_emb:
# 如果启用轴向位置编码
axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size))
# 计算轴向位置编码的形状
self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape)
# 创建轴向位置编码
elif absolute_position_emb:
# 如果启用绝对位置编码
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
# 创建绝对位置编码
elif fixed_position_emb:
# 如果启用固定位置编码
self.pos_emb = FixedPositionalEmbedding(emb_dim)
# 创建固定位置编码
else:
self.layer_pos_emb = FixedPositionalEmbedding(dim_head)
# 创建固定位置编码
self.reformer = Reformer(dim, depth, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
# 创建 Reformer 模型
self.norm = nn.LayerNorm(dim)
# 创建 LayerNorm 层
if return_embeddings:
self.out = Identity()
return
# 如果需要返回嵌入向量,则直接返回
self.out = nn.Sequential(
nn.Linear(dim, emb_dim) if emb_dim != dim else Identity(),
nn.Linear(emb_dim, num_tokens) if not weight_tie_embedding else MatrixMultiply(self.token_emb.weight, transpose=True, normalize=True)
)
# 创建输出层,根据是否需要权重共享选择不同的操作
def forward(self, x, **kwargs):
# 前向传播函数
x = self.token_emb(x)
# 将输入的 token 转换为向量表示
x = x + self.pos_emb(x)
# 添加位置编码到输入向量中
layer_pos_emb = self.layer_pos_emb(x)
# 获取层级位置编码
x = self.to_model_dim(x)
# 将输入向量转换为模型维度
x = self.reformer(x, pos_emb = layer_pos_emb, **kwargs)
# 使用 Reformer 模型进行处理
x = self.norm(x)
# 对输出进行 LayerNorm 处理
return self.out(x)
# 返回输出结果
.\lucidrains\reformer-pytorch\reformer_pytorch\reversible.py
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
# 创建一个继承自 nn.Module 的 Deterministic 类,用于记录和设置随机数生成器状态
# 参考链接:https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 创建一个继承自 nn.Module 的 ReversibleBlock 类,用于实现可逆块
# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 工作正常,重构并向源发送 PR
class ReversibleBlock(nn.Module):
def __init__(self, f, g, depth=None, send_signal = False):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.depth = depth
self.send_signal = send_signal
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
if self.send_signal:
f_args['_reverse'] = g_args['_reverse'] = False
f_args['_depth'] = g_args['_depth'] = self.depth
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
if self.send_signal:
f_args['_reverse'] = g_args['_reverse'] = True
f_args['_depth'] = g_args['_depth'] = self.depth
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
# 创建一个继承自 nn.Module 的 IrreversibleBlock 类,用于实现不可逆块
class IrreversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = f
self.g = g
def forward(self, x, f_args, g_args):
x1, x2 = torch.chunk(x, 2, dim=2)
y1 = x1 + self.f(x2, **f_args)
y2 = x2 + self.g(y1, **g_args)
return torch.cat([y1, y2], dim=2)
# 创建一个继承自 Function 的 _ReversibleFunction 类,用于实现可逆函数
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, kwargs):
ctx.kwargs = kwargs
for block in blocks:
x = block(x, **kwargs)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
# 定义一个反向传播函数,接收上下文和梯度作为参数
def backward(ctx, dy):
# 从上下文中获取 y 值
y = ctx.y
# 从上下文中获取关键字参数
kwargs = ctx.kwargs
# 反向遍历上下文中的块列表
for block in ctx.blocks[::-1]:
# 调用每个块的反向传播方法,更新 y 和 dy
y, dy = block.backward_pass(y, dy, **kwargs)
# 返回更新后的梯度
return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
# 初始化函数,接受一些参数用于构建可逆序列
def __init__(self, blocks, layer_dropout = 0., reverse_thres = 0, send_signal = False):
super().__init__()
# 设置层级丢弃率和反转阈值
self.layer_dropout = layer_dropout
self.reverse_thres = reverse_thres
# 创建可逆块的模块列表,根据是否需要反转选择不同的块
self.blocks = nn.ModuleList([ReversibleBlock(f, g, depth, send_signal) for depth, (f, g) in enumerate(blocks)])
self.irrev_blocks = nn.ModuleList([IrreversibleBlock(f=f, g=g) for f, g in blocks])
# 前向传播函数,接受输入和一些参数,根据是否需要反转选择不同的块进行处理
def forward(self, x, arg_route = (True, False), **kwargs):
# 判断是否需要反转
reverse = x.shape[1] > self.reverse_thres
blocks = self.blocks if reverse else self.irrev_blocks
# 如果处于训练状态且设置了层级丢弃率
if self.training and self.layer_dropout > 0:
# 随机选择是否丢弃某些块
to_drop = torch.empty(len(self.blocks)).uniform_(0, 1) < self.layer_dropout
blocks = [block for block, drop in zip(self.blocks, to_drop) if not drop]
blocks = self.blocks[:1] if len(blocks) == 0 else blocks
# 根据参数路由设置不同的参数
f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
block_kwargs = {'f_args': f_args, 'g_args': g_args}
# 如果不需要反转,则依次对每个块进行处理
if not reverse:
for block in blocks:
x = block(x, **block_kwargs)
return x
# 如果需要反转,则调用自定义的可逆函数进行处理
return _ReversibleFunction.apply(x, blocks, block_kwargs)
.\lucidrains\reformer-pytorch\reformer_pytorch\__init__.py
# 从 reformer_pytorch 模块中导入 LSHAttention, LSHSelfAttention, Reformer, ReformerLM 类
from reformer_pytorch.reformer_pytorch import LSHAttention, LSHSelfAttention, Reformer, ReformerLM
# 从 reformer_pytorch 模块中导入 ReformerEncDec 类
from reformer_pytorch.reformer_enc_dec import ReformerEncDec
# 从 reformer_pytorch 模块中导入 Recorder 类
from reformer_pytorch.recorder import Recorder
# 从 reformer_pytorch 模块中导入 Autopadder 类
from reformer_pytorch.autopadder import Autopadder
.\lucidrains\reformer-pytorch\setup.py
# 导入设置安装包和查找包的模块
from setuptools import setup, find_packages
# 设置安装包的信息
setup(
# 包的名称
name = 'reformer_pytorch',
# 查找包,排除 examples 和 pretraining 文件夹
packages = find_packages(exclude=['examples', 'pretraining']),
# 版本号
version = '1.4.4',
# 许可证
license='MIT',
# 描述
description = 'Reformer, the Efficient Transformer, Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/reformer-pytorch',
# 关键词
keywords = ['transformers', 'attention', 'artificial intelligence'],
# 安装依赖
install_requires=[
'axial-positional-embedding>=0.1.0',
'einops',
'local-attention',
'product-key-memory',
'torch'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

ReLA (Rectified Linear Attention) Transformer
Implementation of a Transformer using ReLA (Rectified Linear Attention). It will also contain an attempt to combine the feedforward into the ReLA layer as memory key / values, as proposed in All Attention, suggestion made by Charles Foster.
Install
$ pip install rela-transformer
Usage
import torch
from rela_transformer import ReLATransformer
model = ReLATransformer(
num_tokens = 20000,
dim = 512,
depth = 8,
max_seq_len = 1024,
dim_head = 64,
heads = 8
)
x = torch.randint(0, 20000, (1, 1024))
mask = torch.ones(1, 1024).bool()
logits = model(x, mask = mask) # (1, 1024, 20000)
Enwik8
$ python train.py
Citations
@misc{zhang2021sparse,
title = {Sparse Attention with Linear Units},
author = {Biao Zhang and Ivan Titov and Rico Sennrich},
year = {2021},
eprint = {2104.07012},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
.\lucidrains\rela-transformer\rela_transformer\autoregressive_wrapper.py
# 导入必要的库
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
# 定义函数,判断值是否存在
def exists(val):
return val is not None
# 定义函数,返回值或默认值
def default(value, default):
return value if exists(value) else default
# 定义函数,计算输入张量的对数
def log(t, eps=1e-9):
return torch.log(t + eps)
# 定义函数,根据阈值返回前k个概率最高的logits
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = None, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = default(ignore_index, pad_value)
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列的方法
@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
was_training = self.net.training
num_dims = len(start_tokens.shape)
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)
logits = logits[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
gumbel_noise = -log(-log(torch.zeros_like(filtered_logits).uniform_(0, 1)))
sample = ((filtered_logits / temperature) + gumbel_noise).argmax(dim=-1)
out = torch.cat((out, sample[:, None]), dim=-1)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
self.net.train(was_training)
return out
# 前向传播方法
def forward(self, x, *args, **kwargs):
inp, labels = x[:, :-1], x[:, 1:]
out = self.net(inp, *args, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), labels, ignore_index = self.ignore_index)
return loss
.\lucidrains\rela-transformer\rela_transformer\rela_transformer.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat 函数
from einops import rearrange, repeat
# 定义辅助函数 exists,用于检查值是否存在
def exists(val):
return val is not None
# 定义 GatedRMSNorm 类,继承自 nn.Module
class GatedRMSNorm(nn.Module):
def __init__(
self,
dim,
eps = 1e-8
):
super().__init__()
# 初始化缩放因子 scale
self.scale = dim ** -0.5
# 初始化 eps
self.eps = eps
# 初始化可学习参数 w 和 g
self.w = nn.Parameter(torch.ones(dim))
self.g = nn.Parameter(torch.ones(dim))
# 前向传播函数
def forward(self, x):
# 计算输入 x 的 L2 范数,并进行缩放
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
# 对输入 x 进行归一化处理
normed_x = x / norm.clamp(min = self.eps) * self.g
# 返回经过门控的 RMS 归一化结果
return normed_x * (x * self.w).sigmoid()
# 定义 FeedForward 函数,返回一个包含线性层和 GELU 激活函数的序列
def FeedForward(dim, mult = 4):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Linear(dim * mult, dim)
)
# 定义 ReLA 类,继承自 nn.Module
class ReLA(nn.Module):
def __init__(
self,
*,
dim,
causal = True,
dim_head = 64,
heads = 8,
num_memory_kv = 0,
relu_squared = False
):
super().__init__()
# 初始化头数和内部维度
self.heads = heads
inner_dim = dim_head * heads
# 初始化缩放因子 scale
self.scale = dim_head ** -0.5
# 初始化是否是因果关系
self.causal = causal
# 初始化是否对激活函数进行平方操作
self.relu_squared = relu_squared
# 初始化 RMS 归一化层
self.norm = GatedRMSNorm(dim)
# 初始化 q、k、v 的线性层
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# 初始化记忆键值对
self.mem_k = nn.Parameter(torch.randn(num_memory_kv, inner_dim))
self.mem_v = nn.Parameter(torch.randn(num_memory_kv, inner_dim))
# 初始化值的 RMS 归一化层和输出层
self.norm_values = GatedRMSNorm(dim_head)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
)
# 前向传播函数
def forward(self, x, mask = None):
# 获取输入 x 的批量大小和设备信息
b, device = x.shape[0], x.device
# 对输入 x 进行 RMS 归一化处理
x = self.norm(x)
h = self.heads
# 将输入 x 经过 qkv 线性层并分块
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 将记忆键值对进行扩展并拼接到 k、v 中
mem_k, mem_v = map(lambda t: repeat(t, 'n d -> b n d', b = b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim = 1)
v = torch.cat((mem_v, v), dim = 1)
# 重排 q、k、v 的维度
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对 q 进行缩放
q = q * self.scale
# 计算注意力分数
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 计算注意力值,并进行 ReLU 激活
attn = F.relu(sim)
# 如果设置了 relu_squared 标志,则对注意力值进行平方操作
if self.relu_squared:
attn = attn ** 2
# 如果存在 mask,则进行 mask 操作
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
attn = attn.masked_fill(~mask, 0.)
# 如果是因果关系,进行因果 mask 操作
if self.causal:
i, j = attn.shape[-2:]
causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
attn = attn.masked_fill(causal_mask, 0.)
# 计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = self.norm_values(out)
# 重排输出维度
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义 ReLATransformer 类,继承自 nn.Module
class ReLATransformer(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
max_seq_len,
causal = True,
heads = 8,
dim_head = 64,
num_memory_kv = 0,
no_ff = False,
ff_mult = 4,
relu_squared = False
):
super().__init__()
# 初始化最大序列长度、token 词嵌入和位置嵌入
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 初始化层列表
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ReLA(dim = dim, relu_squared = relu_squared, heads = heads, dim_head = dim_head, num_memory_kv = num_memory_kv, causal = causal),
FeedForward(dim = dim, mult = ff_mult) if not no_ff else None
]))
# 初始化输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
)
# 定义前向传播函数,接受输入张量 x 和掩码 mask,默认为 None
def forward(self, x, mask = None):
# 获取输入张量 x 的维度 n 和设备信息
n, device = x.shape[1], x.device
# 对输入张量 x 进行 token embedding
x = self.token_emb(x)
# 根据输入张量 x 的长度 n,生成位置编码 pos_emb
pos_emb = self.pos_emb(torch.arange(n, device = device))
# 将位置编码与 token embedding 相加
x = x + rearrange(pos_emb, 'n d -> 1 n d')
# 遍历每个注意力层和前馈层
for attn, ff in self.layers:
# 使用注意力层处理输入张量 x,并将结果与原始输入相加
x = attn(x, mask = mask) + x
# 如果前馈层存在
if exists(ff):
# 使用前馈层处理输入张量 x,并将结果与原始输入相加
x = ff(x) + x
# 将处理后的张量 x 转换为最终的输出 logits
return self.to_logits(x)
.\lucidrains\rela-transformer\rela_transformer\__init__.py
# 从 rela_transformer.rela_transformer 模块中导入 ReLATransformer 类
from rela_transformer.rela_transformer import ReLATransformer
.\lucidrains\rela-transformer\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'rela-transformer', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.7', # 版本号
license='MIT', # 许可证
description = 'ReLA Transformer', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/rela-transformer', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention-mechanism',
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\rela-transformer\train.py
# 导入所需的模块
from rela_transformer import ReLATransformer
from rela_transformer.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化模型
# 创建 ReLATransformer 模型
model = ReLATransformer(
num_tokens = 256,
dim = 512,
depth = 8,
max_seq_len = SEQ_LEN,
heads = 8,
causal = True
)
# 将模型包装在 AutoregressiveWrapper 中
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU 上
model.cuda()
# 准备 enwik8 数据
# 读取 enwik8 数据集
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 创建自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 优化器
# 创建 Adam 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练
# 循环训练指定次数
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
# 梯度累积
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
inp = inp[:SEQ_LEN]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp[None, :], GENERATE_LENGTH)
output_str = decode_tokens(sample.squeeze(0))
print(output_str)

Remixer - Pytorch
Implementation of the Remixer Block from the Remixer paper, in Pytorch. It claims that substituting the feedforwards in transformers with sequence wide mixing followed by multiplication and subtraction leads to better language understanding results.
Install
$ pip install remixer-pytorch
Usage
import torch
from remixer_pytorch import RemixerBlock
block = RemixerBlock(
dim = 512,
seq_len = 1024
)
x = torch.randn(1, 1024, 512)
block(x) # (1, 1024, 512)
Citations
@inproceedings{anonymous,
title = {Remixers: A Mixer-Transformer Architecture with Compositional Operators for Natural Language Understanding },
author = {Anonymous},
year = {2021},
url = {https://openreview.net/forum?id=9FHQHJnRtfL}
}
.\lucidrains\remixer-pytorch\remixer_pytorch\remixer_pytorch.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 函数
from einops import rearrange
# 定义 RemixerBlock 类,继承自 nn.Module
class RemixerBlock(nn.Module):
# 初始化函数,接受 dim、seq_len、causal 和 bias 四个参数
def __init__(
self,
dim,
seq_len,
causal = False,
bias = False
):
super().__init__()
# 初始化 causal 属性
self.causal = causal
# 初始化 proj_in 属性为 Linear 层,输入维度为 dim,输出维度为 2 * dim
self.proj_in = nn.Linear(dim, 2 * dim, bias = bias)
# 初始化 mixer 属性为 nn.Parameter,值为随机生成的 seq_len x seq_len 的张量
self.mixer = nn.Parameter(torch.randn(seq_len, seq_len))
# 初始化 alpha 属性为 nn.Parameter,值为 0 的张量
self.alpha = nn.Parameter(torch.tensor(0.))
# 初始化 proj_out 属性为 Linear 层,输入维度为 dim,输出维度为 dim
self.proj_out = nn.Linear(dim, dim, bias = bias)
# 前向传播函数,接受输入 x
def forward(self, x):
# 获取 mixer、causal 和 device 属性
mixer, causal, device = self.mixer, self.causal, x.device
# 将输入 x 经过 proj_in 层并分割成两部分,x 和 gate
x, gate = self.proj_in(x).chunk(2, dim = -1)
# 对 gate 部分进行 gelu 激活函数处理,再与 x 相乘
x = F.gelu(gate) * x
# 如果 causal 为 True
if self.causal:
# 获取序列长度 seq
seq = x.shape[1]
# 创建 mask_value 为 x 数据类型的最小值
mask_value = -torch.finfo(x.dtype).max
# 创建上三角矩阵 mask,大小为 (seq, seq)
mask = torch.ones((seq, seq), device = device, dtype=torch.bool).triu(1)
# 限制 mixer 的大小为 (seq, seq),并根据 mask 进行填充
mixer = mixer[:seq, :seq]
mixer = mixer.masked_fill(mask, mask_value)
# 对 mixer 进行 softmax 处理
mixer = mixer.softmax(dim = -1)
# 使用 einsum 进行矩阵乘法,得到 mixed
mixed = einsum('b n d, m n -> b m d', x, mixer)
# 获取 alpha,并进行 sigmoid 处理
alpha = self.alpha.sigmoid()
# 计算输出 out,根据 alpha 对 x 和 mixed 进行加权平均
out = (x * mixed) * alpha + (x - mixed) * (1 - alpha)
# 将 out 经过 proj_out 层得到最终输出
return self.proj_out(out)
.\lucidrains\remixer-pytorch\remixer_pytorch\__init__.py
# 从 remixer_pytorch.remixer_pytorch 模块中导入 RemixerBlock 类
from remixer_pytorch.remixer_pytorch import RemixerBlock
.\lucidrains\remixer-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元信息
setup(
# 包名
name = 'remixer-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.3',
# 许可证
license='MIT',
# 描述
description = 'Remixer - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/remixer-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'transformer',
'feedforward',
'mlp-mixer'
],
# 安装依赖
install_requires=[
'einops>=0.3',
'torch>=1.6'
],
# 分类器列表
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

ResMLP - Pytorch
Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch
Install
$ pip install res-mlp-pytorch
Usage
import torch
from res_mlp_pytorch import ResMLP
model = ResMLP(
image_size = 256,
patch_size = 16,
dim = 512,
depth = 12,
num_classes = 1000
)
img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)
Rectangular image
import torch
from res_mlp_pytorch import ResMLP
model = ResMLP(
image_size = (128, 256), # (128 x 256)
patch_size = 16,
dim = 512,
depth = 12,
num_classes = 1000
)
img = torch.randn(1, 3, 128, 256)
pred = model(img) # (1, 1000)
Citations
@misc{touvron2021resmlp,
title = {ResMLP: Feedforward networks for image classification with data-efficient training},
author = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
year = {2021},
eprint = {2105.03404},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\res-mlp-pytorch\res_mlp_pytorch\res_mlp_pytorch.py
import torch
from torch import nn, einsum
from einops.layers.torch import Rearrange, Reduce
# 导入必要的库
# 定义一个函数,如果输入不是元组,则返回一个包含相同值的元组
def pair(val):
return (val, val) if not isinstance(val, tuple) else val
# 定义一个仿射变换类
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, 1, dim))
self.b = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
return x * self.g + self.b
# 定义一个预仿射后层缩放类
class PreAffinePostLayerScale(nn.Module): # https://arxiv.org/abs/2103.17239
def __init__(self, dim, depth, fn):
super().__init__()
# 根据深度选择初始化值
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.affine = Affine(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.affine(x)) * self.scale + x
# 定义一个ResMLP模型
def ResMLP(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4):
image_height, image_width = pair(image_size)
assert (image_height % patch_size) == 0 and (image_width % patch_size) == 0, 'image height and width must be divisible by patch size'
num_patches = (image_height // patch_size) * (image_width // patch_size)
wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear((patch_size ** 2) * 3, dim),
*[nn.Sequential(
wrapper(i, nn.Conv1d(num_patches, num_patches, 1)),
wrapper(i, nn.Sequential(
nn.Linear(dim, dim * expansion_factor),
nn.GELU(),
nn.Linear(dim * expansion_factor, dim)
))
) for i in range(depth)],
Affine(dim),
Reduce('b n c -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
# 返回一个包含ResMLP模型结构的序列
.\lucidrains\res-mlp-pytorch\res_mlp_pytorch\__init__.py
# 从 res_mlp_pytorch.res_mlp_pytorch 模块中导入 ResMLP 类
from res_mlp_pytorch.res_mlp_pytorch import ResMLP
.\lucidrains\res-mlp-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'res-mlp-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.0.6', # 版本号
license='MIT', # 许可证
description = 'ResMLP - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/res-mlp-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'image recognition'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip)
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch
This will make use of the Clip Retrieval library made by @rom1504
Citations
@article{Blattmann2022RetrievalAugmentedDM,
title = {Retrieval-Augmented Diffusion Models},
author = {A. Blattmann and Robin Rombach and K Oktay and Bj{\"o}rn Ommer},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.11824}
}
.\lucidrains\retrieval-augmented-ddpm\retrieval_augmented_ddpm\retrieval_augmented_ddpm.py
# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
# 计算矩形的面积
area = length * width
# 返回计算得到的面积
return area
.\lucidrains\retrieval-augmented-ddpm\retrieval_augmented_ddpm\__init__.py
# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
# 计算矩形的面积
area = length * width
# 返回计算得到的面积
return area
.\lucidrains\retrieval-augmented-ddpm\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'retrieval-augmented-ddpm', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.1', # 版本号
license='MIT', # 许可证
description = 'Retrieval-Augmented Denoising Diffusion Probabilistic Models', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/retrieval-augmented-ddpm', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'denoising diffusion',
'retrieval'
],
install_requires=[ # 安装依赖的包
'clip-retrieval',
'einops>=0.4',
'torch>=1.6',
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

RETRO - Pytorch
Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann.
This library leverages autofaiss for building the index and calculating the k-nearest neighbors for all chunks.
Jay Alammar explanatory blogpost
The selling point of this retriever approach is reaching GPT-3 performance at 10x less parameters. More research is definitely deserved in this area.
I have also included the features necessary to scale the retrieval transformer to 1000 layers, if the claims of DeepNet paper is to be believed.
Update: Someone on Reddit has gifted me a Gold Award. Not sure what it is, but thank you! 🙏
Update: Deepnorm has been validated at scale in a 130B model out of Tsinghua. It is now recommended that you train with use_deepnet set to True
Install
$ pip install retro-pytorch
Usage
import torch
from retro_pytorch import RETRO
retro = RETRO(
chunk_size = 64, # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
max_seq_len = 2048, # max sequence length
enc_dim = 896, # encoder model dim
enc_depth = 2, # encoder depth
dec_dim = 796, # decoder model dim
dec_depth = 12, # decoder depth
dec_cross_attn_layers = (3, 6, 9, 12), # decoder cross attention layers (with causal chunk cross attention)
heads = 8, # attention heads
dim_head = 64, # dimension per head
dec_attn_dropout = 0.25, # decoder attention dropout
dec_ff_dropout = 0.25, # decoder feedforward dropout
use_deepnet = True # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers
)
seq = torch.randint(0, 20000, (2, 2048 + 1)) # plus one since it is split into input and labels for training
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)
loss = retro(seq, retrieved, return_loss = True)
loss.backward()
# do above for many steps
RETRO Training Wrapper
The aim of the TrainingWrapper is to process a folder of text documents into the necessary memmapped numpy arrays to begin training RETRO.
import torch
from retro_pytorch import RETRO, TrainingWrapper
# instantiate RETRO, fit it into the TrainingWrapper with correct settings
retro = RETRO(
max_seq_len = 2048, # max sequence length
enc_dim = 896, # encoder model dimension
enc_depth = 3, # encoder depth
dec_dim = 768, # decoder model dimensions
dec_depth = 12, # decoder depth
dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention)
heads = 8, # attention heads
dim_head = 64, # dimension per head
dec_attn_dropout = 0.25, # decoder attention dropout
dec_ff_dropout = 0.25 # decoder feedforward dropout
).cuda()
wrapper = TrainingWrapper(
retro = retro, # path to retro instance
knn = 2, # knn (2 in paper was sufficient)
chunk_size = 64, # chunk size (64 in paper)
documents_path = './text_folder', # path to folder of text
glob = '**/*.txt', # text glob
chunks_memmap_path = './train.chunks.dat', # path to chunks
seqs_memmap_path = './train.seq.dat', # path to sequence data
doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids per chunk (used for filtering neighbors belonging to same document)
max_chunks = 1_000_000, # maximum cap to chunks
max_seqs = 100_000, # maximum seqs
knn_extra_neighbors = 100, # num extra neighbors to fetch
max_index_memory_usage = '100m',
current_memory_available = '1G'
)
# get the dataloader and optimizer (AdamW with all the correct settings)
train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)
# now do your training
# ex. one gradient step
seq, retrieved = map(lambda t: t.cuda(), next(train_dl))
# seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens
loss = retro(
seq,
retrieved,
return_loss = True
)
# one gradient step
loss.backward()
optim.step()
optim.zero_grad()
# do above for many steps, then ...
# topk sampling with retrieval at chunk boundaries
sampled = wrapper.generate(filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all <eos>
# or you can generate with a prompt, knn retrieval for initial chunks all taken care of
prompt = torch.randint(0, 1000, (1, 128)) # start with two chunks worth of sequence
sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all <eos>
If you wish to force a reprocess of the training data, simply run your script with a REPROCESS=1 environment flag as so
$ REPROCESS=1 python train.py
RETRO Datasets
The RETRODataset class accepts paths to a number of memmapped numpy arrays containing the chunks, the index of the first chunk in the sequence to be trained on (in RETRO decoder), and the pre-calculated indices of the k-nearest neighbors per chunk.
You can use this to easily assemble the data for RETRO training, if you do not wish to use the TrainingWrapper from above.
Furthermore, all the functions needed to create the necessary memmapped data is in the sections to follow.
import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset
# mock data constants
import numpy as np
NUM_CHUNKS = 1000
CHUNK_SIZE = 64
NUM_SEQS = 100
NUM_NEIGHBORS = 2
def save_memmap(path, tensor):
f = np.memmap(path, dtype = tensor.dtype, mode = 'w+', shape = tensor.shape)
f[:] = tensor
del f
# generate mock chunk data
save_memmap(
'./train.chunks.dat',
np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1)))
)
# generate nearest neighbors for each chunk
save_memmap(
'./train.chunks.knn.dat',
np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS)))
)
# generate seq data
save_memmap(
'./train.seq.dat',
np.int32(np.random.randint(0, 128, size = (NUM_SEQS,)))
)
# instantiate dataset class
# which constructs the sequence and neighbors from memmapped chunk and neighbor information
train_ds = RETRODataset(
num_sequences = NUM_SEQS,
num_chunks = NUM_CHUNKS,
num_neighbors = NUM_NEIGHBORS,
chunk_size = CHUNK_SIZE,
seq_len = 2048,
chunk_memmap_path = './train.chunks.dat',
chunk_nn_memmap_path = './train.chunks.knn.dat',
seq_memmap_path = './train.seq.dat'
)
train_dl = iter(DataLoader(train_ds, batch_size = 2))
# one forwards and backwards
retro = RETRO(
max_seq_len = 2048, # max sequence length
enc_dim = 896, # encoder model dimension
enc_depth = 3, # encoder depth
dec_dim = 768, # decoder model dimensions
dec_depth = 12, # decoder depth
dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention)
heads = 8, # attention heads
dim_head = 64, # dimension per head
dec_attn_dropout = 0.25, # decoder attention dropout
dec_ff_dropout = 0.25 # decoder feedforward dropout
).cuda()
seq, retrieved = map(lambda t: t.cuda(), next(train_dl))
# seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens
loss = retro(
seq,
retrieved,
return_loss = True
)
loss.backward()
Retrieval related tools
This repository will use the default tokenizer (sentencepiece) for the cased version of BERT. Embeddings will be fetched from the vanilla BERT, and can either be masked mean pooled representation, or the CLS token.
ex. masked mean pooled representation
from retro_pytorch.retrieval import bert_embed, tokenize
ids = tokenize([
'hello world',
'foo bar'
])
embeds = bert_embed(ids) # (2, 768) - 768 is hidden dimension of BERT
ex. CLS token representation
from retro_pytorch.retrieval import bert_embed, tokenize
ids = tokenize([
'hello world',
'foo bar'
])
embeds = bert_embed(ids, return_cls_repr = True) # (2, 768)
Create your chunks and chunk start indices (for calculating sequence ranges for autoregressive training) using text_folder_to_chunks_
from retro_pytorch.retrieval import text_folder_to_chunks_
stats = text_folder_to_chunks_(
folder = './text_folder',
glob = '**/*.txt',
chunks_memmap_path = './train.chunks.dat',
seqs_memmap_path = './train.seq.dat',
doc_ids_memmap_path = './train.doc_ids.dat', # document ids are needed for filtering out neighbors belonging to same document appropriately during computation of nearest neighbors
chunk_size = 64,
seq_len = 2048,
max_chunks = 1_000_000,
max_seqs = 100_000
)
# {'chunks': <number of chunks>, 'docs': <number of documents>, 'seqs': <number of sequences>}
Fetching Nearest Neighbors
You can turn your memmapped chunks numpy array into embeddings and a faiss index with one command
from retro_pytorch.retrieval import chunks_to_index_and_embed
index, embeddings = chunks_to_index_and_embed(
num_chunks = 1000,
chunk_size = 64,
chunk_memmap_path = './train.chunks.dat'
)
query_vector = embeddings[:1] # use first embedding as query
_, indices = index.search(query_vector, k = 2) # fetch 2 neighbors, first indices should be self
neighbor_embeddings = embeddings[indices] # (1, 2, 768)
You can also directly calculate the nearest neighbor file necessary for training, with chunks_to_precalculated_knn_ command
from retro_pytorch.retrieval import chunks_to_precalculated_knn_
chunks_to_precalculated_knn_(
num_chunks = 1000,
chunk_size = 64,
chunk_memmap_path = './train.chunks.dat', # path to main chunks dataset
doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids created by text_folder_to_chunks_, used for filtering out neighbors that belong to the same document
num_nearest_neighbors = 2, # number of nearest neighbors you'd like to use
num_extra_neighbors = 10 # fetch 10 extra neighbors, in the case that fetched neighbors are frequently from same document (filtered out)
)
# nearest neighbor info saved to ./train.chunks.knn.dat
Citations
@misc{borgeaud2022improving,
title = {Improving language models by retrieving from trillions of tokens},
author = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre},
year = {2022},
eprint = {2112.04426},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@article{Wang2022DeepNetST,
title = {DeepNet: Scaling Transformers to 1, 000 Layers},
author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
journal = {ArXiv},
year = {2022},
volume = {abs/2203.00555}
}
@misc{zhang2021sparse,
title = {Sparse Attention with Linear Units},
author = {Biao Zhang and Ivan Titov and Rico Sennrich},
year = {2021},
eprint = {2104.07012},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
I consider always the adult life to be the continuous retrieval of childhood. - Umberto Eco
.\lucidrains\RETRO-pytorch\retro_pytorch\data.py
# 导入所需的库
from functools import partial
import numpy as np
import torch
from torch.utils.data import Dataset
# 导入自定义的模块
from retro_pytorch.retrieval import EOS_ID
from retro_pytorch.utils import memmap
# 定义函数 knn_to_retrieved_chunks,用于将 KNN 结果转换为检索到的块
def knn_to_retrieved_chunks(
knns,
chunks_memmap,
*,
add_continuations,
num_chunks,
pad_id = 0,
eos_id = EOS_ID,
):
# 推导出没有找到邻居的掩码
no_neighbor_mask = knns == -1
knns = np.maximum(knns, 0)
# 获取邻居和连续块
knn_chunks = chunks_memmap[knns]
is_last_document_chunk = np.any(knn_chunks == eos_id, axis = -1, keepdims = True)
# 使用 [EOS] 在块中的存在作为检测文档边界的方式
retrieved = knn_chunks[..., :-1]
if add_continuations:
continuation_indices = np.clip(knns + 1, 0, num_chunks - 1) # 块是连续存储的
continuation_chunks = chunks_memmap[continuation_indices][..., :-1]
continuation_chunks *= ~is_last_document_chunk
# 将邻居与连续块合并
retrieved = np.concatenate((retrieved, continuation_chunks), axis = -1)
# 将任何最近邻块为 -1(在索引时未找到)的掩码为填充 ID
retrieved = np.where(~no_neighbor_mask[..., None], retrieved, pad_id)
return retrieved
# 定义类 RETRODataset,继承自 Dataset 类
class RETRODataset(Dataset):
def __init__(
self,
*,
num_chunks,
chunk_size,
seq_len,
num_sequences,
num_neighbors,
chunk_memmap_path,
chunk_nn_memmap_path,
seq_memmap_path,
eos_id = EOS_ID,
pad_id = 0.,
add_continuations = True
):
super().__init__()
self.num_chunks = num_chunks
self.num_sequences = num_sequences
self.seq_num_chunks = seq_len // chunk_size
self.eos_id = eos_id
self.pad_id = pad_id
num_chunks_with_padding = num_chunks + self.seq_num_chunks
chunks_shape = (num_chunks_with_padding, chunk_size + 1)
knn_shape = (num_chunks_with_padding, num_neighbors)
self.add_continuations = add_continuations
self.get_chunks = partial(memmap, chunk_memmap_path, dtype = np.int32, shape = chunks_shape)
self.get_knns = partial(memmap, chunk_nn_memmap_path, dtype = np.int32, shape = knn_shape)
self.get_seqs = partial(memmap, seq_memmap_path, dtype = np.int32, shape = (num_sequences,))
# 返回数据集的长度
def __len__(self):
return self.num_sequences
# 获取数据集中指定索引的数据
def __getitem__(self, ind):
with self.get_chunks() as chunks_memmap, self.get_knns() as knns_memmap, self.get_seqs() as seqs_memmap:
begin_chunk_index = seqs_memmap[ind]
chunk_range = slice(begin_chunk_index, (begin_chunk_index + self.seq_num_chunks))
chunks = chunks_memmap[chunk_range]
# 剪切最后一个标记,除了最后一个块的最后一个标记
seq_tokens = np.concatenate((chunks[:, :-1].flatten(), chunks[-1, -1:]))
# 掩码掉(使用填充标记)任何跟在 <eos> 后的标记 | 不允许一个序列中有多个文档,因为这会破坏 RETRO 的 CCA
seq_mask = np.cumsum(seq_tokens == self.eos_id, axis = 0)
seq_mask = np.pad(seq_mask, (1, 0))[:-1] == 0.
seq_tokens = np.where(seq_mask, seq_tokens, 0.)
# 推导出检索到的标记
knns = knns_memmap[chunk_range]
retrieved = knn_to_retrieved_chunks(
knns,
chunks_memmap,
add_continuations = self.add_continuations,
eos_id = self.eos_id,
num_chunks = self.num_chunks
)
seq_tokens_torch = torch.from_numpy(seq_tokens).long()
retrieved_torch = torch.from_numpy(retrieved).long()
return seq_tokens_torch, retrieved_torch
.\lucidrains\RETRO-pytorch\retro_pytorch\optimizer.py
# 从 torch.optim 模块中导入 AdamW 优化器
from torch.optim import AdamW
# 将参数分为可进行权重衰减和不可进行权重衰减的参数
def separate_weight_decayable_params(params):
# 找出参数中维度小于 2 的参数,即不可进行权重衰减的参数
no_wd_params = set([param for param in params if param.ndim < 2])
# 可进行权重衰减的参数为所有参数减去不可进行权重衰减的参数
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
# 根据参数和超参数创建 AdamW 优化器
def get_optimizer(params, lr = 3e-4, wd = 1e-1, filter_by_requires_grad = False):
# 如果需要根据 requires_grad 过滤参数,则只保留 requires_grad 为 True 的参数
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
# 将参数转换为集合
params = set(params)
# 将参数分为可进行权重衰减和不可进行权重衰减的参数
wd_params, no_wd_params = separate_weight_decayable_params(params)
# 构建参数组,其中可进行权重衰减的参数使用默认权重衰减,不可进行权重衰减的参数不使用权重衰减
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
# 返回使用 AdamW 优化器的参数组和超参数 lr 和 wd 的优化器
return AdamW(param_groups, lr = lr, weight_decay = wd)
.\lucidrains\RETRO-pytorch\retro_pytorch\retrieval.py
# 导入所需的模块
from pathlib import Path
from math import ceil
import torch
import torch.nn.functional as F
import logging
import numpy as np
from einops import rearrange
import faiss
from autofaiss import build_index
from retro_pytorch.utils import memmap, reset_folder_
# 常量定义
SOS_ID = 101
EOS_ID = 102
BERT_MODEL_DIM = 768
BERT_VOCAB_SIZE = 28996
TMP_PATH = Path('./.tmp')
INDEX_FOLDER_PATH = TMP_PATH / '.index'
EMBEDDING_TMP_SUBFOLDER = 'embeddings'
# 辅助函数
def exists(val):
return val is not None
def range_chunked(max_value, *, batch_size):
counter = 0
while counter < max_value:
curr = counter + batch_size
curr = min(curr, max_value)
yield slice(counter, curr)
counter = curr
# 索引辅助函数
def faiss_read_index(path):
return faiss.read_index(str(path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
# 单例全局变量
MODEL = None
TOKENIZER = None
def get_tokenizer():
global TOKENIZER
if not exists(TOKENIZER):
TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
return TOKENIZER
def get_bert():
global MODEL
if not exists(MODEL):
MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')
if torch.cuda.is_available():
MODEL = MODEL.cuda()
return MODEL
# 分词
def tokenize(texts, add_special_tokens = True):
if not isinstance(texts, (list, tuple)):
texts = [texts]
tokenizer = get_tokenizer()
encoding = tokenizer.batch_encode_plus(
texts,
add_special_tokens = add_special_tokens,
padding = True,
return_tensors = 'pt'
)
token_ids = encoding.input_ids
return token_ids
# 文本转换为块和序列索引
def doc_text_to_chunks_and_seq_indices(
*,
doc_text,
chunk_size = 64,
seq_len = 2048,
pad_id = 0
):
assert (seq_len % chunk_size) == 0, 'sequence length must be divisible by chunk size'
ids = tokenize(doc_text)
ids = rearrange(ids, '1 ... -> ...')
text_len = ids.shape[-1]
# 用额外的标记填充到块大小的倍数
padding = chunk_size - ((text_len - 1) % chunk_size)
ids = F.pad(ids, (0, padding))
# 分离最后一个标记
ids, last_token = ids[:-1], ids[-1:]
ids = rearrange(ids, '(n c) -> n c', c = chunk_size)
# 块的第一个标记[2:]及之后的标记将成为块[1:]的最后一个标记
last_token_per_chunk = ids[1:, 0]
all_last_tokens = torch.cat((last_token_per_chunk, last_token), dim = 0)
all_last_tokens = rearrange(all_last_tokens, 'n -> n 1')
# 将所有最后一个标记附加到块中,形成(num_chunks, chunk_size + 1)
chunks_with_extra_token = torch.cat((ids, all_last_tokens), dim = -1)
# 计算从0开始的块索引,间隔为序列长度的块数
total_chunks = ids.shape[0]
num_chunks_per_seq = seq_len // chunk_size
seq = torch.arange(0, total_chunks, num_chunks_per_seq)
return chunks_with_extra_token, seq
def text_folder_to_chunks_(
*,
folder,
chunks_memmap_path,
seqs_memmap_path,
doc_ids_memmap_path,
chunk_size = 64,
seq_len = 2048,
glob = '**/*.txt',
max_chunks = 1_000_000,
max_seqs = 100_000
):
paths = sorted([*Path(folder).glob(glob)])
total_chunks = 0
total_docs = 0
total_seqs = 0
chunks_shape = (max_chunks, chunk_size + 1)
seqs_shape = (max_seqs,)
doc_ids_shape = (max_chunks,)
# 使用上下文管理器打开三个内存映射文件,分别用于存储chunks、seqs和doc_ids
with memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32, mode = 'w+') as chunks_memmap\
, memmap(seqs_memmap_path, shape = seqs_shape, dtype = np.int32, mode = 'w+') as seqs_memmap\
, memmap(doc_ids_memmap_path, shape = doc_ids_shape, dtype = np.int32, mode = 'w+') as doc_ids_memmap:
# 遍历所有路径
for path in paths:
# 打印当前处理的路径
print(f'processing {path}')
# 将文档文本转换为chunks和seq的索引
chunks, seq = doc_text_to_chunks_and_seq_indices(
doc_text = path.read_text(),
chunk_size = chunk_size,
seq_len = seq_len
)
# 获取当前文档的chunks和seq的长度
doc_chunk_len = chunks.shape[0]
doc_seq_len = seq.shape[0]
# 将当前文档的chunks写入chunks内存映射文件
chunks_memmap[total_chunks:(total_chunks + doc_chunk_len)] = chunks.numpy()
# 将当前文档的seq索引写入seqs内存映射文件,并加上之前文档的总chunks数
seqs_memmap[total_seqs:(total_seqs + doc_seq_len)] = seq.numpy() + total_chunks
# 将当前文档的doc_ids写入doc_ids内存映射文件,使用当前文档的总chunks数填充
doc_ids_memmap[total_chunks:(total_chunks + doc_chunk_len)] = np.full((doc_chunk_len,), total_docs)
# 更新总chunks、总seqs和总docs数
total_chunks += doc_chunk_len
total_seqs += doc_seq_len
total_docs += 1
# 返回包含总chunks、总docs和总seqs数的字典
return dict(
chunks = total_chunks,
docs = total_docs,
seqs = total_seqs
)
# 嵌入函数
@torch.no_grad()
def bert_embed(
token_ids,
return_cls_repr = False,
eps = 1e-8,
pad_id = 0.
):
# 获取 BERT 模型
model = get_bert()
# 创建掩码,标记不是填充符的位置
mask = token_ids != pad_id
# 如果有可用的 GPU,则将数据移至 GPU
if torch.cuda.is_available():
token_ids = token_ids.cuda()
mask = mask.cuda()
# 使用 BERT 模型进行前向传播
outputs = model(
input_ids = token_ids,
attention_mask = mask,
output_hidden_states = True
)
# 获取最后一个隐藏状态
hidden_state = outputs.hidden_states[-1]
# 如果需要返回 [cls] 作为表示,则返回 [cls] 的隐藏状态
if return_cls_repr:
return hidden_state[:, 0]
# 如果没有掩码存在,则计算所有 token 的平均值
if not exists(mask):
return hidden_state.mean(dim = 1)
# 更新掩码,排除 [cls],考虑长度
mask = mask[:, 1:]
mask = rearrange(mask, 'b n -> b n 1')
# 计算加权平均值
numer = (hidden_state[:, 1:] * mask).sum(dim = 1)
denom = mask.sum(dim = 1)
masked_mean = numer / (denom + eps)
return masked_mean
# 将块转换为 KNN
def chunks_to_embeddings_(
*,
num_chunks,
chunks_memmap_path,
embeddings_memmap_path,
chunk_size = 64,
embed_dim = BERT_MODEL_DIM,
batch_size = 16,
use_cls_repr = False,
pad_id = 0.
):
chunks_shape = (num_chunks, chunk_size + 1)
embed_shape = (num_chunks, embed_dim)
# 使用内存映射加载块和嵌入
with memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32) as chunks\
, memmap(embeddings_memmap_path, shape = embed_shape, dtype = np.float32, mode = 'w+') as embeddings:
# 对块进行分批处理
for dim_slice in range_chunked(num_chunks, batch_size = batch_size):
batch_chunk_npy = chunks[dim_slice]
batch_chunk = torch.from_numpy(batch_chunk_npy)
cls_tokens = torch.full((batch_chunk.shape[0], 1), SOS_ID)
batch_chunk = torch.cat((cls_tokens, batch_chunk), dim = 1)
batch_chunk = batch_chunk[:, :-1] # 省略最后一个 token,下一个块的第一个 token,用于自回归训练
# 获取块的嵌入
batch_embed = bert_embed(
batch_chunk,
return_cls_repr = use_cls_repr
)
# 将嵌入写入内存映射
embeddings[dim_slice] = batch_embed.detach().cpu().numpy()
print(f'embedded {dim_slice.stop} / {num_chunks}')
def memmap_file_to_chunks_(
memmap_path,
*,
folder,
shape,
dtype,
max_rows_per_file = 500
):
rows, _ = shape
# 使用内存映射将文件分割为块并保存
with memmap(memmap_path, shape = shape, dtype = dtype, mode = 'r') as f:
root_path = TMP_PATH / folder
reset_folder_(root_path)
for ind, dim_slice in enumerate(range_chunked(rows, batch_size = max_rows_per_file)):
filename = root_path / f'{ind:05d}.npy'
data_slice = f[dim_slice]
np.save(str(filename), f[dim_slice])
print(f'saved {str(filename)}')
def index_embeddings(
embeddings_folder,
*,
index_file = 'knn.index',
index_infos_file = 'index_infos.json',
max_index_memory_usage = '100m',
current_memory_available = '1G'
):
embeddings_path = TMP_PATH / embeddings_folder
index_path = INDEX_FOLDER_PATH / index_file
reset_folder_(INDEX_FOLDER_PATH)
# 构建索引
build_index(
embeddings = str(embeddings_path),
index_path = str(index_path),
index_infos_path = str(INDEX_FOLDER_PATH / index_infos_file),
metric_type = "l2",
max_index_memory_usage = max_index_memory_usage,
current_memory_available = current_memory_available,
make_direct_map = True,
should_be_memory_mappable = False,
use_gpu = torch.cuda.is_available(),
)
# 读取索引
index = faiss_read_index(index_path)
return index
def chunks_to_index_and_embed(
*,
num_chunks,
chunk_size,
chunk_memmap_path,
use_cls_repr = False,
max_rows_per_file = 500,
chunks_to_embeddings_batch_size = 16,
embed_dim = BERT_MODEL_DIM,
index_file = 'knn.index',
**index_kwargs
):
embedding_path = f'{chunk_memmap_path}.embedded'
embed_shape = (num_chunks, embed_dim)
# 将数据分块转换为嵌入向量
chunks_to_embeddings_(
num_chunks = num_chunks, # 数据分块的数量
chunk_size = chunk_size, # 每个数据分块的大小
chunks_memmap_path = chunk_memmap_path, # 数据分块的内存映射路径
embeddings_memmap_path = embedding_path, # 嵌入向量的内存映射路径
use_cls_repr = use_cls_repr, # 是否使用分类表示
batch_size = chunks_to_embeddings_batch_size, # 转换为嵌入向量的批处理大小
embed_dim = embed_dim # 嵌入向量的维度
)
# 将内存映射文件转换为数据分块
memmap_file_to_chunks_(
embedding_path, # 嵌入向量的内存映射路径
shape = embed_shape, # 嵌入向量的形状
dtype = np.float32, # 数据类型为32位浮点数
folder = EMBEDDING_TMP_SUBFOLDER, # 数据分块存储的文件夹
max_rows_per_file = max_rows_per_file # 每个文件的最大行数
)
# 对嵌入向量进行索引
index = index_embeddings(
embeddings_folder = EMBEDDING_TMP_SUBFOLDER, # 嵌入向量存储的文件夹
index_file = index_file, # 索引文件
**index_kwargs # 其他索引参数
)
# 从内存映射文件中读取嵌入向量
embeddings = np.memmap(embedding_path, shape = embed_shape, dtype = np.float32, mode = 'r')
# 返回索引和嵌入向量
return index, embeddings
# 定义一个函数,用于将数据划分为预先计算的 KNN(K-Nearest Neighbors)索引
def chunks_to_precalculated_knn_(
*,
num_nearest_neighbors, # 最近邻居的数量
num_chunks, # 数据块的数量
chunk_size, # 数据块的大小
chunk_memmap_path, # 数据块的内存映射路径
doc_ids_memmap_path, # 文档 ID 的内存映射路径
use_cls_repr = False, # 是否使用分类表示
max_rows_per_file = 500, # 每个文件的最大行数
chunks_to_embeddings_batch_size = 16, # 数据块到嵌入的批处理大小
embed_dim = BERT_MODEL_DIM, # 嵌入维度
num_extra_neighbors = 10, # 额外的邻居数量
force_reprocess = False, # 是否强制重新处理
index_file = 'knn.index', # 索引文件名
**index_kwargs # 其他索引参数
):
# 获取数据块的路径
chunk_path = Path(chunk_memmap_path)
# 获取 KNN 文件的路径
knn_path = chunk_path.parents[0] / f'{chunk_path.stem}.knn{chunk_path.suffix}'
# 获取索引文件的路径
index_path = INDEX_FOLDER_PATH / index_file
# 如果索引文件和 KNN 文件存在且不需要强制重新处理,则直接返回 KNN 文件路径和 Faiss 索引
if index_path.exists() and knn_path.exists() and not force_reprocess:
print(f'preprocessed knn found at {str(knn_path)}, faiss index reconstituted from {str(index_path)}')
index = faiss_read_index(index_path)
return knn_path, index
# 获取 Faiss 索引和数据块的嵌入
index, embeddings = chunks_to_index_and_embed(
num_chunks = num_chunks,
chunk_size = chunk_size,
chunk_memmap_path = chunk_memmap_path,
index_file = index_file,
**index_kwargs
)
# 计算需要获取的总邻居数
total_neighbors_to_fetch = num_extra_neighbors + num_nearest_neighbors + 1
# 使用内存映射创建 KNN 和文档 ID 的数组
with memmap(knn_path, shape = (num_chunks, num_nearest_neighbors), dtype = np.int32, mode = 'w+') as knns\
, memmap(doc_ids_memmap_path, shape = (num_chunks,), dtype = np.int32, mode = 'r') as doc_ids:
# 对数据块进行分片处理
for dim_slice in range_chunked(num_chunks, batch_size = max_rows_per_file):
# 获取查询向量
query_vector = embeddings[dim_slice]
# 使用索引查找最近邻居
distances, indices = index.search(query_vector, k = total_neighbors_to_fetch)
# 移除自身作为邻居
distances = distances[:, 1:]
indices = indices[:, 1:]
# 将属于同一文档的邻居标记为 -1
query_doc_ids = doc_ids[dim_slice]
neighbor_doc_ids = doc_ids[indices]
neighbor_from_same_doc = query_doc_ids[..., None] == neighbor_doc_ids
indices = np.where(neighbor_from_same_doc, -1, indices)
distances = np.where(neighbor_from_same_doc, 1e3, distances)
# 根据更新后的距离重新排序索引
indices = np.take_along_axis(indices, np.argsort(distances, axis = 1), axis = 1)
# 将最近邻居存储到 KNN 内存映射中
knns[dim_slice] = indices[:, :num_nearest_neighbors]
print(f'knns calculated for {dim_slice.stop} / {num_chunks}')
# 打印 KNN 文件保存路径
print(f'knn saved to {knn_path}')
return knn_path, index
.\lucidrains\RETRO-pytorch\retro_pytorch\retro_pytorch.py
# 导入必要的库
from functools import partial
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 导入自定义的库
from retro_pytorch.retrieval import BERT_VOCAB_SIZE
from einops import rearrange, repeat
# 常量定义
MIN_DIM_HEAD = 32
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否可以被另一个数整除
def divisible_by(val, divisor):
return (val / divisor).is_integer()
# 将变量转换为元组
def cast_tuple(val, num = 1):
return val if isinstance(val, tuple) else ((val,) * num)
# 初始化深度网络参数
def deepnorm_init(transformer, beta, module_name_match_list = ['.ff.', '.to_v', '.to_out']):
for name, module in transformer.named_modules():
if type(module) != nn.Linear:
continue
needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
gain = beta if needs_beta_gain else 1
nn.init.xavier_normal_(module.weight.data, gain = gain)
if exists(module.bias):
nn.init.constant_(module.bias.data, 0)
# 归一化
# RMS归一化类
class RMSNorm(nn.Module):
def __init__(
self,
dim,
*,
eps = 1e-8,
gated = False
):
super().__init__()
self.eps = eps
self.scale = dim ** -0.5
self.gamma = nn.Parameter(torch.ones(dim))
self.weight = nn.Parameter(torch.ones(dim)) if gated else None
def forward(self, x):
norm = x.norm(keepdim = True, dim = -1) * self.scale
out = (x / norm.clamp(min = self.eps)) * self.gamma
if not exists(self.weight):
return out
return out * (x * self.weight).sigmoid()
# 前向和后向归一化残差包装模块
# 前向归一化类
class PreNorm(nn.Module):
def __init__(self, dim, fn, norm_klass = RMSNorm):
super().__init__()
self.fn = fn
self.norm = norm_klass(dim)
def forward(self, x, *args, **kwargs):
return self.fn(self.norm(x), *args, **kwargs) + x
# 后向归一化类
class PostNorm(nn.Module):
def __init__(self, dim, fn, scale_residual = 1, norm_klass = RMSNorm):
super().__init__()
self.fn = fn
self.scale_residual = scale_residual
self.norm = norm_klass(dim)
def forward(self, x, *args, **kwargs):
residual = x * self.scale_residual
out = self.fn(x, *args, **kwargs) + residual
return self.norm(out)
# 位置嵌入
# 旋转嵌入类
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, max_seq_len, *, device, offset = 0):
seq = torch.arange(max_seq_len, device = device) + offset
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim = -1)
return rearrange(emb, 'n d -> 1 1 n d')
# 旋转半个位置
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(t, freqs):
seq_len, rot_dim = t.shape[-2], freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim = -1)
# 前馈网络
# 前馈网络类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(mult * dim)
self.ff = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim)
)
def forward(self, x):
return self.ff(x)
# 注意力机制
# 注意力类
class Attention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
causal = False,
dropout = 0.,
null_kv = False
# 初始化函数,设置模型参数
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False,
context_dim = None,
null_kv = False
):
# 调用父类的初始化函数
super().__init__()
# 设置上下文维度,默认为输入维度
context_dim = default(context_dim, dim)
# 设置头数和缩放因子
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads
# 设置dropout层
self.dropout = nn.Dropout(dropout)
# 线性变换层,将输入转换为查询、键、值
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(context_dim, inner_dim, bias = False)
self.to_v = nn.Linear(context_dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
# 允许对空值进行注意力计算,以防止注意力破坏
self.null_k = nn.Parameter(torch.randn(inner_dim)) if null_kv else None
self.null_v = nn.Parameter(torch.randn(inner_dim)) if null_kv else None
# 前向传播函数
def forward(self, x, mask = None, context = None, pos_emb = None):
# 获取输入张量的形状、设备、头数和缩放因子
b, device, h, scale = x.shape[0], x.device, self.heads, self.scale
# 获取键值对输入,默认为输入张量
kv_input = default(context, x)
# 分别对输入进行线性变换得到查询、键、值
q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input)
# 将查询、键、值按头数拆分
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对查询进行缩放
q = q * scale
# 应用相对位置编码(旋转嵌入)
if exists(pos_emb):
q_pos_emb, k_pos_emb = cast_tuple(pos_emb, num = 2)
q = apply_rotary_pos_emb(q, q_pos_emb)
k = apply_rotary_pos_emb(k, k_pos_emb)
# 添加空键/值
if exists(self.null_k):
nk, nv = self.null_k, self.null_v
nk, nv = map(lambda t: repeat(t, '(h d) -> b h 1 d', b = b, h = h), (nk, nv))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 计算查询键相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 掩码
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
if exists(self.null_k):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)
# 如果是因果注意力,进行掩码
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones(i, j, device = device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
# 注意力计算
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# 聚合
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 合并头部
out = rearrange(out, 'b h n d -> b n (h d)')
# 线性变换输出
return self.to_out(out)
class ChunkedCrossAttention(nn.Module):
def __init__(
self,
chunk_size,
**kwargs
):
super().__init__()
self.chunk_size = chunk_size
self.cross_attn = Attention(null_kv = True, **kwargs)
def forward(self, x, *, context_mask = None, context, pos_emb = None):
# derive variables
chunk_size = self.chunk_size
b, n, num_chunks, num_retrieved = x.shape[0], x.shape[-2], *context.shape[-4:-2]
# if sequence length less than chunk size, do an early return
if n < self.chunk_size:
return torch.zeros_like(x)
# causal padding
causal_padding = chunk_size - 1
x = F.pad(x, (0, 0, -causal_padding, causal_padding), value = 0.)
# remove sequence which is ahead of the neighbors retrieved (during inference)
seq_index = (n // chunk_size) * chunk_size
x, x_remainder = x[:, :seq_index], x[:, seq_index:]
seq_remain_len = x_remainder.shape[-2]
# take care of rotary positional embedding
# make sure queries positions are properly shifted to the future
q_pos_emb, k_pos_emb = pos_emb
q_pos_emb = F.pad(q_pos_emb, (0, 0, -causal_padding, causal_padding), value = 0.)
k_pos_emb = repeat(k_pos_emb, 'b h n d -> b h (r n) d', r = num_retrieved)
pos_emb = (q_pos_emb, k_pos_emb)
# reshape so we have chunk to chunk attention, without breaking causality
x = rearrange(x, 'b (k n) d -> (b k) n d', k = num_chunks)
context = rearrange(context, 'b k r n d -> (b k) (r n) d')
if exists(context_mask):
context_mask = rearrange(context_mask, 'b k r n -> (b k) (r n)')
# cross attention
out = self.cross_attn(x, context = context, mask = context_mask, pos_emb = pos_emb)
# reshape back to original sequence
out = rearrange(out, '(b k) n d -> b (k n) d', b = b)
# pad back to original, with 0s at the beginning (which will be added to the residual and be fine)
out = F.pad(out, (0, 0, causal_padding, -causal_padding + seq_remain_len), value = 0.)
return out
# encoder and decoder classes
class Encoder(nn.Module):
def __init__(
self,
dim,
*,
depth,
context_dim = None,
causal = False,
heads = 8,
dim_head = 64,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
final_norm = True,
cross_attn_layers = None,
post_norm = False,
output_dim = None,
norm_klass = RMSNorm,
scale_residual = 1.
):
super().__init__()
self.layers = nn.ModuleList([])
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al https://github.com/kingoflolz/mesh-transformer-jax/
rotary_emb_dim = min(dim_head, MIN_DIM_HEAD)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)
wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass)
for layer_num in range(1, depth + 1):
has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers
self.layers.append(nn.ModuleList([
wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)),
wrapper(Attention(dim = dim, context_dim = context_dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None,
wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
]))
self.norm_out = norm_klass(dim) if final_norm and not post_norm else nn.Identity()
self.project_out = nn.Linear(dim, output_dim) if exists(output_dim) else nn.Identity()
# 定义一个前向传播函数,接受输入 x 和关键字参数 mask 和 chunked_seq
def forward(self, x, *, mask = None, chunked_seq):
# 获取输入 x 的设备信息、分块大小和序列长度
device, chunk_size, seq_len = x.device, x.shape[-2], chunked_seq.shape[-2]
# 生成查询位置编码
q_pos_emb = self.rotary_pos_emb(chunk_size, device = device)
# 生成键值位置编码
k_pos_emb = self.rotary_pos_emb(seq_len, device = device)
# 遍历每个层中的注意力、交叉注意力和前馈网络
for attn, cross_attn, ff in self.layers:
# 使用注意力机制处理输入 x,传入位置编码 q_pos_emb
x = attn(x, mask = mask, pos_emb = q_pos_emb)
# 如果存在交叉注意力层
if exists(cross_attn):
# 使用交叉注意力处理输入 x,传入上下文 chunked_seq 和位置编码 q_pos_emb、k_pos_emb
x = cross_attn(x, context = chunked_seq, pos_emb = (q_pos_emb, k_pos_emb))
# 使用前馈网络处理输入 x
x = ff(x)
# 对处理后的 x 进行输出层的归一化
x = self.norm_out(x)
# 对归一化后的 x 进行输出投影
return self.project_out(x)
class Decoder(nn.Module):
# 定义解码器类
def __init__(
self,
dim,
*,
depth,
heads = 8,
dim_head = 64,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
final_norm = True,
cross_attn_layers = None,
chunk_size = 64,
post_norm = False,
norm_klass = RMSNorm,
scale_residual = 1.
):
# 初始化函数,设置解码器的参数
super().__init__()
self.layers = nn.ModuleList([])
# 部分旋转嵌入,比完整旋转更好
# 王和小松崎等人 https://github.com/kingoflolz/mesh-transformer-jax/
rotary_emb_dim = min(dim_head, MIN_DIM_HEAD)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)
wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass)
self.chunk_size = chunk_size
for layer_num in range(1, depth + 1):
has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers
self.layers.append(nn.ModuleList([
wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = True)),
wrapper(ChunkedCrossAttention(chunk_size = chunk_size, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None,
wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
]))
self.norm_out = norm_klass(dim) if final_norm and not post_norm else nn.Identity()
def forward(self, x, *, encoder = None, encoder_retrieved_mask = None, context_mask = None, retrieved = None):
# 前向传播函数,接收输入 x 和其他参数
device, seq_len = x.device, x.shape[-2]
self_attn_pos_emb = self.rotary_pos_emb(seq_len, device = device)
# 计算序列索引
num_seq_chunks = seq_len // self.chunk_size
seq_index = num_seq_chunks * self.chunk_size
# 在检索的块上进行旋转位置
if exists(retrieved):
num_chunks, num_neighbors, chunk_size = retrieved.shape[-4:-1]
cross_attn_q_pos_emb = self.rotary_pos_emb(self.chunk_size, device = device, offset = self.chunk_size - 1) # 需要添加额外的块大小,因为它将被移位
cross_attn_k_pos_emb = self.rotary_pos_emb(chunk_size, device = device)
cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb)
# ��踪检索的标记是否已编码
retrieved_encoded = False
# 遍历解码器层
for attn, cross_attn, ff in self.layers:
x = attn(x, pos_emb = self_attn_pos_emb)
if exists(cross_attn) and exists(retrieved):
if not retrieved_encoded:
retrieved = rearrange(retrieved, 'b k r n d -> (b k r) n d')
seq_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors)
retrieved = encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = seq_as_context)
retrieved = rearrange(retrieved, '(b k r) n d -> b k r n d', k = num_chunks, r = num_neighbors)
retrieved_encoded = True
x = cross_attn(
x,
context = retrieved,
context_mask = context_mask,
pos_emb = cross_attn_pos_emb
)
x = ff(x)
return self.norm_out(x)
# 主类
class RETRO(nn.Module):
# 定义主类
# 初始化模型参数
def __init__(
self,
*,
num_tokens = BERT_VOCAB_SIZE, # 设置词汇表大小,默认为BERT词汇表大小
max_seq_len = 2048, # 设置最大序列长度,默认为2048
enc_dim = 896, # 设置编码器维度,默认为896
enc_depth = 2, # 设置编码器深度,默认为2
enc_cross_attn_layers = None, # 设置编码器交叉注意力层,默认为None
dec_depth = 12, # 设置解码器深度,默认为12
dec_cross_attn_layers = (1, 3, 6, 9), # 设置解码器交叉注意力层,默认为(1, 3, 6, 9)
heads = 8, # 设置头数,默认为8
dec_dim = 768, # 设置解码器维度,默认为768
dim_head = 64, # 设置每个头的维度,默认为64
enc_attn_dropout = 0., # 设置编码器注意力机制的dropout,默认为0
enc_ff_dropout = 0., # 设置编码器前馈网络的dropout,默认为0
dec_attn_dropout = 0., # 设置解码器注意力机制的dropout,默认为0
dec_ff_dropout = 0., # 设置解码器前馈网络的dropout,默认为0
chunk_size = 64, # 设置块大小,默认为64
pad_id = 0, # 设置填充ID,默认为0
enc_scale_residual = None, # 设置编码器残差缩放,默认为None
dec_scale_residual = None, # 设置解码器残差缩放,默认为None
norm_klass = None, # 设置规范化类,默认为None
gated_rmsnorm = False, # 设置是否使用门控RMSNorm,默认为False
use_deepnet = False # 设置是否使用深度网络,默认为False
):
super().__init__()
assert dim_head >= MIN_DIM_HEAD, f'dimension per head must be greater than {MIN_DIM_HEAD}' # 断言每个头的维度必须大于等于最小维度
self.seq_len = max_seq_len # 设置序列长度为最大序列长度
self.pad_id = pad_id # 设置填充ID
self.token_emb = nn.Embedding(num_tokens, enc_dim) # 创建词嵌入层
self.pos_emb = nn.Embedding(max_seq_len, enc_dim) # 创建位置嵌入层
self.chunk_size = chunk_size # 设置块大小
self.to_decoder_model_dim = nn.Linear(enc_dim, dec_dim) if enc_dim != dec_dim else nn.Identity() # 创建线性层,用于编码器到解码器维度转换
# for deepnet, residual scales
# follow equation in Figure 2. in https://arxiv.org/abs/2203.00555
norm_klass = default(norm_klass, RMSNorm) # 设置规范化类为默认值或RMSNorm
if use_deepnet:
enc_scale_residual = default(enc_scale_residual, 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625) # 如果使用深度网络,则设置编码器残��缩放
dec_scale_residual = default(dec_scale_residual, (3 * dec_depth) ** 0.25) # 如果使用深度网络,则设置解码器残差缩放
norm_klass = nn.LayerNorm # 如果使用深度网络,则设置规范化类为LayerNorm
# allow for gated rmsnorm
if gated_rmsnorm:
norm_klass = partial(RMSNorm, gated = True) # 如果使用门控RMSNorm,则设置规范化类为带有门控的RMSNorm
# define encoder and decoders
self.encoder = Encoder(
dim = enc_dim,
context_dim = dec_dim,
dim_head = dim_head,
depth = enc_depth,
attn_dropout = enc_attn_dropout,
ff_dropout = enc_ff_dropout,
cross_attn_layers = enc_cross_attn_layers,
post_norm = use_deepnet,
norm_klass = norm_klass,
scale_residual = enc_scale_residual,
output_dim = dec_dim
) # 定义编码器
self.decoder = Decoder(
dim = dec_dim,
depth = dec_depth,
dim_head = dim_head,
attn_dropout = dec_attn_dropout,
ff_dropout = dec_ff_dropout,
cross_attn_layers = dec_cross_attn_layers,
chunk_size = chunk_size,
post_norm = use_deepnet,
norm_klass = norm_klass,
scale_residual = dec_scale_residual
) # 定义解码器
self.to_logits = nn.Linear(dec_dim, num_tokens) # 创建线性层,用于将解码器输出映射到词汇表大小
# deepnet has special init of weight matrices
if use_deepnet:
deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625) # 如果使用深度网络,则初始化编码器
deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25) # 如果使用深度网络,则初始化解码器
def forward_without_retrieval(
self,
seq
):
# embed sequence
embed = self.token_emb(seq) # 对序列进行词嵌入
embed = embed[:, :self.seq_len] # 截取指定长度的嵌入序列
# get absolute positional embedding
pos_emb = self.pos_emb(torch.arange(embed.shape[1], device = embed.device)) # 获取绝对位置嵌入
pos_emb = rearrange(pos_emb, 'n d -> 1 n d') # 重新排列位置嵌入
embed = embed + pos_emb # 将位置嵌入加到词嵌入上
embed = self.to_decoder_model_dim(embed) # 将嵌入转换到解码器模型维度
embed = self.decoder(embed) # 解码器处理嵌入序列
# project to logits
return self.to_logits(embed) # 将解码器输出映射到词汇表大小
def forward(
self,
seq,
retrieved = None,
return_loss = False
"""
b - batch
n - sequence length / chunk length
k - number of chunks
d - feature dimension
r - num retrieved neighbors
"""
# 如果没有提供retrieved参数,则直接调用forward_without_retrieval方法
if not exists(retrieved):
return self.forward_without_retrieval(seq)
# 断言只有在训练时才能返回损失
assert not (return_loss and not self.training), 'must be training if returning loss'
# 假设填充标记ID(通常为0)需要被屏蔽掉
mask = retrieved != self.pad_id
# 处理一些用户输入
if retrieved.ndim == 3:
# 重新排列retrieved的维度,将'n'维度变为1
retrieved = rearrange(retrieved, 'b k n -> b k 1 n') # 1 neighbor retrieved
# 如果需要返回损失,则推导标签
if return_loss:
seq, labels = seq[:, :-1], seq[:, 1:]
# 定义变量
n, num_chunks, num_neighbors, chunk_size, retrieved_shape, device = seq.shape[-1], *retrieved.shape[-3:], retrieved.shape, seq.device
# 断言检查retrieved输入的chunk_size必须大于等于RETRO初始化时指定的chunk_size
assert chunk_size >= self.chunk_size, 'chunk size of retrieval input must be greater or equal to the designated chunk_size on RETRO initialization'
# 计算序列需要的chunk数量,并检查传入的num_chunks是否符合要求
num_seq_chunks = n // self.chunk_size
assert num_chunks == num_seq_chunks, f'sequence requires {num_seq_chunks} retrieved chunks, but only {num_chunks} passed in'
# 计算还未获取k个最近邻的序列索引
seq_index = num_seq_chunks * self.chunk_size
# 对序列和retrieved chunks进行嵌入
embed = self.token_emb(seq)
retrieved = self.token_emb(retrieved)
# 获取绝对位置嵌入
pos_emb = self.pos_emb(torch.arange(n, device=device))
pos_emb = rearrange(pos_emb, 'n d -> 1 n d')
embed = embed + pos_emb
# 如果需要,处理编码器和解码器的掩码
encoder_retrieved_mask = decoder_retrieved_mask = None
if exists(mask):
assert mask.shape == retrieved_shape, 'retrieval mask must be of the same shape as the retrieval tokens'
encoder_retrieved_mask = rearrange(mask, 'b k r n -> (b k r) n')
decoder_retrieved_mask = mask
# 如果需要,将序列嵌入和retrieved嵌入投影到解码器维度
embed = self.to_decoder_model_dim(embed)
# 解码
embed = self.decoder(
embed,
encoder=self.encoder,
context_mask=decoder_retrieved_mask,
encoder_retrieved_mask=encoder_retrieved_mask,
retrieved=retrieved
)
# 投影到logits
logits = self.to_logits(embed)
# 如果不需要返回损失,则返回logits
if not return_loss:
return logits
# 计算交叉熵损失
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index=self.pad_id)
return loss


浙公网安备 33010602011771号