Lucidrains-系列项目源码解析-十七-
Lucidrains 系列项目源码解析(十七)

Flexible Diffusion Modeling of Long Videos - Pytorch (wip)
Implementation of the video diffusion model and training scheme presented in the paper, Flexible Diffusion Modeling of Long Videos, in Pytorch. While the Unet architecture does not look that novel (quite similar to Space-time factored unets, where they do attention across time) they achieved up to 25 minutes of coherent video with their specific frame sampling conditioning scheme during training.
I will also attempt to push this approach even further by introducing a super-resoluting module on top identical to what was used in Imagen
Citations
@inproceedings{Harvey2022FlexibleDM,
title = {Flexible Diffusion Modeling of Long Videos},
author = {William Harvey and Saeid Naderiparizi and Vaden Masrani and Christian Weilbach and Frank Wood},
year = {2022}
}

FUSS - Nim (wip)
Implementation of FUSS (Fitness Uniform Selection), a selection method proposed by Marcus Hutter himself for maintaining diversity in evolutionary algorithms, in Nim
Basically will be a rewrite of FUSS in C
Citations
@article{Hutter_2006,
doi = {10.1109/tevc.2005.863127},
url = {https://doi.org/10.1109%2Ftevc.2005.863127},
year = 2006,
month = {oct},
publisher = {Institute of Electrical and Electronics Engineers ({IEEE})},
volume = {10},
number = {5},
pages = {568--589},
author = {M. Hutter and S. Legg},
title = {Fitness uniform optimization},
journal = {{IEEE} Transactions on Evolutionary Computation}
}
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\g-mlp-gpt\g_mlp_gpt\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.seq_len
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
device = start_tokens.device
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
# 前向传播函数,用于计算损失
def forward(self, x, **kwargs):
xi, xo = x[:, :-1], x[:, 1:]
out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\g-mlp-gpt\g_mlp_gpt\g_mlp_gpt.py
# 从 math 模块中导入 ceil 函数,用于向上取整
# 从 functools 模块中导入 partial 函数,用于创建偏函数
# 从 random 模块中导入 randrange 函数,用于生成指定范围内的随机整数
# 导入 torch 模块
# 从 torch.nn.functional 模块中导入 F 别名
# 从 torch 模块中导入 nn、einsum 函数
from math import ceil
from functools import partial
from random import randrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 从 einops 模块中导入 rearrange、repeat 函数
from einops import rearrange, repeat
# 从 g_mlp_gpt.reversible 模块中导入 ReversibleSequence、SequentialSequence 类
# functions
# 定义函数 exists,用于判断值是否存在
def exists(val):
return val is not None
# 定义函数 cast_tuple,用于将值转换为元组
def cast_tuple(val, num):
return ((val,) * num) if not isinstance(val, tuple) else val
# 定义函数 pad_to_multiple,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
seqlen = tensor.shape[dim]
m = seqlen / multiple
if m.is_integer():
return tensor
remainder = ceil(m) * multiple - seqlen
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value = value)
# 定义函数 dropout_layers,用于对层进行随机丢弃
def dropout_layers(layers, prob_survival):
if prob_survival == 1:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival
# 确保至少有一层保留
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# helper classes
# 定义类 Residual,实现残差连接
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 定义类 PreNorm,实现预层归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 定义类 GEGLU,实现门控线性单元
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim * 2),
GEGLU(),
nn.Linear(inner_dim, dim)
)
def forward(self, x):
return self.net(x)
# 定义类 Attention,实现注意力机制
class Attention(nn.Module):
def __init__(self, dim_in, dim_out, dim_inner):
super().__init__()
self.scale = dim_inner ** -0.5
self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)
def forward(self, x):
device = x.device
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
mask = torch.ones(sim.shape[-2:], device = device).triu(1).bool()
sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
return self.to_out(out)
# 定义类 LocalAttention,实现局部注意力机制
class LocalAttention(nn.Module):
def __init__(self, dim_in, dim_inner, dim_out, window = 128):
super().__init__()
self.scale = dim_inner ** -0.5
self.window = window
self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)
# 定义前向传播函数,接受输入 x
def forward(self, x):
# 获取输入 x 的形状信息,包括 batch size、序列长度、设备信息和窗口大小
b, n, *_, device, w = *x.shape, x.device, self.window
# 将输入 x 进行填充,使其长度能够被窗口大小整除
x = pad_to_multiple(x, w, dim = -2, value = 0.)
# 将填充后的 x 分别转换为查询、键、值,并按照最后一个维度分割成三部分
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 定义窗口函数,将输入按照窗口大小重新排列
window_fn = lambda t: rearrange(t, 'b (w n) d -> b w n d', n = w)
q, k, v = map(window_fn, (q, k, v))
# 对键和值进行填充,使其能够进行滑动窗口操作
k, v = map(lambda t: F.pad(t, (0, 0, 0, 0, 1, 0)), (k, v))
k, v = map(lambda t: torch.cat((k[:, :-1], k[:, 1:]), dim = 2), (k, v))
# 计算查询和键之间的相似度,并乘以缩放因子
sim = einsum('b w i d, b w j d -> b w i j', q, k) * self.scale
buckets, i, j = sim.shape[-3:]
# 创建掩码,用于屏蔽无效的位置信息
mask_value = -torch.finfo(sim.dtype).max
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
mask = repeat(mask, 'i j -> () u i j', u = buckets)
# 将掩码应用到相似度矩阵中
sim.masked_fill_(mask, mask_value)
# 对相似度矩阵进行 softmax 操作,得到注意力权重
attn = sim.softmax(dim = -1)
# 根据注意力权重计算输出
out = einsum('b w i j, b w j d -> b w i d', attn, v)
# 将输出重新排列成原始形状
out = rearrange(out, 'b w n d -> b (w n) d')
# 将输出传递给输出层,并返回结果
out = self.to_out(out[:, :n])
return out
# 定义一个类 CausalSGU,继承自 nn.Module
class CausalSGU(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
dim_seq,
init_eps = 1e-3,
heads = 4,
act = nn.Identity()
):
# 调用父类的初始化函数
super().__init__()
# 计算输出维度
dim_out = dim // 2
# 初始化 LayerNorm 模块
self.norm = nn.LayerNorm(dim_out)
# 设置头数和权重、偏置参数
self.heads = heads
self.weight = nn.Parameter(torch.zeros(heads, dim_seq, dim_seq))
self.bias = nn.Parameter(torch.zeros(heads, dim_seq))
# 初始化权重和偏置参数
init_eps /= dim_seq
nn.init.uniform_(self.weight, -init_eps, init_eps)
nn.init.constant_(self.bias, 1.)
# 设置激活函数
self.act = act
# 创建一个缓冲区,用于存储掩码
self.register_buffer('mask', ~torch.ones(dim_seq, dim_seq).triu_(1).bool())
# 前向传播函数,接受输入 x 和 gate_res
def forward(self, x, gate_res = None):
# 获取设备信息、输入序列长度和头数
device, n, h = x.device, x.shape[1], self.heads
# 将输入 x 分成两部分:res 和 gate
res, gate = x.chunk(2, dim = -1)
# 对 gate 进行 LayerNorm 处理
gate = self.norm(gate)
# 获取权重和偏置参数
weight, bias = self.weight, self.bias
weight, bias = weight[:, :n, :n], bias[:, :n]
# 对权重参数应用掩码
weight = weight * self.mask[None, :n, :n].int().float()
# 重排 gate 的维度
gate = rearrange(gate, 'b n (h d) -> b h n d', h = h)
# 执行矩阵乘法操作
gate = einsum('b h n d, h m n -> b h m d', gate, weight)
# 添加偏置参数
gate = gate + rearrange(bias, 'h n -> () h n ()')
# 重排 gate 的维度
gate = rearrange(gate, 'b h n d -> b n (h d)')
# 如果存在 gate_res,则将其加到 gate 上
if exists(gate_res):
gate = gate + gate_res
# 返回激活函数作用后的结果乘以 res
return self.act(gate) * res
# 定义一个类 CausalLocalSGU,继承自 nn.Module
class CausalLocalSGU(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
dim_seq,
init_eps = 1e-3,
heads = 4,
window = 128,
act = nn.Identity()
):
# 调用父类的初始化函数
super().__init__()
# 计算输出维度
dim_out = dim // 2
# 初始化 LayerNorm 模块
self.norm = nn.LayerNorm(dim_out)
# 设置头数、窗口大小和权重、偏置参数
self.heads = heads
self.window = window
self.weight = nn.Parameter(torch.zeros(heads, window, window * 2))
self.bias = nn.Parameter(torch.zeros(heads, window))
# 初始化权重和偏置参数
init_eps /= window
nn.init.uniform_(self.weight, -init_eps, init_eps)
nn.init.constant_(self.bias, 1.)
# 设置激活函数
self.act = act
# 创建一个缓冲区,用于存储掩码
self.register_buffer('mask', ~torch.ones(window, window * 2).triu_(window + 1).bool())
# 前向传播函数,接受输入 x 和 gate_res
def forward(self, x, gate_res = None):
# 获取设备信息、输入序列长度、头数和窗口大小
device, n, h, w = x.device, x.shape[1], self.heads, self.window
# 将输入 x 分成两部分:res 和 gate
res, gate = x.chunk(2, dim = -1)
# 对 gate 进行填充和重排
gate = pad_to_multiple(gate, w, dim = -2)
gate = rearrange(gate, 'b (w n) d -> b w n d', n = w)
# 对 gate 进行 LayerNorm 处理
gate = self.norm(gate)
# 对 gate 进行填充和拼接
gate = F.pad(gate, (0, 0, 0, 0, 1, 0), value = 0.)
gate = torch.cat((gate[:, :-1], gate[:, 1:]), dim = 2)
# 获取权重和偏置参数
weight, bias = self.weight, self.bias
# 对权重参数应用掩码
weight = weight * self.mask[None, ...].int().float()
# 重排 gate 的维度
gate = rearrange(gate, 'b w n (h d) -> b w h n d', h = h)
# 执行矩阵乘法操作
gate = einsum('b w h n d, h m n -> b w h m d', gate, weight)
# 添加偏置参数
gate = gate + rearrange(bias, 'h n -> () () h n ()')
# 重排 gate 的维度
gate = rearrange(gate, 'b w h n d -> b w n (h d)')
# 重排 gate 的维度
gate = rearrange(gate, 'b w n d -> b (w n) d')
gate = gate[:, :n]
# 如果存在 gate_res,则将其加到 gate 上
if exists(gate_res):
gate = gate + gate_res
# 返回激活函数作用后的结果乘以 res
return self.act(gate) * res
# 定义一个类 AxiallyFold,继承自 nn.Module
class AxiallyFold(nn.Module):
# 初始化函数,接受维度、步长和函数参数
def __init__(self, dim, every, fn):
# 调用父类的初始化函数
super().__init__()
# 设置函数和步长
self.fn = fn
self.every = every
# 如果步长大于 1,则创建一个卷积层
self.conv = nn.Conv1d(dim, dim, kernel_size = every, groups = dim) if every > 1 else None
# 前向传播函数,接受输入 x
def forward(self, x):
# 获取步长
every = self.every
# 如果步长小于等于 1,则直接应用函数
if every <= 1:
return self.fn(x)
# 获取序列长度
n = x.shape[1]
# 对输入 x 进行填充和重排
x = pad_to_multiple(x, self.every, dim = -2)
x = rearrange(x, 'b (n e) d -> (b e) n d', e = every)
x = self.fn(x)
# 重排结果并进行填充
x = rearrange(x, '(b e) n d -> b d (n e)', e = every)
x = F.pad(x, (every - 1, 0), value = 0)
# 对结果应用卷积操作
out = self.conv(x)
out = rearrange(out, 'b d n -> b n d')
return out[:, :n]
# 定义一个类 gMLPBlock,继承自 nn.Module
class gMLPBlock(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
dim, # 输入维度
seq_len, # 序列长度
dim_ff, # FeedForward 层维度
heads = 4, # 多头注意力机制的头数,默认为4
causal = False, # 是否使用因果关系,默认为False
window = None, # 窗口大小,默认为None
attn_dim = None, # 注意力机制维度,默认为None
act = nn.Identity() # 激活函数,默认为恒等函数
):
super().__init__()
is_windowed = exists(window) and window < seq_len
# 根据是否存在窗口大小选择不同的 SGU 类型
SGU_klass = partial(CausalLocalSGU, window = window) if is_windowed else CausalSGU
# 根据是否存在窗口大小选择不同的 Attention 类型
Attention_klass = partial(LocalAttention, window = window) if is_windowed else Attention
# 如果存在注意力机制维度,则创建注意力层
self.attn = Attention_klass(dim_in = dim, dim_inner = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None
# 输入投影层,包含线性层和 GELU 激活函数
self.proj_in = nn.Sequential(
nn.Linear(dim, dim_ff),
nn.GELU()
)
# SGU 层,根据选择的 SGU 类型进行初始化
self.sgu = SGU_klass(dim_ff, seq_len, causal, heads = heads, act = act)
# 输出投影层,线性层
self.proj_out = nn.Linear(dim_ff // 2, dim)
# 前向传播函数
def forward(self, x):
# 如果存在注意力层,则进行注意力计算
gate_res = self.attn(x) if exists(self.attn) else None
# 输入投影
x = self.proj_in(x)
# SGU 层计算
x = self.sgu(x, gate_res = gate_res)
# 输出投影
x = self.proj_out(x)
return x
# 主要类
class gMLPGPT(nn.Module):
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量维度
depth, # 模型深度
seq_len, # 序列长度
heads = 1, # 多头注意力机制的头数,默认为1
ff_mult = 4, # FeedForward 层的倍数,默认为4
prob_survival = 1., # 存活概率,默认为1
reversible = False, # 是否可逆,默认为False
window = None, # 窗口大小,默认为None
attn_dim = None, # 注意力维度,默认为None
act = nn.Identity() # 激活函数,默认为恒等函数
):
super().__init__()
dim_ff = dim * ff_mult
self.seq_len = seq_len
self.prob_survival = prob_survival
self.to_embed = nn.Embedding(num_tokens, dim) # 创建嵌入层
window = cast_tuple(window, depth) # 将窗口大小转换为元组
window = tuple(map(lambda t: t if isinstance(t, tuple) else (t, 1), window)) # 将窗口大小转换为元组
attn_dims = cast_tuple(attn_dim, depth) # 将注意力维度转换为元组
assert len(window) == depth, f'num window sizes {len(window)} must be equal to depth {depth}' # 断言窗口大小数量必须等于深度
layers = nn.ModuleList([]) # 创建模块列表
for ind, (w, ax), attn_dim in zip(range(depth), window, attn_dims):
attn_dim = attn_dim if exists(window) else None
get_gmlp = lambda: PreNorm(dim, AxiallyFold(dim, ax, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, heads = heads, window = w, act = act, attn_dim = attn_dim)) # 获取 gMLP 模块
layer_blocks = nn.ModuleList([
get_gmlp()
])
if reversible:
layer_blocks.append(FeedForward(dim, mult = ff_mult)) # 如果是可逆模型,添加 FeedForward 层
layers.append(layer_blocks) # 添加模块列表到层列表
execute_klass = SequentialSequence if not reversible else ReversibleSequence # 根据是否可逆选择执行类
self.net = execute_klass(layers) # 创建执行网络
self.to_logits = nn.Sequential(
nn.LayerNorm(dim), # 层归一化
nn.Linear(dim, num_tokens) # 线性层
)
def forward(self, x):
layer_dropout = 1. - self.prob_survival # 计算层的丢弃率
x = self.to_embed(x) # 嵌入输入序列
out = self.net(x, layer_dropout = layer_dropout) # 通过网络传播输入
return self.to_logits(out) # 返回输出��果
.\lucidrains\g-mlp-gpt\g_mlp_gpt\reversible.py
# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数
# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
# 初始化路由后的参数列表
routed_args = [(dict(), dict()) for _ in range(depth)]
# 获取参数中与路由器匹配的键
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# 根据概率丢弃层的函数
def layer_drop(layers, prob):
to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
blocks = [block for block, drop in zip(layers, to_drop) if not drop]
blocks = layers[:1] if len(blocks) == 0 else blocks
return blocks
# 保存和设置随机数种子的类
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 可逆块类,受启发于 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
# 可逆函数类
class _ReversibleFunction(Function):
@staticmethod
# 前向传播函数,接收上下文对象 ctx,输入数据 x,模块列表 blocks 和参数列表 args
def forward(ctx, x, blocks, args):
# 将参数列表 args 存储到上下文对象 ctx 中
ctx.args = args
# 遍历模块列表 blocks 和参数列表 args,对输入数据 x 进行处理
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
# 将处理后的数据 x 分离出来,并存储到上下文对象 ctx 中
ctx.y = x.detach()
# 将模块列表 blocks 存储到上下文对象 ctx 中
ctx.blocks = blocks
# 返回处理后的数据 x
return x
# 反向传播函数,接收上下文对象 ctx 和梯度 dy
@staticmethod
def backward(ctx, dy):
# 获取上下文对象 ctx 中存储的处理后的数据 y 和参数列表 args
y = ctx.y
args = ctx.args
# 反向遍历模块列表 blocks 和参数列表 args,对梯度 dy 进行处理
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
# 调用模块的反向传播函数,更新梯度 dy 和数据 y
y, dy = block.backward_pass(y, dy, **kwargs)
# 返回更新后的梯度 dy
return dy, None, None
# 定义一个继承自 nn.Module 的类 SequentialSequence
class SequentialSequence(nn.Module):
# 初始化函数,接受层列表、参数路由字典和层丢弃率作为参数
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
super().__init__()
# 断言每个参数路由映射的深度与顺序层的数量相同
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
# 前向传播函数,接受输入 x 和关键字参数 kwargs
def forward(self, x, **kwargs):
# 根据参数路由和关键字参数获取参数
args = route_args(self.args_route, kwargs, len(self.layers))
# 将层和参数组成元组列表
layers_and_args = list(zip(self.layers, args))
# 如果处于训练状态且层丢弃率大于0
if self.training and self.layer_dropout > 0:
# 对层和参数进行层丢弃
layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
# 遍历层和参数列表,对输入 x 进行操作
for (f,), (f_args, _) in layers_and_args:
x = x + f(x, **f_args)
# 返回处理后的 x
return x
# 定义一个继承自 nn.Module 的类 ReversibleSequence
class ReversibleSequence(nn.Module):
# 初始化函数,接受块列表、参数路由字典和层丢弃率作为参数
def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
super().__init__()
self.args_route = args_route
self.layer_dropout = layer_dropout
# 创建包含可逆块的模块列表
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
# 前向传播函数,接受输入 x、层丢弃率和关键字参数 kwargs
def forward(self, x, layer_dropout = 0., **kwargs):
# 在最后一个维度上连接输入 x 的副本
x = torch.cat([x, x], dim=-1)
# 获取块列表和参数
blocks = self.blocks
args = route_args(self.args_route, kwargs, len(blocks))
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
# 将块和参数组成元组列表
layers_and_args = list(zip(blocks, args))
# 如果处于训练状态且层丢弃率大于0
if self.training and layer_dropout > 0:
# 对块和参数进行层丢弃
layers_and_args = layer_drop(layers_and_args, layer_dropout)
# 分别获取块和参数
blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))
# 调用自定义的可逆函数进行处理
out = _ReversibleFunction.apply(x, blocks, args)
# 在最后一个维度上分割输出并求和
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
.\lucidrains\g-mlp-gpt\g_mlp_gpt\__init__.py
# 从 g_mlp_gpt.g_mlp_gpt 模块中导入 gMLPGPT 类
from g_mlp_gpt.g_mlp_gpt import gMLPGPT
GPT - gMLP
This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will contain a variant that does gMLP for local sliding windows. The hope is to be able to stretch a single GPU to be able to train context lengths of 4096 and above efficiently and well.
You can also add the "tiny" attention (as described in the paper) with the attn_dim keyword argument, which corresponds to the dimension of the single head (64 is recommended). You can pass in a tuple to customize different dimension per layer.
Install
$ pip install g-mlp-gpt
Usage
import torch
from g_mlp_gpt import gMLPGPT
model = gMLPGPT(
num_tokens = 20000,
dim = 512,
depth = 4,
seq_len = 1024,
window = (128, 256, 512, 1024) # window sizes for each depth
)
x = torch.randint(0, 20000, (1, 1000))
logits = model(x) # (1, 1000, 20000)
16k context length
import torch
from g_mlp_gpt import gMLPGPT
model = gMLPGPT(
num_tokens = 20000,
dim = 512,
seq_len = 16384,
reversible = True, # reversible networks
act = nn.Tanh(), # tanh activation for spatial gating
depth = 12,
window = (
128,
128,
256,
512,
1024,
1024,
(2048, 2), # window size of 2048, axial of 2
(2048, 2),
(4096, 4),
(4096, 4),
(8192, 8), # window size of 8192, axial of 8
(8192, 8)
)
).cuda()
x = torch.randint(0, 20000, (1, 16384)).cuda()
logits = model(x) # (1, 16384, 20000)
Citations
@misc{liu2021pay,
title = {Pay Attention to MLPs},
author = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year = {2021},
eprint = {2105.08050},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\g-mlp-gpt\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包的名称
name = 'g-mlp-gpt',
# 查找所有包
packages = find_packages(),
# 版本号
version = '0.0.15',
# 许可证
license='MIT',
# 描述
description = 'gMLP - GPT',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/g-mlp-gpt',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'multi-layered-preceptrons'
],
# 安装依赖
install_requires=[
'einops>=0.3',
'torch>=1.6'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\g-mlp-gpt\train.py
# 导入必要的库
from g_mlp_gpt import gMLPGPT
from g_mlp_gpt.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 768
SEQ_LEN = 768
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = gMLPGPT(
num_tokens = 256,
dim = 512,
seq_len = SEQ_LEN,
depth = 8,
window = (16, 32, 64, 128, 256, 512, 768, SEQ_LEN),
attn_dim = 16
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\g-mlp-pytorch\g_mlp_pytorch\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.seq_len
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
device = start_tokens.device
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
# 前向传播函数,用于计算损失
def forward(self, x, **kwargs):
xi, xo = x[:, :-1], x[:, 1:]
out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\g-mlp-pytorch\g_mlp_pytorch\g_mlp_pytorch.py
# 从 random 模块中导入 randrange 函数
# 从 torch 模块中导入相关函数和类
# 从 einops 模块中导入 rearrange, repeat 函数以及 Rearrange, Reduce 类
from random import randrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# functions
# 判断值是否存在的函数
def exists(val):
return val is not None
# 将输入值转换为元组的函数
def pair(val):
return (val, val) if not isinstance(val, tuple) else val
# 对层进行 dropout 处理的函数
def dropout_layers(layers, prob_survival):
if prob_survival == 1:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival
# 确保至少有一层保留
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# 对张量进行平移的函数
def shift(t, amount, mask = None):
if amount == 0:
return t
return F.pad(t, (0, 0, amount, -amount), value = 0.)
# helper classes
# 残差连接的类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 对输入进行预平移的类
class PreShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
if self.shifts == (0,):
return self.fn(x, **kwargs)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args), zip(segments_to_shift, shifts)))
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)
# 对输入进行预归一化的类
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 注意力机制类
class Attention(nn.Module):
def __init__(self, dim_in, dim_out, dim_inner, causal = False):
super().__init__()
self.scale = dim_inner ** -0.5
self.causal = causal
self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)
def forward(self, x):
device = x.device
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if self.causal:
mask = torch.ones(sim.shape[-2:], device = device).triu(1).bool()
sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
return self.to_out(out)
# 空间门控单元类
class SpatialGatingUnit(nn.Module):
def __init__(
self,
dim,
dim_seq,
causal = False,
act = nn.Identity(),
heads = 1,
init_eps = 1e-3,
circulant_matrix = False
):
super().__init__()
dim_out = dim // 2
self.heads = heads
self.causal = causal
self.norm = nn.LayerNorm(dim_out)
self.act = act
# 参数
if circulant_matrix:
self.circulant_pos_x = nn.Parameter(torch.ones(heads, dim_seq))
self.circulant_pos_y = nn.Parameter(torch.ones(heads, dim_seq))
self.circulant_matrix = circulant_matrix
shape = (heads, dim_seq,) if circulant_matrix else (heads, dim_seq, dim_seq)
weight = torch.zeros(shape)
self.weight = nn.Parameter(weight)
init_eps /= dim_seq
nn.init.uniform_(self.weight, -init_eps, init_eps)
self.bias = nn.Parameter(torch.ones(heads, dim_seq))
# 定义前向传播函数,接受输入 x 和门控信息 gate_res
def forward(self, x, gate_res = None):
# 获取输入 x 的设备信息、特征维度 n 和注意力头数 h
device, n, h = x.device, x.shape[1], self.heads
# 将输入 x 切分为结果 res 和门控信息 gate
res, gate = x.chunk(2, dim = -1)
# 对门控信息 gate 进行归一化处理
gate = self.norm(gate)
# 获取权重和偏置参数
weight, bias = self.weight, self.bias
# 如果使用循环矩阵
if self.circulant_matrix:
# 构建循环矩阵
# 获取权重参数的最后一个维度大小
dim_seq = weight.shape[-1]
# 在权重参数的最后一个维度上进行填充
weight = F.pad(weight, (0, dim_seq), value = 0)
weight = repeat(weight, '... n -> ... (r n)', r = dim_seq)
weight = weight[:, :-dim_seq].reshape(h, dim_seq, 2 * dim_seq - 1)
weight = weight[:, :, (dim_seq - 1):]
# 赋予循环矩阵绝对位置感知
pos_x, pos_y = self.circulant_pos_x, self.circulant_pos_y
weight = weight * rearrange(pos_x, 'h i -> h i ()') * rearrange(pos_y, 'h j -> h () j')
# 如果是因果关系
if self.causal:
# 裁剪权重和偏置参数
weight, bias = weight[:, :n, :n], bias[:, :n]
# 创建掩码,使得只能看到当前位置及之前的信息
mask = torch.ones(weight.shape[-2:], device = device).triu_(1).bool()
mask = rearrange(mask, 'i j -> () i j')
weight = weight.masked_fill(mask, 0.)
# 重排门控信息 gate 的维度
gate = rearrange(gate, 'b n (h d) -> b h n d', h = h)
# 执行矩阵乘法操作
gate = einsum('b h n d, h m n -> b h m d', gate, weight)
# 加上偏置参数
gate = gate + rearrange(bias, 'h n -> () h n ()')
# 重排门控信息 gate 的维度
gate = rearrange(gate, 'b h n d -> b n (h d)')
# 如果存在门控信息 gate_res,则将其加到 gate 上
if exists(gate_res):
gate = gate + gate_res
# 对 gate 执行激活函数,并乘以结果 res
return self.act(gate) * res
# 定义 gMLPBlock 类,继承自 nn.Module 类
class gMLPBlock(nn.Module):
# 初始化函数
def __init__(
self,
*,
dim, # 输入维度
dim_ff, # Feed-Forward 层维度
seq_len, # 序列长度
heads = 1, # 多头注意力机制中的头数
attn_dim = None, # 注意力机制的维度
causal = False, # 是否使用因果关系
act = nn.Identity(), # 激活函数,默认为恒等映射
circulant_matrix = False # 是否使用循环矩阵
):
super().__init__()
# 输入投影层,包含线性变换和 GELU 激活函数
self.proj_in = nn.Sequential(
nn.Linear(dim, dim_ff),
nn.GELU()
)
# 如果存在注意力机制的维度,则创建注意力对象
self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None
# 空间门控单元
self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, circulant_matrix = circulant_matrix)
# 输出投影层
self.proj_out = nn.Linear(dim_ff // 2, dim)
# 前向传播函数
def forward(self, x):
# 如果存在注意力对象,则进行注意力计算
gate_res = self.attn(x) if exists(self.attn) else None
x = self.proj_in(x) # 输入投影
x = self.sgu(x, gate_res = gate_res) # 空间门控单元
x = self.proj_out(x) # 输出投影
return x
# 主要类
# 定义 gMLP 类,继承自 nn.Module 类
class gMLP(nn.Module):
# 初始化函数
def __init__(
self,
*,
num_tokens = None, # 标记数量
dim, # 输入维度
depth, # 深度
seq_len, # 序列长度
heads = 1, # 多头注意力机制中的头数
ff_mult = 4, # Feed-Forward 层维度倍数
attn_dim = None, # 注意力机制的维度
prob_survival = 1., # 生存概率
causal = False, # 是否使用因果关系
circulant_matrix = False, # 是否使用循环矩阵
shift_tokens = 0, # 标记偏移
act = nn.Identity() # 激活函数,默认为恒等映射
):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
dim_ff = dim * ff_mult
self.seq_len = seq_len
self.prob_survival = prob_survival
# Embedding 层
self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()
token_shifts = tuple(range(0 if causal else -shift_tokens, shift_tokens + 1))
# 层列表
self.layers = nn.ModuleList([Residual(PreNorm(dim, PreShiftTokens(token_shifts, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act, circulant_matrix = circulant_matrix))) for i in range(depth)])
# 输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
) if exists(num_tokens) else nn.Identity()
# 前向传播函数
def forward(self, x):
x = self.to_embed(x) # Embedding
layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
out = nn.Sequential(*layers)(x) # 层序列
return self.to_logits(out) # 输出层
# 定义 gMLPVision 类,继承自 nn.Module 类
class gMLPVision(nn.Module):
# 初始化函数
def __init__(
self,
*,
image_size, # 图像尺寸
patch_size, # 补丁尺寸
num_classes, # 类别数量
dim, # 输入维度
depth, # 深度
heads = 1, # 多头注意力机制中的头数
ff_mult = 4, # Feed-Forward 层维度倍数
channels = 3, # 通道数
attn_dim = None, # 注意力机制的维度
prob_survival = 1. # 生存概率
):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width)
dim_ff = dim * ff_mult
# 补丁嵌入层
self.to_patch_embed = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_height, p2 = patch_width),
nn.Linear(channels * patch_height * patch_width, dim)
)
self.prob_survival = prob_survival
# 层列表
self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = num_patches, attn_dim = attn_dim))) for i in range(depth)])
# 输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
Reduce('b n d -> b d', 'mean'),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, x):
x = self.to_patch_embed(x) # 补丁嵌入
layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
x = nn.Sequential(*layers)(x) # 层序列
return self.to_logits(x) # 输出层
.\lucidrains\g-mlp-pytorch\g_mlp_pytorch\__init__.py
# 从 g_mlp_pytorch.g_mlp_pytorch 模块中导入 gMLP, gMLPVision, gMLPBlock, SpatialGatingUnit 类
from g_mlp_pytorch.g_mlp_pytorch import gMLP, gMLPVision, gMLPBlock, SpatialGatingUnit

gMLP - Pytorch
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Install
$ pip install g-mlp-pytorch
Usage
For masked language modelling
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
circulant_matrix = True, # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn.Tanh() # activation for spatial gate (defaults to identity)
)
x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
For image classification
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 512,
depth = 6
)
img = torch.randn(1, 3, 256, 256)
logits = model(img) # (1, 1000)
You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 512,
depth = 6,
attn_dim = 64
)
img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)
Non-square images and patch sizes
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = (256, 128),
patch_size = (16, 8),
num_classes = 1000,
dim = 512,
depth = 6,
attn_dim = 64
)
img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)
Experimental
A independent researcher proposes using a multi-headed approach for gMLPs in a blogpost on Zhihu. To do so, just set heads to be greater than 1
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
causal = True,
circulant_matrix = True,
heads = 4 # 4 heads
)
x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
Citations
@misc{liu2021pay,
title = {Pay Attention to MLPs},
author = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year = {2021},
eprint = {2105.08050},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = aug,
year = 2021,
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578%7D
}
.\lucidrains\g-mlp-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'g-mlp-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.5', # 版本号
license='MIT', # 许可证
description = 'gMLP - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/g-mlp-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'multi-layered-preceptrons'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\g-mlp-pytorch\train.py
# 导入所需的库
from g_mlp_pytorch import gMLP
from g_mlp_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 768
SEQ_LEN = 768
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = gMLP(
num_tokens = 256,
dim = 512,
seq_len = SEQ_LEN,
depth = 8,
causal = True
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 nn 模块
from torch import nn
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数,用于对 logits 进行 top k 过滤
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value=0, max_seq_len=4096):
super().__init__()
self.max_seq_len = max_seq_len
self.pad_value = pad_value
self.net = net
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(
self,
start_tokens,
seq_len,
eos_token=None,
temperature=1.0,
filter_thres=0.9,
**kwargs
):
b, n, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(
out[:, -self.max_seq_len:],
**kwargs
)[:, -1]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
return out[:, n:]
# 前向传播函数,用于模型训练
def forward(self, x, **kwargs):
inp, labels = x[:, :-1], x[:, 1:]
return self.net(inp, labels=labels, **kwargs)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\dsconv.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.fft import rfft, irfft
from einops import rearrange
from scipy.fftpack import next_fast_len
# functions
# 检查值是否存在
def exists(val):
return val is not None
# 在张量中添加指定数量的维度
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
# 使用傅立叶技巧进行 O(N log(N)) 的一维卷积
def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
assert weight_dim >= dim
N = x.shape[dim]
M = weights.shape[weight_dim]
fast_len = next_fast_len(N + M - 1)
f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
out = out.roll(-1, dims = (dim,))
indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
out = out.index_select(dim, indices)
return out
# classes
# 高效的深度可分离卷积模块
class EfficientDsConv(nn.Module):
def __init__(
self,
*,
dim,
heads
):
super().__init__()
assert (dim % heads) == 0
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.to_weight = nn.Linear(dim, heads, bias = False)
# 参数 D
self.param_D = nn.Parameter(torch.randn(dim))
def forward(self, x):
device, seq_len = x.device, x.shape[1]
u = self.norm(x)
# 学习的加权残差
residual = u * self.param_D
# dsconv 核取决于序列长度
K = self.to_weight(x)
K = torch.flip(K, dims = (1,))
# 一维卷积傅立叶变换 O(nlog(n))
u = rearrange(u, '... (h d) -> ... h d', h = self.heads)
out = conv1d_fft(u, K, dim = -3, weight_dim = -2)
out = rearrange(out, '... h d -> ... (h d)')
return out + residual
# 门控深度可分离卷积模块
class GatedDsConv(nn.Module):
""" Pseudocode 3.2 """
""" except state spaces replaced with regular learned convolution kernel """
def __init__(
self,
*,
dim,
heads = 8,
dim_dsconv = 512,
dim_expansion_factor = 4,
reverse_seq = False
):
super().__init__()
assert (dim_dsconv % heads) == 0
self.reverse_seq = reverse_seq
self.norm = nn.LayerNorm(dim)
dim_hidden = int(dim_expansion_factor * dim)
self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
self.to_v = nn.Sequential(nn.Linear(dim, dim_dsconv, bias = False), nn.GELU())
self.dsconv = EfficientDsConv(dim = dim_dsconv, heads = heads)
self.to_gate = nn.Linear(dim_dsconv, dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x):
if self.reverse_seq:
x = torch.flip(x, dims = (1,))
residual, x = x.clone(), self.norm(x)
u = self.to_u(x)
v = self.to_v(x)
v = self.dsconv(v)
uc = self.to_gate(v)
out = self.to_out(uc * u)
out = out + residual
if self.reverse_seq:
out = torch.flip(out, dims = (1,))
return out
# 门控深度可分离卷积 LM
class GatedDsConvLM(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
heads = 8,
dim_dsconv = 512,
max_seq_len = 2048,
dim_expansion_factor = 4,
):
# 初始化函数,继承父类的初始化方法
super().__init__()
# 创建一个嵌入层,用于将输入的 token 转换为指定维度的向量表示
self.token_emb = nn.Embedding(num_tokens, dim)
# 设置最大序列长度
self.max_seq_len = max_seq_len
# 创建一个空的神经网络层列表
self.layers = nn.ModuleList([])
# 根据深度循环创建 GatedDsConv 层,并添加到神经网络层列表中
for _ in range(depth):
self.layers.append(
GatedDsConv(
dim = dim,
heads = heads,
dim_dsconv = dim_dsconv,
dim_expansion_factor = dim_expansion_factor
)
)
# 创建一个线性层,用于将输出的向量转换为预测的 token
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
def forward(self, x, labels = None):
# 断言输入的序列长度不超过最大序列长度
assert x.shape[1] <= self.max_seq_len
# 将输入的 token 转换为向量表示
x = self.token_emb(x)
# 遍历神经网络层列表,依次对输入进行处理
for dsconv in self.layers:
x = dsconv(x)
# 将处理后的向量转换为预测的 token
logits = self.to_logits(x)
# 如果没有提供标签,则直接返回预测结果
if not exists(labels):
return logits
# 重新排列预测结果的维度,以便计算交叉熵损失
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失并返回
return F.cross_entropy(logits, labels)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\gss.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.fft import rfft, irfft
from einops import rearrange
# functions
# 检查值是否存在
def exists(val):
return val is not None
# classes
# 定义 DSS 类
class DSS(nn.Module):
def __init__(
self,
*,
dim,
kernel_N = 512,
dss_kernel_lambda_imag_exp = True
):
super().__init__()
self.norm = nn.LayerNorm(dim)
# Lambda
# 初始化 Lambda 的实部参数
self.Lambda_real = nn.Parameter(torch.randn(kernel_N))
# 初始化 Lambda 的虚部参数
self.Lambda_imag = nn.Parameter(torch.randn(kernel_N))
# C
# 初始化 C 的实部参数
self.C_real = nn.Parameter(torch.randn(dim, kernel_N))
# 初始化 C 的虚部参数
self.C_imag = nn.Parameter(torch.randn(dim, kernel_N))
# params D
# 初始化参数 D
self.param_D = nn.Parameter(torch.randn(dim))
# 是否对 Lambda 的虚部进行指数运算
self.dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp
def forward(self, x):
"""
einstein notation:
b - batch
l - sequence length
d - dimension
"""
device, seq_len = x.device, x.shape[1]
u = self.norm(x)
# learned weighted residual
# 计算加权残差
residual = u * self.param_D
# derive simple dss kernel
# 计算简单的 DSS 核
Lambda_imag = self.Lambda_imag.exp() if self.dss_kernel_lambda_imag_exp else self.Lambda_imag
Lambda = -self.Lambda_real.exp() + 1j * Lambda_imag
C = self.C_real + 1j * self.C_imag
arange = torch.arange(seq_len, device = device)
S = (rearrange(Lambda, 'n -> n 1') * rearrange(arange, 'l -> 1 l')).exp()
C = C * (Lambda.exp() - 1) / Lambda
K = einsum('h n, n l -> l h', C, S).real
# conv1d fft O(nlog(n))
u_f = rfft(u, n = seq_len * 2, dim = -2)
K_f = rfft(K, n = seq_len * 2, dim = -2)
y = irfft(u_f * K_f, seq_len * 2, dim = -2)[..., :seq_len, :]
return y + residual
# 定义 GSS 类
class GSS(nn.Module):
""" Pseudocode 3.2 """
def __init__(
self,
*,
dim,
dim_expansion_factor = 4,
dss_kernel_N = 512,
dss_kernel_H = 256,
reverse_seq = False,
dss_kernel_lambda_imag_exp = True
):
super().__init__()
self.reverse_seq = reverse_seq
self.norm = nn.LayerNorm(dim)
dim_hidden = int(dim_expansion_factor * dim)
self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
self.to_v = nn.Sequential(nn.Linear(dim, dss_kernel_H, bias = False), nn.GELU())
self.dss = DSS(dim = dss_kernel_H, kernel_N = dss_kernel_N, dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp)
self.to_gate = nn.Linear(dss_kernel_H, dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x):
if self.reverse_seq:
x = torch.flip(x, dims = (1,))
residual, x = x.clone(), self.norm(x)
u = self.to_u(x)
v = self.to_v(x)
v = self.dss(v)
uc = self.to_gate(v)
out = self.to_out(uc * u)
out = out + residual
if self.reverse_seq:
out = torch.flip(out, dims = (1,))
return out
# Gated State Spaces LM
# 定义 GatedStateSpacesLM 类
class GatedStateSpacesLM(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
dim_expansion_factor = 4,
dss_kernel_N = 512,
dss_kernel_H = 256,
dss_kernel_lambda_imag_exp = True
# 初始化函数,继承父类的初始化方法
):
# 调用父类的初始化方法
super().__init__()
# 创建一个嵌入层,用于将输入的标记转换为向量表示
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个空的模块列表,用于存储多个 GSS 模块
self.layers = nn.ModuleList([])
# 循环创建 depth 次 GSS 模块,并添加到模块列表中
for _ in range(depth):
self.layers.append(
GSS(
dim = dim,
dss_kernel_H = dss_kernel_H,
dss_kernel_N = dss_kernel_N,
dim_expansion_factor = dim_expansion_factor,
dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp
)
)
# 创建一个线性层,用于将模型输出的向量转换为预测的标记
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
# 前向传播函数,接收输入 x 和标签 labels
def forward(self, x, labels = None):
# 将输入的标记转换为向量表示
x = self.token_emb(x)
# 遍历模块列表中的每个 GSS 模块,依次对输入进行处理
for gss in self.layers:
x = gss(x)
# 将处理后的向量转换为预测的标记
logits = self.to_logits(x)
# 如果没有提供标签,则直接返回预测结果
if not exists(labels):
return logits
# 重新排列 logits 的维度,以适应交叉熵损失函数的输入要求
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失并返回
return F.cross_entropy(logits, labels)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\mhesa.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.fft import rfft, irfft
from einops import rearrange
from scipy.fftpack import next_fast_len
# functions
# 检查值是否存在
def exists(val):
return val is not None
# 在张量中添加指定数量的维度
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
# 使用傅立叶技巧进行 O(N log(N)) 的一维卷积
def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
assert weight_dim >= dim
N = x.shape[dim]
M = weights.shape[weight_dim]
fast_len = next_fast_len(N + M - 1)
f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
out = out.roll(-1, dims = (dim,))
indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
out = out.index_select(dim, indices)
return out
# classes
# MHESA 模块
class MHESA(nn.Module):
""" used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """
def __init__(
self,
*,
dim,
heads,
reverse_seq = False
):
super().__init__()
assert (dim % heads) == 0
self.reverse_seq = reverse_seq
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.alphas = nn.Parameter(torch.randn(heads))
self.dampen_factors = nn.Parameter(torch.randn(heads))
# params D
self.param_D = nn.Parameter(torch.randn(dim))
def forward(self, x):
"""
einstein notation:
b - batch
h - heads
l - sequence length
d - dimension
"""
if self.reverse_seq:
x = torch.flip(x, dims = (1,))
device, seq_len = x.device, x.shape[1]
u = self.norm(x)
# learned weighted residual
residual = u * self.param_D
# weights derived from alphas (learned exponential smoothing decay rate)
alphas = self.alphas.sigmoid()
dampen_factors = self.dampen_factors.sigmoid()
reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))
# conv1d fft O(nlog(n))
u = rearrange(u, '... (h d) -> ... h d', h = self.heads)
out = conv1d_fft(u, K, dim = -3, weight_dim = -2)
out = rearrange(out, '... h d -> ... (h d)')
out = out + residual
if self.reverse_seq:
out = torch.flip(out, dims = (1,))
return out
# GatedMHESA 模块
class GatedMHESA(nn.Module):
""" Pseudocode 3.2 """
""" except state spaces replaced with multi-head exponential smoothing with learned alpha """
""" used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """
def __init__(
self,
*,
dim,
heads = 8,
dim_mhesa = 512,
dim_expansion_factor = 4,
):
super().__init__()
assert (dim_mhesa % heads) == 0
self.norm = nn.LayerNorm(dim)
dim_hidden = int(dim_expansion_factor * dim)
self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
self.to_v = nn.Sequential(nn.Linear(dim, dim_mhesa, bias = False), nn.GELU())
self.mhesa = MHESA(dim = dim_mhesa, heads = heads)
self.to_gate = nn.Linear(dim_mhesa, dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x):
residual, x = x.clone(), self.norm(x)
u = self.to_u(x)
v = self.to_v(x)
v = self.mhesa(v)
uc = self.to_gate(v)
out = self.to_out(uc * u)
return out + residual
# Gated Dsconv LM
class GatedExponentialSmoothingLM(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量维度
depth, # 模型深度
heads = 8, # 多头注意力机制的头数
dim_mhesa = 512, # MHESA 模块的维度
dim_expansion_factor = 4, # 扩展因子
):
super().__init__()
# 创建标记嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建多个 GatedMHESA 层
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
GatedMHESA(
dim = dim,
heads = heads,
dim_mhesa = dim_mhesa,
dim_expansion_factor = dim_expansion_factor
)
)
# 创建输出层
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
# 前向传播函数
def forward(self, x, labels = None):
# 对输入进行标记嵌入
x = self.token_emb(x)
# 遍历多个 GatedMHESA 层
for mhesa in self.layers:
x = mhesa(x)
# 将结果传入输出层
logits = self.to_logits(x)
# 如果没有标签,则直接返回结果
if not exists(labels):
return logits
# 重新排列 logits 的维度
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失
return F.cross_entropy(logits, labels)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\__init__.py
# 从 gated_state_spaces_pytorch.gss 模块中导入 GSS 和 GatedStateSpacesLM 类
from gated_state_spaces_pytorch.gss import GSS, GatedStateSpacesLM
# 从 gated_state_spaces_pytorch.dsconv 模块中导入 GatedDsConv 和 GatedDsConvLM 类
from gated_state_spaces_pytorch.dsconv import GatedDsConv, GatedDsConvLM
# 从 gated_state_spaces_pytorch.mhesa 模块中导入 GatedExponentialSmoothingLM 和 GatedMHESA 类
from gated_state_spaces_pytorch.mhesa import GatedExponentialSmoothingLM, GatedMHESA

Gated State Spaces - Pytorch
Implementation of Gated State Spaces, from the paper Long Range Language Modeling via Gated State Spaces, in Pytorch. In particular, it will contain the hybrid version containing local self attention with the long-range GSS.
It will also contain a few more settings to compare state spaces to a sequence-wise GLU depthwise conv, and even simpler, a parameterized exponential moving average along the sequence dimension. So we get to the bottom of whether state spaces are worth it, or whether it is really all about the O(L log(L)) FFT convolution trick. Results will be shared in the readme.
I will also pit the GSS module against the Path-X challenge and see how well it does.
Update: This paper has beat S4 on LRA using multi-headed EMA + single head attention.
Install
$ pip install gated-state-spaces-pytorch
Usage
import torch
from gated_state_spaces_pytorch import GSS
gss = GSS(
dim = 512, # dimension
dim_expansion_factor = 4, # hidden dimension (expansion factor x dim) = 2048
dss_kernel_N = 512,
dss_kernel_H = 256
)
x = torch.randn(1, 65536, 512)
out = gss(x) # (1, 65536, 512)
Gated state spaces language model
import torch
from gated_state_spaces_pytorch import GatedStateSpacesLM
gss_lm = GatedStateSpacesLM(
num_tokens = 20000,
depth = 12,
dim = 512,
dim_expansion_factor = 4,
dss_kernel_N = 512,
dss_kernel_H = 256
)
ids = torch.randint(0, 20000, (1, 1024))
logits = gss_lm(ids) # (1, 1024, 20000)
Todo
Citations
@inproceedings{Mehta2022LongRL,
title = {Long Range Language Modeling via Gated State Spaces},
author = {Harsh Mehta and Ankit Gupta and Ashok Cutkosky and Behnam Neyshabur},
year = {2022}
}
@misc{woo2022etsformer,
title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
year = {2022},
eprint = {2202.01381},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\gated-state-spaces-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'gated-state-spaces-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.0',
# 许可证
license='MIT',
# 描述
description = 'Gated State Spaces - GSS - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/gated-state-spaces-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'state spaces',
'long context'
],
# 安装依赖
install_requires=[
'einops>=0.4',
'scipy',
'torch>=1.6',
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\gated-state-spaces-pytorch\train.py
# 导入所需的库
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入自定义的模块
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 4096
# 定义辅助函数
# 生成数据加载器的无限循环
def cycle(loader):
while True:
for data in loader:
yield data
# 将 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 将 tokens 解码为字符串
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = GatedStateSpacesLM(
num_tokens = 256,
dim = 512,
depth = 8
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
.\lucidrains\gateloop-transformer\gateloop_transformer\associative_scan.py
# 从 S5-pytorch 代码库中获取的代码段
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
# 将被调整以在小规模上测试 GateLoop https://arxiv.org/abs/2311.01927
import torch
from torch import Tensor
import torch.nn.functional as F
from typing import Tuple, Callable
# 辅助函数
def pad_at_dim(t, pad, dim = -1, value = 0.):
# 在指定维度上填充张量
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# Pytorch 实现的 jax.lax.associative_scan
# 专门用于轴为1的情况(用于自回归建模的令牌序列)
def associative_scan(
operator: Callable,
elems: Tuple[Tensor, Tensor]
):
num_elems = int(elems[0].shape[1])
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
raise ValueError('Array inputs to associative_scan must have the same '
'first dimension. (saw: {})'
.format([elem.shape for elem in elems]))
def _scan(elems):
"""对 `elems` 执行扫描操作."""
num_elems = elems[0].shape[1]
if num_elems < 2:
return elems
# 组合相邻的元素对。
reduced_elems = operator(
[elem[:, :-1:2] for elem in elems],
[elem[:, 1::2] for elem in elems])
# 递归计算部分减少张量的扫描。
odd_elems = _scan(reduced_elems)
if num_elems % 2 == 0:
even_elems = operator(
[e[:, :-1] for e in odd_elems],
[e[:, 2::2] for e in elems])
else:
even_elems = operator(
odd_elems,
[e[:, 2::2] for e in elems])
# 扫描的第一个元素与原始 `elems` 的第一个元素相同。
even_elems = [
torch.cat([elem[:, :1], result], dim=1)
for (elem, result) in zip(elems, even_elems)]
return list(map(_interleave, even_elems, odd_elems))
return _scan(elems)
def _interleave(a, b):
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
output_axis_len = a_axis_len + b_axis_len
if (a_axis_len == (b_axis_len + 1)):
b = pad_at_dim(b, (0, 1), dim = 1)
stacked = torch.stack([a, b], dim=2)
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
return interleaved[:, :output_axis_len]
.\lucidrains\gateloop-transformer\gateloop_transformer\gateloop_transformer.py
from functools import partial # 导入 functools 模块中的 partial 函数
import torch # 导入 torch 库
from torch.nn import Module, ModuleList # 从 torch.nn 模块中导入 Module 和 ModuleList 类
from torch import nn, einsum, Tensor # 从 torch 模块中导入 nn、einsum 和 Tensor
from torch.utils.checkpoint import checkpoint # 从 torch.utils.checkpoint 模块导入 checkpoint 函数
import torch.nn.functional as F # 导入 torch.nn.functional 模块并重命名为 F
from einops import rearrange # 导入 einops 库中的 rearrange 函数
from einops.layers.torch import Rearrange # 从 einops.layers.torch 模块中导入 Rearrange 类
from rotary_embedding_torch import RotaryEmbedding # 导入 rotary_embedding_torch 库中的 RotaryEmbedding 类
from gateloop_transformer.associative_scan import associative_scan # 从 gateloop_transformer.associative_scan 模块中导入 associative_scan 函数
# helpers
def exists(v): # 定义 exists 函数,用于判断变量是否存在
return v is not None # 返回变量是否不为 None
def default(v, d): # 定义 default 函数,用于返回变量或默认值
return v if exists(v) else d # 如果变量存在则返回变量,否则返回默认值
def Sequential(*modules): # 定义 Sequential 函数,用于创建序列模块
modules = list(filter(exists, modules)) # 过滤掉不存在的模块
num_modules = len(modules) # 获取模块数量
if num_modules == 0: # 如果模块数量为 0
return nn.Identity() # 返回一个恒等映射的模块
elif num_modules == 1: # 如果模块数量为 1
return modules[0] # 返回该模块
return nn.Sequential(*modules) # 返回包含所有模块的序列模块
# rms norm
class RMSNorm(Module): # 定义 RMSNorm 类,用于实现 RMS 归一化
def __init__(self, dim): # 初始化方法
super().__init__() # 调用父类的初始化方法
self.scale = dim ** 0.5 # 计算缩放因子
self.gamma = nn.Parameter(torch.ones(dim)) # 创建可学习参数 gamma
def forward(self, x): # 前向传播方法
return F.normalize(x, dim=-1) * self.scale * self.gamma # 对输入进行归一化并乘以缩放因子和 gamma
# norm wrappers
class PreNorm(Module): # 定义 PreNorm 类,用于实现预归一化
def __init__(self, dim, fn: Module): # 初始化方法
super().__init__() # 调用父类的初始化方法
self.fn = fn # 保存传入的模块
self.norm = RMSNorm(dim) # 创建 RMSNorm 归一化模块
def forward(self, x, **kwargs): # 前向传播方法
return self.fn(self.norm(x), **kwargs) + x # 对输入进行归一化后,再应��传入的模块并加上原始输入
class PostNorm(Module): # 定义 PostNorm 类,用于实现后归一化
def __init__(self, dim, fn: Module): # 初始化方法
super().__init__() # 调用父类的初始化方法
self.fn = fn # 保存传入的模块
self.norm = nn.LayerNorm(dim) # 创建 LayerNorm 归一化模块
def forward(self, x, **kwargs): # 前向传播方法
return self.norm(self.fn(x, **kwargs) + x) # 应用传入的模块后,再对结果进行归一化并加上原始输入
# feedforward
def FeedForward(dim, mult=4): # 定义 FeedForward 函数,用于创建前馈神经网络
dim_inner = dim * mult # 计算内部维度
return nn.Sequential( # 返回一个序列模块
nn.Linear(dim, dim_inner), # 线性变换层
nn.GELU(), # GELU 激活函数
nn.Linear(dim_inner, dim) # 线性变换层
)
# attention
class CausalFullAttention(Module): # 定义 CausalFullAttention 类,用于实现自回归注意力机制
def __init__(
self,
dim,
*,
dim_head=64,
heads=8,
rotary_emb=False,
add_swish_gating=False,
data_dependent_rel_pos=False,
frac_gradient_data_dependent_rel_pos=0.5,
softmax_normalize=None
): # 初始化方法
super().__init__() # 调用父类的初始化方法
dim_inner = dim_head * heads # 计算内部维度
self.softmax_normalize = default(softmax_normalize, not data_dependent_rel_pos) # 设置 softmax 归一化参数
self.scale = dim_head ** -0.5 # 计算缩放因子
self.rotary_emb = RotaryEmbedding(dim_head) if rotary_emb else None # 创建旋转嵌入对象(如果需要)
self.to_qkv = nn.Sequential( # 创建 Q、K、V 投影模块
nn.Linear(dim, dim_inner * 3, bias=False), # 线性变换层
Rearrange('b n (qkv h d) -> qkv b h n d', h=heads, qkv=3) # 重排张量维度
)
self.data_dependent_rel_pos = data_dependent_rel_pos # 是否使用数据相关的相对位置编码
self.frac_gradient_data_dependent_rel_pos = frac_gradient_data_dependent_rel_pos # 数据相关的相对位置编码的梯度比例
if data_dependent_rel_pos: # 如果使用数据相关的相对位置编码
self.to_a = nn.Sequential( # 创建相对位置编码模块
nn.Linear(dim, dim_inner, bias=False), # 线性变换层
Rearrange('b n (h d c) -> b h n d c', h=heads, c=2) # 重排张量维度
)
self.to_gates = None # 初始化门控模块为 None
if add_swish_gating: # 如果添加 Swish 门控
self.to_gates = nn.Sequential( # 创建门控模块
nn.Linear(dim, dim_inner, bias=False), # 线性变换层
nn.SiLU(), # Swish 激活函数
Rearrange('b n (h d) -> b h n d', h=heads) # 重排张量维度
)
self.to_out = nn.Sequential( # 创建输出模块
Rearrange('b h n d -> b n (h d)'), # 重排张量维度
nn.Linear(dim_inner, dim) # 线性变换层
)
def forward(
self,
x,
ablate_complex=False,
ablate_state_transition=False
):
# 将输入 x 转换为查询 q、键 k、值 v
q, k, v = self.to_qkv(x)
# 如果存在旋转嵌入,则对查询和键进行旋转
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
# 缩放查询
q = q * self.scale
# 如果启用数据相关的相对位置编码,并且不禁用状态转换
if self.data_dependent_rel_pos and not ablate_state_transition:
# 获取数据相关的相对位置投影
frac_gradient = self.frac_gradient_data_dependent_rel_pos
# 计算相对位置投影
a = self.to_a(x)
# 允许数据相关的相对位置投影变化更慢
a = a * frac_gradient + a.detach() * (1 - frac_gradient)
# 将 a 转换为复数形式
a = torch.view_as_complex(a)
# 如果禁用复数计算
if ablate_complex:
a = a.real + 0.j
# 计算幅度和相位
magnitude, phase = a.abs(), a.angle()
a = torch.polar(magnitude.sigmoid(), phase)
# 重排形状
a = rearrange(a, '... -> ... 1')
a_cumprod = a.cumprod(dim=-2)
# 对实部进行截断
a_cumprod_real = a_cumprod.real.clamp(min=1e-10)
a_cumprod_real_inverse = 1. / a_cumprod_real
# 重排形状
q, k = map(lambda t: rearrange(t, '... (d c) -> ... d c', c=2), (q, k))
# 更新查询和键
q = q * a_cumprod_real
k = k * a_cumprod_real_inverse
# 重排形状
q, k = map(lambda t: rearrange(t, '... d c -> ... (d c)'), (q, k))
# 计算相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
i, j = sim.shape[2:]
# 创建因果掩码
causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(j - i + 1)
# 如果启用 softmax 归一化
if self.softmax_normalize:
# 对相似度矩阵进行掩码处理
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
else:
# 对相似度矩阵进行掩码处理
attn = sim.masked_fill(causal_mask, 0.)
# 计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 如果存在门控机制
if exists(self.to_gates):
# 应用门控机制
out = out * self.to_gates(x)
# 返回输出结果
return self.to_out(out)
# 定义一个函数,实现带有“gateloop操作符”的数据门控线性注意力
def gate_loop_operator(q, k, v, a):
"""
the pseudocode in section 3.2 of the paper
"""
# 计算 k 和 v 的张量积
kv = einsum('b n d, b n e -> b n d e', k, v)
# 将结果转换为复数张量
kv = kv + 0.j
# 定义一个二元操作符函数
def binary_operator(a, b):
a_i, kv_i = a
a_j, kv_j = b
return a_j * a_i, a_j * kv_i + kv_j
# 对二元操作符进行关联扫描
_, kv = associative_scan(binary_operator, (a, kv))
# 计算最终输出
return einsum('b n d, b n d e -> b n e', q, kv.real)
# GateLoopedAttention 类,继承自 Module 类
class GateLoopedAttention(Module):
def __init__(
self,
dim,
heads = None,
dim_inner = None,
checkpoint_gate_looped_attn = True,
add_swish_gating = True,
sub_ln = False,
frac_gradient_state_transition = 0.9
):
super().__init__()
self.frac_gradient_state_transition = frac_gradient_state_transition
self.checkpoint_gate_looped_attn = checkpoint_gate_looped_attn
dim_inner = default(dim_inner, dim)
heads = default(heads, dim_inner)
# 检查维度是否符合要求
assert (dim_inner % heads) == 0, f'dimension for gate looped attention {dim_inner} must be divisible by number of gate loop heads {heads}'
# 将输入张量按照头数进行分割
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
# 线性变换,将输入转换为 Q、K、V
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
# 线性变换,将输入转换为注意力权重
self.to_a = nn.Sequential(
nn.Linear(dim, heads * 2),
Rearrange('b n (h c) -> (b h) n 1 1 c', h = heads, c = 2)
)
# 合并头部
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
# 可选的 LayerNorm
self.maybe_sub_ln = nn.LayerNorm(dim_inner) if sub_ln else nn.Identity()
self.to_gates = None
# 添加 Swish 激活门控
if add_swish_gating:
self.to_gates = nn.Sequential(
nn.Linear(dim, dim_inner, bias = False),
nn.SiLU()
)
# 输出线性变换
self.to_out = nn.Linear(dim_inner, dim, bias = False) if dim_inner != dim or add_swish_gating else nn.Identity()
# 前向传播函数
def forward(
self,
x,
ablate_complex = False,
ablate_state_transition = False
):
frac_gradient = self.frac_gradient_state_transition
# 将输入 x 转换为 Q、K、V
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(self.split_heads, (q, k, v))
# 获取注意力权重
a = self.to_a(x)
a = a * frac_gradient + a.detach() * (1 - frac_gradient)
# 将注意力权重转换为复数张量
a = torch.view_as_complex(a)
# 如果 ablate_complex 为 True,则将注意力权重转换为实部
if ablate_complex:
a = a.real + 0.j
# 如果 ablate_state_transition 为 True,则将注意力权重设置为全 1
if ablate_state_transition:
a = torch.ones_like(a.real) + 0.j
else:
# 对状态转换的激活函数
# 使用 sigmoid 函数处理幅度,使用恒等函数处理相位
magnitude, phase = a.abs(), a.angle()
a = torch.polar(magnitude.sigmoid(), phase)
# 检查是否需要反向传播
need_backwards = any([t.requires_grad for t in (q, k, v, a)])
# 使用 partial 函数创建一个带有检查点的函数
fn = partial(checkpoint, gate_loop_operator) if need_backwards and self.checkpoint_gate_looped_attn else gate_loop_operator
# 计算输出
out = fn(q, k, v, a)
out = self.merge_heads(out)
out = self.maybe_sub_ln(out)
# 如果存在门控,则将门控应用到输出上
if exists(self.to_gates):
out = self.to_gates(x) * out
return self.to_out(out)
# Transformer 类,继承自 Module 类
class Transformer(Module):
def __init__(
self,
dim,
*,
num_tokens,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
checkpoint_gate_looped_attn = True,
use_gate_looped_attn = True,
gate_loop_heads = None,
attn_add_swish_gating = True,
dim_gate_looped_attn = None,
attn_softmax_normalize = None,
data_dependent_rel_pos = False,
frac_gradient_state_transition = 0.9,
ablate_complex = False,
ablate_state_transition = False,
rotary_emb = False,
post_ln_norm = False,
sub_ln = False
# 初始化函数,设置模型的参数
):
# 调用父类的初始化函数
super().__init__()
# 设置是否削弱复杂性和状态转换的参数
self.ablate_complex = ablate_complex
self.ablate_state_transition = ablate_state_transition
# 创建一个词嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个模块列表用于存储每个层的注意力和前馈网络
layers = ModuleList([])
# 根据是否后层归一化选择层包装器
layer_wrapper = PreNorm if not post_ln_norm else PostNorm
# 循环创建指定深度的层
for _ in range(depth):
# 根据是否使用门控循环注意力选择空间混合器类型
if use_gate_looped_attn:
spatial_mixer = GateLoopedAttention(
dim = dim,
heads = gate_loop_heads,
dim_inner = dim_gate_looped_attn,
add_swish_gating = attn_add_swish_gating,
sub_ln = sub_ln,
checkpoint_gate_looped_attn = checkpoint_gate_looped_attn,
frac_gradient_state_transition = frac_gradient_state_transition
)
else:
spatial_mixer = CausalFullAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
rotary_emb = rotary_emb,
add_swish_gating = attn_add_swish_gating,
softmax_normalize = attn_softmax_normalize,
data_dependent_rel_pos = data_dependent_rel_pos,
frac_gradient_data_dependent_rel_pos = frac_gradient_state_transition
)
# 创建通道混合器
channelwise_mixer = FeedForward(
dim = dim,
mult = ff_mult
)
# 将空间混合器和通道混合器添加到层列表中
layers.append(ModuleList([
layer_wrapper(dim, spatial_mixer),
layer_wrapper(dim, channelwise_mixer)
]))
# 将层列表转换为模块列表
self.layers = ModuleList(layers)
# 创建输出层,包括 RMS 归一化和线性层
self.to_logits = Sequential(
RMSNorm(dim) if not post_ln_norm else None,
nn.Linear(dim, num_tokens, bias = False)
)
# 前向传播函数
def forward(
self,
x,
return_loss = False,
ablate_complex = None,
ablate_state_transition = None
):
# 设置是否削弱复杂性和状态转换的参数
ablate_complex = default(ablate_complex, self.ablate_complex)
ablate_state_transition = default(ablate_state_transition, self.ablate_state_transition)
# 如果需要返回损失,则提取标签
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
# 对输入进行词嵌入
x = self.token_emb(x)
# 遍历每个层的注意力和前馈网络
for attn, ff in self.layers:
# 使用注意力层
x = attn(
x,
ablate_complex = ablate_complex,
ablate_state_transition = ablate_state_transition
)
# 使用前馈网络
x = ff(x)
# 获取最终输出
logits = self.to_logits(x)
# 如果不需要返回损失,则直接返回输出
if not return_loss:
return logits
# 重新排列输出并计算交叉熵损失
logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels)
.\lucidrains\gateloop-transformer\gateloop_transformer\gateloop_transformer_jax.py
# 导入必要的模块和函数
from typing import List, Tuple, Callable
from jax import random, jit, nn, lax, numpy as np
from jax.lax import associative_scan
from equinox import Module, static_field
# linear
# 定义线性层模块
class Linear(Module):
weight: np.ndarray
bias: np.ndarray
def __init__(self, dim_in, dim_out, *, key):
# 使用随机数生成权重和偏置
weight_key, bias_key = random.split(key)
self.weight = random.normal(weight_key, (dim_in, dim_out))
self.bias = random.normal(bias_key, (dim_out,))
def __call__(self, x, *, key = None):
# 计算线性变换
return x @ self.weight + self.bias
# rmsnorm
# 定义 RMSNorm 模块
class RMSNorm(Module):
scale: float = static_field()
eps: float = static_field()
gamma: np.ndarray
def __init__(self, dim, eps = 1e-5):
# 初始化参数
self.eps = eps
self.scale = dim ** 0.5
self.gamma = np.ones((dim,))
def __call__(self, x):
# 计算 RMSNorm
sum_of_squares = np.sum(np.square(x), axis = -1, keepdims = True)
inv_norm = lax.rsqrt(sum_of_squares + self.eps)
return inv_norm * x * self.gamma * self.scale
# gate loop layer
# 定义门循环操作符
def gate_loop_operator(k, v, q, a):
kv = k * v + 0.j
def binary_operator(e_i, e_j):
a_i, kv_i = e_i
a_j, kv_j = e_j
return a_j * a_i, a_j * kv_i + kv_j
# 使用关联扫描计算门循环
_, y = associative_scan(binary_operator, (a, kv), axis = 1)
return q * np.real(y)
# 定义门循环模块
class GateLoop(Module):
norm: RMSNorm
wq: np.ndarray
wk: np.ndarray
wv: np.ndarray
wa: np.ndarray
wg: np.ndarray
wo: np.ndarray
def __init__(
self,
dim,
key
):
"""
q - query
k - key
v - value
a - state transition
g - gating with silu activation
o - output
"""
# 使用随机数生成参数
q_key, k_key, v_key, a_key, g_key, o_key = random.split(key, 6)
self.norm = RMSNorm(dim)
self.wq = random.normal(q_key, (dim, dim))
self.wk = random.normal(k_key, (dim, dim))
self.wv = random.normal(v_key, (dim, dim))
self.wa = random.normal(a_key, (dim, dim * 2))
self.wg = random.normal(g_key, (dim, dim))
self.wo = random.normal(o_key, (dim, dim))
def __call__(self, x):
x = self.norm(x)
q = x @ self.wq
k = x @ self.wk
v = x @ self.wv
a = x @ self.wa
g = x @ self.wg
# 构成复杂状态转换
a_real, a_imag = np.split(a, 2, axis = -1)
a_complex = lax.complex(a_real, a_imag)
magnitude, phase = np.abs(a_complex), np.angle(a_complex)
magnitude = nn.sigmoid(magnitude)
a_complex = magnitude * np.exp(1j * phase)
# 使用复杂状态进行关联扫描
y = gate_loop_operator(k, v, q, a_complex)
# 使用 ReTNet 的 silu gating
y = y * nn.silu(g)
o = y @ self.wo
return o
# basic feedforward with pre-rmsnorm
# 定义带有 RMSNorm 的基本前馈模块
class FeedForward(Module):
norm: RMSNorm
proj_in: Linear
proj_out: Linear
def __init__(
self,
*,
dim,
key,
mult = 4
):
self.norm = RMSNorm(dim)
self.proj_in = Linear(dim, dim * mult, key = key)
self.proj_out = Linear(dim * mult, dim, key = key)
def __call__(self, x):
x = self.norm(x)
x = self.proj_in(x)
x = nn.gelu(x)
x = self.proj_out(x)
return x
# main class
# 定义门循环变换器模块
class GateLoopTransformer(Module):
embedding: np.ndarray
norm: Module
layers: List[Tuple[GateLoop, FeedForward]]
def __init__(
self,
*,
num_tokens,
dim,
depth,
key,
ff_mult = 4
# 初始化嵌入矩阵,使用正态分布随机初始化,乘以0.02
self.embedding = random.normal(key, (num_tokens, dim)) * 0.02
# 初始化层列表
layers = []
# 循环创建深度次数的GateLoop和FeedForward层,并添加到层列表中
for _ in range(depth):
gateloop = GateLoop(dim = dim, key = key)
ff = FeedForward(dim = dim, mult = ff_mult, key = key)
layers.append((gateloop, ff))
# 将创建的层列表赋值给self.layers
self.layers = layers
# 初始化RMSNorm层
self.norm = RMSNorm(dim)
@jit
def __call__(self, x):
# 通过嵌入矩阵获取输入x的嵌入向量
x = self.embedding[x]
# 遍历每一层,依次进行GateLoop和FeedForward操作
for gateloop, ff in self.layers:
x = gateloop(x) + x
x = ff(x) + x
# 对输出进行归一化处理
x = self.norm(x)
# 计算logits,即输出结果
logits = x @ self.embedding.transpose()
return logits
# 如果当前脚本被直接运行
if __name__ == '__main__':
# 导入 jax 库
import jax
# 使用 PRNGKey 创建一个随机种子
key = jax.random.PRNGKey(0)
# 创建一个 GateLoopTransformer 模型实例
model = GateLoopTransformer(
num_tokens = 20000,
dim = 512,
depth = 12,
key = key
)
# 生成一个长度为 1024 的随机整数序列
seq = jax.random.randint(key, (1024,), 0, 20000)
# 使用模型对序列进行推理,得到输出 logits
logits = model(seq)
# 打印 logits 的形状
print(logits.shape) # (1024, 20000)
.\lucidrains\gateloop-transformer\gateloop_transformer\simplified_gate_loop.py
# 导入所需模块
from functools import partial
import torch
from torch import nn, Tensor
from torch.nn import Module
from typing import Tuple
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange
from gateloop_transformer.gateloop_transformer import RMSNorm
from gateloop_transformer.associative_scan import associative_scan
# 检查变量是否存在的函数
def exists(v):
return v is not None
# 绝对值截断函数,用于处理小于给定阈值的值
def abs_clamp_eps(t, eps = 1e-20):
sign = torch.sign(t)
return sign * t.abs().clamp(min = eps)
# 使用 Heinsen 序列进行关联扫描
def heinsen_associative_scan(a, kv, eps = 1e-20):
log_a = a.clamp(min = eps).log()
log_kv = abs_clamp_eps(kv, eps = eps).to(dtype = torch.complex64).log()
a_star = torch.cumsum(log_a, dim = 1)
log_x0_plus_b_star = torch.logcumsumexp(log_kv - a_star, dim = 1)
log_x = a_star + log_x0_plus_b_star
return a_star.exp().real, log_x.exp().real
# 使用 TorchScript 实现的二进制运算函数
@torch.jit.script
def binary_operator(
a: Tuple[Tensor, Tensor],
b: Tuple[Tensor, Tensor]
):
a_i, kv_i = a
a_j, kv_j = b
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
# 门循环操作符
def gate_loop_operator(q, kv, a, cache = None, heinsen = False):
if exists(cache):
cache_a, cache_kv = cache
a, a_ps = pack([cache_a, a], 'b * d')
kv, kv_ps = pack([cache_kv, kv], 'b * d')
if heinsen:
a, kv = heinsen_associative_scan(a, kv)
else:
a, kv = associative_scan(binary_operator, (a, kv))
if exists(cache):
_, a = unpack(a, a_ps, 'b * d')
_, kv = unpack(kv, kv_ps, 'b * d')
return q * kv, (a[:, -1], kv[:, -1])
# 使用 JAX 实现的门循环操作符
def get_jax_gate_loop_operator():
try:
from jax import jit, numpy as jnp
from jax.lax import associative_scan
from jax2torch import jax2torch
except ImportError as e:
print(f'jax and jax2torch must be installed - `pip install jax2torch`')
@jit
def jax_gate_loop_operator(q, kv, a, cache = None):
def binary_operator(e_i, e_j):
a_i, kv_i = e_i
a_j, kv_j = e_j
return a_j * a_i, a_j * kv_i + kv_j
if exists(cache):
cache_a, cache_kv = cache
a, a_ps = pack([cache_a, a], 'b * d')
kv, kv_ps = pack([cache_kv, kv], 'b * d')
_, y = associative_scan(binary_operator, (a, kv), axis = 1)
if exists(cache):
_, a = unpack(a, a_ps, 'b * d')
_, kv = unpack(kv, kv_ps, 'b * d')
return q * y, (a[:, -1], kv[:, -1])
return jax2torch(jax_gate_loop_operator)
# 简单的门循环层
class SimpleGateLoopLayer(Module):
"""
简化的门循环层,用于补充注意力机制
参考 https://github.com/lucidrains/mega-pytorch
"""
def __init__(
self,
dim,
prenorm = True,
use_heinsen = False,
use_jax_associative_scan = False,
post_ln = False,
reverse = False
):
# 调用父类的构造函数
super().__init__()
# 断言确保 use_heinsen 和 use_jax_associative_scan 中至多只有一个为真
assert (int(use_heinsen) + int(use_jax_associative_scan)) <= 1
# 如果 prenorm 为真,则使用 RMSNorm 进行归一化,否则使用 nn.Identity()
self.norm = RMSNorm(dim) if prenorm else nn.Identity()
# 初始化维度
self.dim = dim
# 将输入映射到 q, k, v,并进行线性变换
self.to_qkva = nn.Sequential(
nn.Linear(dim, dim * 3, bias = False),
Rearrange('b n (qkva d) -> qkva (b d) n 1', qkva = 3)
)
# 设置是否使用 Heinsen 或 JAX 的关联扫描
self.use_heinsen = use_heinsen
self.use_jax = use_jax_associative_scan
# 根据使用的扫描方式选择相应的 gate_loop_fn
if use_jax_associative_scan:
self.gate_loop_fn = get_jax_gate_loop_operator()
elif use_heinsen:
self.gate_loop_fn = partial(gate_loop_operator, heinsen = True)
else:
self.gate_loop_fn = gate_loop_operator
# 如果 post_ln 为真,则使用 nn.LayerNorm(dim) 进行归一化,否则使用 nn.Identity()
self.maybe_post_ln = nn.LayerNorm(dim) if post_ln else nn.Identity()
# 将输出进行头部分割
self.split_heads = Rearrange('(b d) n 1 -> b n d', d = dim)
# 设置是否反转序列
self.reverse = reverse
# 前向传播函数
def forward(
self,
x,
cache = None,
return_cache = False
):
# 如果需要反转序列,则对输入进行反转
if self.reverse:
x = torch.flip(x, dims = (-2,))
# 对输入进行归一化
x = self.norm(x)
# 将输入映射到 q, k, v
q, kv, a = self.to_qkva(x)
# 使用 gate_loop_fn 进行计算
out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache)
# 将输出进行头部分割
out = self.split_heads(out)
# 对输出进行归一化
out = self.maybe_post_ln(out)
# 如果需要反转序列,则对输出进行反转
if self.reverse:
out = torch.flip(out, dims = (-2,))
# 如果不需要返回 cache,则直接返回输出
if not return_cache:
return out
# 断言确保只有在非反转序列时才能缓存
assert not self.reverse, 'caching only works with non-reversed seq'
# 返回输出和 cache
return out, cache
.\lucidrains\gateloop-transformer\gateloop_transformer\__init__.py
# 从 gateloop_transformer.gateloop_transformer 模块中导入 CausalFullAttention, GateLoopedAttention, Transformer 类
# 从 gateloop_transformer.simplified_gate_loop 模块中导入 SimpleGateLoopLayer 类
from gateloop_transformer.gateloop_transformer import (
CausalFullAttention,
GateLoopedAttention,
Transformer
)
from gateloop_transformer.simplified_gate_loop import (
SimpleGateLoopLayer
)

GateLoop Transformer
Implementation of GateLoop Transformer in Pytorch and Jax, to be tested on Enwik8 character level modeling.
Update: A transformer run with regular attention + data dependent xpos relative positions did not converge at all. Also, gate loop's associative scan also is not able to train on even sequence lengths of 128. I'm not sure if it can be done without a specialized CUDA kernel, much like autoregressive linear attention (RWKV and the like)
Update 2: Got a smaller GateLoop transformer (gate loop dimensions of 128) to run on sequence length of 256. It is converging very well with a quick eyeball. Will run some more rigorous experiments tomorrow.
Update 3: Fixed a misunderstanding and definitely seems to be converging better than vanilla linear attention (from my memories of those experiments).
Update 4: Ongoing experiments
Update 5: Author has reviewed the code, and there was another misunderstanding. They use maximum heads (heads == dimension). This is kind of a plot twist, as this is infeasible for normal attention. It also obviates the need a fused CUDA kernel as in autoregressive linear attention.
Update 6: Corrected gateloop transformer run looks amazing. Cautiously optimistic now.
Update 7: Ablating state transition shows expected negative result. Ablating complex valued states though, I see no difference, at least, early in the run.
Update 8: Directly projecting to kv with one projection for the max-heads setting (instead of keys and values separately followed by element-wise multiplication) yields similar results
Update 9: Head to head to 20k, just to make sure Gateloop doesn't get exceeded later on
Update 10: and it got passed by attention, at least, assuming the implementation in the repo is correct.
Update 11: I'm seeing a steady improvement increasing the head dimension, so I no longer believe max-heads is optimal. Increasing the head dimension brings us right back to linear attention and needing the fused CUDA kernel.
Update 12: Nikil spotted a potential error with the kv not being kept in complex (and real component taken at end). Rerunning experiments
Update 13: Still clearly worse
Update 14: See some synergy when mixing gateloop and attention on a small scale, when holding parameters constant. Will be adding a tiny bit of simplified gateloop layers to transformers to address a main weakness in attention for future projects.
Update 15: There may be a way to combine associative scan based works with the findings from the recently proposed taylor series linear attention. will carry out some independent research before end of January 2024 and share the results here.
Appreciation
- StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install gateloop-transformer
Usage
import torch
from gateloop_transformer import Transformer
model = Transformer(
num_tokens = 256,
dim = 624,
depth = 6,
use_gate_looped_attn = True
)
ids = torch.randint(0, 256, (1, 1024))
logits = model(ids) # (1, 1024, 256)
A simplified gate loop layer
import torch
from gateloop_transformer import SimpleGateLoopLayer
gateloop = SimpleGateLoopLayer(512)
x = torch.randn(1, 65536, 512)
x = gateloop(x) + x
Character-level Language Modeling
Install requirements
$ pip install -r requirements.txt
Then run the train.py script for autoregressive modeling on enwik8
$ python train.py
Todo
Citations
@inproceedings{Katsch2023GateLoopFD,
title = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
author = {Tobias Katsch},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:265018962}
}
@inproceedings{Heinsen2023EfficientPO,
title = {Efficient Parallelization of a Ubiquitous Sequential Computation},
author = {Franz A. Heinsen},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:265213659}
}
.\lucidrains\gateloop-transformer\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'gateloop-transformer', # 包名
packages = find_packages(exclude=[]), # 查找所有包
version = '0.2.4', # 版本号
license='MIT', # 许可证
description = 'GateLoop Transformer', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/gateloop-transformer', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'gated linear attention' # 关键词
],
install_requires=[
'einops>=0.7.0', # 安装所需的依赖包
'rotary-embedding-torch', # 安装所需的依赖包
'torch>=2.1', # 安装所需的依赖包
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
.\lucidrains\gateloop-transformer\train.py
# 导入所需的库
import math
import gzip
import random
import tqdm
import numpy as np
from functools import wraps, partial
import torch
from torch.optim import Adam, AdamW
from torch import Tensor
from torch.nn import Module, functional as F
from torch.utils.data import DataLoader, Dataset
# 导入加速库
from accelerate import Accelerator
# 导入自定义的 Transformer 模型
from gateloop_transformer import Transformer
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 256
WANDB = True
PROJECT_NAME = 'gateloop'
RUN_NAME = 'baseline gateloop'
# 初始化加速器
accelerator = Accelerator(log_with='wandb' if WANDB else None)
# 辅助函数
def exists(v):
return v is not None
def cycle(loader):
while True:
for data in loader:
yield data
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 采样辅助函数
def log(t, eps=1e-20):
return torch.log(t.clamp(min=eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1., dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
def top_k(logits, thres=0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(-1, ind, val)
return probs
def base_decoding(net: Module, prompt: Tensor, seq_len: int, temperature=1., filter_thres=0.9):
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
sample_num_times = max(0, seq_len - prompt_seq_len)
for _ in range(sample_num_times):
logits = net(out)
logits = logits[:, -1]
logits = top_k(logits, thres=filter_thres)
sample = gumbel_sample(logits, temperature=temperature, dim=-1)
out = torch.cat((out, sample[..., None]), dim=-1)
return out[..., prompt_seq_len:]
# 优化器
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(params, lr=1e-4, wd=0., betas=(0.9, 0.99), eps=1e-8, group_wd_params=True, **kwargs):
opt_kwargs = dict(lr=lr, betas=betas, eps=eps)
if wd == 0:
return Adam(params, **opt_kwargs)
opt_kwargs = {'weight_decay': wd, **opt_kwargs}
if not group_wd_params:
return AdamW(params, **opt_kwargs)
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
return AdamW(params, **opt_kwargs)
# 实例化 Transformer 模型
hparams = dict(
num_tokens=256,
dim=512,
depth=6,
use_gate_looped_attn=True,
gate_loop_heads=512,
data_dependent_rel_pos=False,
attn_softmax_normalize=True,
ablate_complex=False,
ablate_state_transition=False,
rotary_emb=False,
post_ln_norm=True
)
model = Transformer(**hparams)
# 初始化实验跟踪
num_parameters = sum(p.numel() for p in model.parameters())
print(f'number of parameters: {num_parameters}')
wandb_config = {**hparams, 'num_parameters': num_parameters}
accelerator.init_trackers(PROJECT_NAME, config=wandb_config)
if WANDB and exists(RUN_NAME) and len(accelerator.trackers) > 0:
accelerator.trackers[0].run.name = RUN_NAME
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
# 从文件中读取指定长度的数据,转换为 numpy 数组,数据类型为无符号整数8位,然后复制一份
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 将数据数组分割成训练集和验证集,分割点为第90e6个元素的位置
np_train, np_valid = np.split(data, [int(90e6)])
# 将 numpy 数组转换为 PyTorch 张量,分别赋值给训练集和验证集的变量
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
# 定义一个自定义的数据集类,用于处理文本数据的采样
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 存储数据
self.seq_len = seq_len # 存储序列长度
def __getitem__(self, index):
# 随机生成起始位置
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
# 获取完整的序列数据
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练数据集和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练数据加载器和验证数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# 优化器
optim = get_optimizer(
model.parameters(),
lr = LEARNING_RATE,
wd = WEIGHT_DECAY
)
# 准备模型、优化器、训练数据加载器和验证数据加载器
(
model,
optim,
train_loader,
val_loader
) = accelerator.prepare(
model,
optim,
train_loader,
val_loader
)
# 将训练数据加载器和验证数据加载器转换为循环迭代器
train_loader = cycle(train_loader)
val_loader = cycle(val_loader)
# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
model.train()
for _ in range(GRAD_ACCUM_EVERY):
data = next(train_loader)
loss = model(data, return_loss = True)
accelerator.backward(loss / GRAD_ACCUM_EVERY)
print(f"training loss: {loss.item():.3f}")
accelerator.log(dict(loss = loss.item()), step = i)
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
accelerator.wait_for_everyone()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
valid_data = next(val_loader)
loss = model(valid_data, return_loss = True)
print(f"validation loss: {loss.item():.3f}")
accelerator.log(dict(valid_loss = loss.item()), step = i)
accelerator.wait_for_everyone()
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:PRIME_LENGTH]
inp = inp.to(accelerator.device)
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))
prompt = inp[None, ...]
sampled = base_decoding(model, prompt, GENERATE_LENGTH)
base_decode_output = decode_tokens(sampled[0])
print("\n\n", base_decode_output, "\n")
.\lucidrains\genetic-algorithm-pytorch\bcga.py
"""
Bee Colonies Genetic Algorithm
Here we simulate different colonies to maintain diversity. At each generation, one allow a small subset of bees from each colony to immigrate to another
"""
import torch
import einx
# constants
GOAL = 'Attention is all you need'
COLONIES = 10
POP_SIZE = 250
MUTATION_PROB = 0.05
STRONG_MUTATION_PROB = 0.15
NUM_TOURNAMENT_PARTICIPANTS = 25
MIGRATE_EVERY = 5
FRAC_BEES_MIGRANTS = 0.1
# encode and decode functions
# 检查变量是否存在
def exists(v):
return v is not None
# 将字符串编码为张量
def encode(s):
return torch.tensor([ord(c) for c in s])
# 将张量解码为字符串
def decode(t):
return ''.join([chr(i) for i in t.tolist()])
# 计算适应度函数
def calc_fitness(genes, target):
return 1. / (genes - target).square().sum(dim = -1)
# derived constants
# 目标字符串长度
gene_length = len(GOAL)
# 目标基因
target_gene = encode(GOAL)
# 计算基因突变数量
num_code_mutate = MUTATION_PROB * gene_length
strong_num_code_mutate = STRONG_MUTATION_PROB * gene_length
# 计算迁移的蜜蜂数量
num_bees_migrate = int((POP_SIZE - 1) * FRAC_BEES_MIGRANTS)
# queen bee genetic algorithm
generation = 1
# 初始化种群
colonies = torch.randint(0, 255, (COLONIES, POP_SIZE - 1, gene_length))
colonies_arange = torch.arange(COLONIES)[..., None]
# 初始化皇后蜜蜂
queens = torch.randint(0, 255, (COLONIES, gene_length))
queen_fitnesses = calc_fitness(queens, target_gene)
while True:
print(f"\n\ngeneration {generation}\n")
# sort population by fitness
# 计算种群适应度
colony_fitnesses = calc_fitness(colonies, target_gene)
# 按适应度降序排序种群
indices = colony_fitnesses.sort(descending = True).indices
colonies, colony_fitnesses = colonies[colonies_arange, indices], colony_fitnesses[colonies_arange, indices]
# display every generation
for i, (pool, fitnesses) in enumerate(zip(colonies[:, :10], colony_fitnesses[:, :10])):
print(f'\ncolony {i + 1}:\n')
if exists(queens):
queen, queen_fitness = queens[i], queen_fitnesses[i]
print(f"Q: {decode(queen)} ({queen_fitness.item():.3f})\n")
for gene, fitness in zip(pool, fitnesses):
print(f"{decode(gene)} ({fitness.item():.3f})")
# if one of the children has a better fitness than queen, that child becomes the new queen
# and the queen replaces the worst bee in the population, kept around for at least one generation more
has_new_queen = colony_fitnesses[:, 0] > queen_fitnesses
pop_arange = torch.arange(POP_SIZE)
pop_arange_with_offset = pop_arange + has_new_queen[:, None]
colonies = torch.cat((
queens[:, None, :],
colonies,
queens[:, None, :]
), dim = -2)
colony_fitnesses = torch.cat((
queen_fitnesses[:, None],
colony_fitnesses,
queen_fitnesses[:, None]
), dim = -1)
colonies = colonies[colonies_arange, pop_arange_with_offset]
colony_fitnesses = colony_fitnesses[colonies_arange, pop_arange_with_offset]
queens, colonies = colonies[:, 0], colonies[:, 1:]
queen_fitnesses, colony_fitnesses = colony_fitnesses[:, 0], colony_fitnesses[:, 1:]
# solved if any fitness is inf
if (queen_fitnesses == float('inf')).any():
print(f'\nsolved at generation {generation}')
break
# deterministic tournament selection - let top winner become parent with queen
colonies_arange_ = colonies_arange[..., None]
contender_ids = torch.randn((COLONIES, POP_SIZE - 1, POP_SIZE - 1)).argsort(dim = -1)[..., :NUM_TOURNAMENT_PARTICIPANTS]
participants, tournaments = colonies[colonies_arange_, contender_ids], colony_fitnesses[colonies_arange_, contender_ids]
top_winner = tournaments.topk(1, dim = -1, largest = True, sorted = False).indices
parents = einx.get_at('... [t] g, ... 1 -> ... g', participants, top_winner)
# potential parents with queen is strongly mutated ("Mutant Bee")
strong_mutate_mask = torch.randn(parents.shape).argsort(dim = -1) < strong_num_code_mutate
noise = torch.randint(0, 2, parents.shape) * 2 - 1
mutated_parents = torch.where(strong_mutate_mask, parents + noise, parents)
mutated_parents.clamp_(0, 255)
# 随机进行50%的基因代码混合,而不是在中点处进行连续的交叉
# 生成一个随机的掩码,用于确定哪些基因需要进行混合
rand_mix_mask = torch.randn(mutated_parents.shape).argsort(dim=-1) < (gene_length // 2)
# 根据随机混合的掩码,将皇后和变异后的父代进行基因混合
colonies = einx.where('c p g, c g, c p g', rand_mix_mask, queens, mutated_parents)
# 对种群中的基因进行突变
# 生成一个用于确定哪些基因需要突变的掩码
mutate_mask = torch.randn(colonies.shape).argsort(dim=-1) < num_code_mutate
# 生成一个随机的噪声,用于基因突变
noise = torch.randint(0, 2, colonies.shape) * 2 - 1
# 根据突变掩码,对种群中的基因进行突变
colonies = torch.where(mutate_mask, colonies + noise, colonies)
# 将基因值限制在0到255之间
colonies.clamp_(0, 255)
# 允许一部分蜜蜂迁移到相邻的群落
# 如果当前代数是迁移周期的倍数,并且有蜜蜂需要迁移
if not (generation % MIGRATE_EVERY) and num_bees_migrate > 0:
# 将一部分蜜蜂迁移到相邻的群落
colonies, migrant_colonies = colonies[:, :-num_bees_migrate], colonies[:, -num_bees_migrate:]
# 将迁移的蜜蜂群落向右滚动一个位置
migrant_colonies = torch.roll(migrant_colonies, 1, dims=0)
# 将迁移后的蜜蜂群落合并回原始种群
colonies = torch.cat((colonies, migrant_colonies), dim=1)
# 增加代数计数
generation += 1
.\lucidrains\genetic-algorithm-pytorch\bega.py
"""
Queen-bee evolution for genetic algorithms - Jung 2003
Inspired by evolution of bees, the fittest solution is designated the "queen", and the rest of the population contends to mate with it. The strong exploitation is balanced by a higher than normal mutation rate.
For some problems, the paper claims convergence at 2-3 orders of magnitude faster
https://www.researchgate.net/publication/3385719_Queen-bee_evolution_for_genetic_algorithms
"""
import torch
from einops import repeat
from einx import get_at
# constants
GOAL = 'Attention is all you need' # 目标字符串
POP_SIZE = 100 # 种群大小
MUTATION_PROB = 0.04 # 突变概率
STRONG_MUTATION_RATE = 0.1 # 强突变率
STRONG_MUTATION_PROB = 0.25 # 强突变概率
NUM_TOURNAMENT_PARTICIPANTS = 25 # 锦标赛参与者数量
# encode and decode functions
def encode(s):
return torch.tensor([ord(c) for c in s]) # 将字符串编码为张量
def decode(t):
return ''.join([chr(i) for i in t.tolist()]) # 将张量解码为字符串
# derived constants
gene_length = len(GOAL) # 目标字符串长度
gene_midpoint = gene_length // 2 # 目标字符串中点位置
target_gene = encode(GOAL) # 目标字符串编码
strong_mutate_pool_size = STRONG_MUTATION_RATE * POP_SIZE # 强突变池大小
num_code_mutate = MUTATION_PROB * gene_length # 码位突变数量
strong_num_code_mutate = STRONG_MUTATION_PROB * gene_length # 强码位突变数量
# queen bee genetic algorithm
generation = 1 # 代数
pool = torch.randint(0, 255, (POP_SIZE, gene_length)) # 随机初始化种群
queen = queen_fitness = None # 初始化皇后和皇后适应度
while True:
print(f"\n\ngeneration {generation}\n") # 打印当前代数
# sort population by fitness
fitnesses = 1. / torch.square(pool - target_gene).sum(dim = -1) # 计算适应度
indices = fitnesses.sort(descending = True).indices # 根据适应度排序种群
pool, fitnesses = pool[indices], fitnesses[indices]
# display every generation
if queen is not None:
print("queen:")
print(f"{decode(queen)} ({queen_fitness.item():.3f})\n") # 打印皇后及其适应度
for gene, fitness in zip(pool, fitnesses):
print(f"{decode(gene)} ({fitness.item():.3f})") # 打印每个基因及其适应度
# if one of the children has a better fitness than queen, that child becomes the new queen
# and the queen replaces the worst bee in the population, kept around for at least one generation more
if queen is not None and queen_fitness < fitnesses[0]:
pool = torch.cat((pool, queen[None, :]), dim = 0) # 将皇后加入种群
fitnesses = torch.cat((fitnesses, queen_fitness[None]), dim = 0)
queen = queen_fitness = None
# separate the queen bee from the rest of the population
if queen is None:
queen, pool = pool[0], pool[1:] # 分离皇后和种群
queen_fitness, fitnesses = fitnesses[0], fitnesses[1:]
# solved if any queen fitness is inf
if (queen_fitness == float('inf')).any(): # 如果皇后适应度为无穷大,则问题已解决
break
# deterministic tournament selection - let top winner become parent with queen
contender_ids = torch.randn((POP_SIZE - 1, POP_SIZE - 1)).argsort(dim = -1)[..., :NUM_TOURNAMENT_PARTICIPANTS] # 锦标赛选择参与者
participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
top_winner = tournaments.topk(1, dim = -1, largest = True, sorted = False).indices # 选择最优参与者
parents = get_at('p [t] g, p 1 -> p g', participants, top_winner) # 获取父母基因
# cross over all chosen drones with the queen
queen_parents = repeat(queen, '... -> p ...', p = POP_SIZE - 1) # 重复皇后基因
queen_and_parents = torch.stack((queen_parents, parents), dim = 1) # 合并皇后和父母基因
rand_crossover_order = torch.randn(queen_and_parents.shape[:2]).argsort(dim = -1) # 随机交叉排序
batch_arange = torch.arange(POP_SIZE - 1)[..., None]
queen_and_parents = queen_and_parents[batch_arange, rand_crossover_order]
queen_parents, parents = queen_and_parents.unbind(dim = 1)
pool = torch.cat((queen_parents[:, :gene_midpoint], parents[:, gene_midpoint:]), dim = -1) # 交叉生成新种群
# mutate genes in population
mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_code_mutate # 生成突变掩码
noise = torch.randint(0, 2, pool.shape) * 2 - 1
mutated_pool = torch.where(mutate_mask, pool + noise, pool) # 码位突变
strong_mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < strong_num_code_mutate # 生成强突变掩码
noise = torch.randint(0, 2, pool.shape) * 2 - 1
strong_mutated_pool = torch.where(strong_mutate_mask, pool + noise, pool) # 强码位突变
# 生成一个布尔掩码,用于选择强变异池中的个体
strong_mutate_pool_mask = torch.randn(POP_SIZE - 1).argsort(dim=-1) < strong_mutate_pool_size
# 根据强变异池掩码,选择强变异池中的个体或者普通变异池中的个体,组成新的池
pool = torch.where(strong_mutate_pool_mask[:, None], strong_mutated_pool, mutated_pool)
# 将池中的值限制在0到255之间
pool.clamp_(0, 255)
# 增加一代
generation += 1
.\lucidrains\genetic-algorithm-pytorch\ga.py
"""
Genetic Algorithm - formalized by John H. Holland in 1992, but has been talked about since 1960-70s
https://www.researchgate.net/figure/Hollands-canonical-genetic-algorithm-Holland-1992_fig4_221174380
"""
import torch
from einx import get_at
# constants
GOAL = 'Attention is all you need' # 目标字符串
POP_SIZE = 100 # 种群大小
MUTATION_RATE = 0.04 # 变异率
FRAC_FITTEST_SURVIVE = 0.25 # 最适应个体存活比例
FRAC_TOURNAMENT = 0.25 # 锦标赛选择比例
ELITE_FRAC = 0.05 # 精英比例
# encode and decode functions
def encode(s):
return torch.tensor([ord(c) for c in s]) # 将字符串编码为张量
def decode(t):
return ''.join([chr(i) for i in t.tolist()]) # 将张量解码为字符串
# derived constants
gene_length = len(GOAL) # 目标字符串长度
gene_midpoint = gene_length // 2 # 目标字符串中点位置
target_gene = encode(GOAL) # 目标字符串编码
keep_fittest_len = int(POP_SIZE * FRAC_FITTEST_SURVIVE) # 保留最适应个体数量
num_elite = int(ELITE_FRAC * POP_SIZE) # 精英数量
num_repro_and_mutate = keep_fittest_len - num_elite # 繁殖和变异数量
num_tournament_contenders = int(num_repro_and_mutate * FRAC_TOURNAMENT) # 锦标赛参与者数量
num_children = POP_SIZE - keep_fittest_len # 子代数量
num_mutate = MUTATION_RATE * gene_length # 变异基因数量
assert num_tournament_contenders >= 2 # 断言确保锦标赛参与者数量大于等于2
# genetic algorithm
generation = 1 # 代数计数器
pool = torch.randint(0, 255, (POP_SIZE, gene_length)) # 初始化种群,随机生成基因
while True:
print(f"\n\ngeneration {generation}\n") # 打印当前代数
# sort population by fitness
fitnesses = 1. / torch.square(pool - target_gene).sum(dim = -1) # 计算适应度
indices = fitnesses.sort(descending = True).indices # 根据适应度对种群排序
pool, fitnesses = pool[indices], fitnesses[indices]
# keep the fittest
pool, fitnesses = pool[:keep_fittest_len], fitnesses[:keep_fittest_len] # 保留最适应个体
# display every generation
for gene, fitness in zip(pool, fitnesses):
print(f"{decode(gene)} ({fitness.item():.3f})") # 打印每个个体的基因和适应度
# solved if any fitness is inf
if (fitnesses == float('inf')).any(): # 如果有个体的适应度为无穷大,则问题已解决
break
# elites can pass directly to next generation
elites, pool = pool[:num_elite], pool[num_elite:] # 精英直接传递到下一代
elites_fitnesses, fitnesses = fitnesses[:num_elite], fitnesses[num_elite:]
# deterministic tournament selection - let top 2 winners become parents
contender_ids = torch.randn((num_children, num_repro_and_mutate)).argsort(dim = -1)[..., :num_tournament_contenders] # 锦标赛选择参与者
participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
top2_winners = tournaments.topk(2, dim = -1, largest = True, sorted = False).indices # 选择前两名作为父母
parents = get_at('p [t] g, p w -> p w g', participants, top2_winners) # 获取父母
# cross over recombination of parents
parent1, parent2 = parents.unbind(dim = 1) # 拆分父母
children = torch.cat((parent1[:, :gene_midpoint], parent2[:, gene_midpoint:]), dim = -1) # 交叉重组父母基因
pool = torch.cat((pool, children)) # 将子代加入种群
# mutate genes in population
mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_mutate # 生成变异掩码
noise = torch.randint(0, 2, pool.shape) * 2 - 1 # 生成变异噪声
pool = torch.where(mutate_mask, pool + noise, pool) # 变异
pool.clamp_(0, 255) # 限制基因值范围在0-255之间
# add back the elites
pool = torch.cat((elites, pool)) # 将精英加回种群
generation += 1 # 代数加一
.\lucidrains\genetic-algorithm-pytorch\inbreed.py
"""
Genetic Algorithm
but without first generation inbreeding
"""
import torch
import einx
from einx import get_at, rearrange
# constants
GOAL = 'Attention is all you need' # 目标字符串
POP_SIZE = 100 # 种群大小
MUTATION_RATE = 0.04 # 变异率
FRAC_FITTEST_SURVIVE = 0.25 # 存活最适应个体的比例
FRAC_TOURNAMENT = 0.25 # 锦标赛选择的比例
ELITE_FRAC = 0.05 # 精英个体的比例
# encode and decode functions
def encode(s):
return torch.tensor([ord(c) for c in s]) # 将字符串编码为张量
def decode(t):
return ''.join([chr(i) for i in t.tolist()]) # 将张量解码为字符串
# derived constants
gene_length = len(GOAL) # 目标字符串的长度
gene_midpoint = gene_length // 2 # 目标字符串的中点位置
target_gene = encode(GOAL) # 目标字符串的编码
keep_fittest_len = int(POP_SIZE * FRAC_FITTEST_SURVIVE) # 保留最适应个体的数量
num_elite = int(ELITE_FRAC * POP_SIZE) # 精英个体的数量
num_repro_and_mutate = keep_fittest_len - num_elite # 繁殖和变异的个体数量
num_tournament_contenders = int(num_repro_and_mutate * FRAC_TOURNAMENT) # 锦标赛的参与者数量
num_children = POP_SIZE - keep_fittest_len # 子代个体数量
num_mutate = MUTATION_RATE * gene_length # 变异的基因数量
assert num_tournament_contenders >= 2 # 断言确保锦标赛的参与者数量大于等于2
# genetic algorithm
generation = 1 # 代数
parent_ids = torch.full((POP_SIZE, 2), -1, dtype=torch.long) # 父母的ID
pool = torch.randint(0, 255, (POP_SIZE, gene_length)) # 种群中的个体
while True:
print(f"\n\ngeneration {generation}\n") # 打印当前代数
# sort population by fitness
fitnesses = 1. / torch.square(pool - target_gene).sum(dim=-1) # 计算适应度
indices = fitnesses.sort(descending=True).indices # 根据适应度对种群进行排序
pool, parent_ids, fitnesses = pool[indices], parent_ids[indices], fitnesses[indices]
# keep the fittest
pool, parent_ids, fitnesses = pool[:keep_fittest_len], parent_ids[:keep_fittest_len], fitnesses[:keep_fittest_len] # 保留最适应的个体
# display every generation
for gene, fitness in zip(pool, fitnesses):
print(f"{decode(gene)} ({fitness.item():.3f})") # 打印每个个体的基因和适应度
# solved if any fitness is inf
if (fitnesses == float('inf')).any(): # 如果任何适应度为无穷大,则问题已解决
break
# elites can pass directly to next generation
elites, pool = pool[:num_elite], pool[num_elite:] # 精英个体直接传递到下一代
elites_fitnesses, fitnesses = fitnesses[:num_elite], fitnesses[num_elite:]
elites_parent_ids, parent_ids = parent_ids[:num_elite], parent_ids[num_elite:]
elites_parent_ids.fill_(-1) # 将精英个体的父母ID填充为-1
# deterministic tournament selection
# 2 tournaments - the second tournament removes all contestants with shared parents with 1st winner
first_contender_ids = torch.randn((num_children, num_repro_and_mutate)).argsort(dim=-1)[..., :num_tournament_contenders] # 第一轮锦标赛的参与者ID
first_participants, participants_parent_ids, tournaments = pool[first_contender_ids], parent_ids[first_contender_ids], fitnesses[first_contender_ids]
first_winner = tournaments.topk(1, dim=-1, largest=True, sorted=False).indices # 第一轮锦标赛的获胜者
first_winner = rearrange('p 1 -> p', first_winner)
first_parent_ids = get_at('p [t] i, p -> p i', participants_parent_ids, first_winner) # 第一轮锦标赛的获胜者的父母ID
# second tournament, masking out any siblings to first winners
contender_scores = torch.randn((num_children, num_repro_and_mutate)) # 参与者得分
self_mask = rearrange('i -> i 1', first_winner) == torch.arange(num_repro_and_mutate) # 自身掩码
contender_scores = torch.where(self_mask, 1e6, contender_scores)
sibling_mask = (rearrange('p i -> p 1 i 1', first_parent_ids) == rearrange('c j -> 1 c 1 j', parent_ids)) # 兄弟掩码
valid_parent_mask = (rearrange('p i -> p 1 i 1', first_parent_ids) != -1) & (rearrange('c j -> 1 c 1 j', parent_ids) != -1) # 有效父母掩码
num_shared_parents = (sibling_mask & valid_parent_mask).float().sum(dim=(-1, -2)) # 共享父母的数量
contender_scores += num_shared_parents * 1e3
second_contender_ids = contender_scores.argsort(dim=-1)[..., :num_tournament_contenders] # 第二轮锦标赛的参与者ID
second_participants, second_tournaments = pool[second_contender_ids], fitnesses[second_contender_ids]
second_winner = second_tournaments.topk(1, dim=-1, largest=True, sorted=False).indices # 第二轮锦标赛的获胜者
second_winner = rearrange('p 1 -> p', second_winner)
# get parents
first_ids = get_at('p [t], p -> p', first_contender_ids, first_winner) # 第一轮锦标赛的获胜者的ID
second_ids = get_at('p [t], p -> p', second_contender_ids, second_winner) # 第二轮锦标赛的获胜者的ID
new_parent_ids = torch.stack((first_ids, second_ids), dim=-1) # 新的父母ID对
# 从第一组参与者和第一组获胜者中获取父母1
parent1 = get_at('p [t] g, p -> p g', first_participants, first_winner)
# 从第二组参与者和第二组获胜者中获取父母2
parent2 = get_at('p [t] g, p -> p g', second_participants, second_winner)
# 交叉重组父母的基因
# 将父母1的前半部分和父母2的后半部分连接起来形成子代
children = torch.cat((parent1[:, :gene_midpoint], parent2[:, gene_midpoint:]), dim=-1)
# 将子代添加到种群中
pool = torch.cat((pool, children))
# 重置父母ID数组并将新的父母ID添加到其中
parent_ids.fill_(-1)
parent_ids = torch.cat((parent_ids, new_parent_ids))
# 在种群中突变基因
# 生成一个用于确定哪些基因需要突变的掩码
mutate_mask = torch.randn(pool.shape).argsort(dim=-1) < num_mutate
# 生成一个随机噪声数组,用于基因突变
noise = torch.randint(0, 2, pool.shape) * 2 - 1
# 根据掩码决定是否对基因进行突变,并添加随机噪声
pool = torch.where(mutate_mask, pool + noise, pool)
# 将基因值限制在0到255之间
pool.clamp_(0, 255)
# 将精英个体重新添加到种群中
# 将精英个体添加回种群中
pool = torch.cat((elites, pool))
# 将精英个体的父母ID添加回父母ID数组中
parent_ids = torch.cat((elites_parent_ids, parent_ids))
# 递增代数计数器
generation += 1
.\lucidrains\genetic-algorithm-pytorch\qbmb.py
"""
Queen-bee and Mutant-bee evolution for genetic algorithms - Jung 2007
4 years after proposing the Queen bee evolution genetic algorithm, Jung proposes a simplification to get rid of a few hyperparameters
In the new scheme, all the selected bees to mate with the queen undergo strong mutation prior to crossover
This scheme therefore better preserves the queen's genetic code. He shows through various experiments that this performs just as well as the original algorithm while being simpler
https://www.researchgate.net/publication/290131255_Queen-bee_and_Mutant-bee_Evolution_for_Genetic_Algorithms
"""
import torch
from einops import repeat
from einx import get_at
# constants
GOAL = 'Attention is all you need'
POP_SIZE = 100
MUTATION_PROB = 0.04
STRONG_MUTATION_PROB = 0.25
NUM_TOURNAMENT_PARTICIPANTS = 25
# encode and decode functions
def encode(s):
return torch.tensor([ord(c) for c in s])
def decode(t):
return ''.join([chr(i) for i in t.tolist()])
# derived constants
gene_length = len(GOAL)
gene_midpoint = gene_length // 2
target_gene = encode(GOAL)
num_code_mutate = MUTATION_PROB * gene_length
strong_num_code_mutate = STRONG_MUTATION_PROB * gene_length
# queen bee genetic algorithm
generation = 1
pool = torch.randint(0, 255, (POP_SIZE, gene_length))
queen = queen_fitness = None
while True:
print(f"\n\ngeneration {generation}\n")
# sort population by fitness
fitnesses = 1. / torch.square(pool - target_gene).sum(dim = -1)
indices = fitnesses.sort(descending = True).indices
pool, fitnesses = pool[indices], fitnesses[indices]
# display every generation
if queen is not None:
print("queen:")
print(f"{decode(queen)} ({queen_fitness.item():.3f})\n")
for gene, fitness in zip(pool, fitnesses):
print(f"{decode(gene)} ({fitness.item():.3f})")
# if one of the children has a better fitness than queen, that child becomes the new queen
# and the queen replaces the worst bee in the population, kept around for at least one generation more
if queen is not None and queen_fitness < fitnesses[0]:
pool = torch.cat((pool, queen[None, :]), dim = 0)
fitnesses = torch.cat((fitnesses, queen_fitness[None]), dim = 0)
queen = queen_fitness = None
# separate the queen bee from the rest of the population
if queen is None:
queen, pool = pool[0], pool[1:]
queen_fitness, fitnesses = fitnesses[0], fitnesses[1:]
# solved if any fitness is inf
if (queen_fitness == float('inf')).any():
break
# deterministic tournament selection - let top winner become parent with queen
contender_ids = torch.randn((POP_SIZE - 1, POP_SIZE - 1)).argsort(dim = -1)[..., :NUM_TOURNAMENT_PARTICIPANTS]
participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
top_winner = tournaments.topk(1, dim = -1, largest = True, sorted = False).indices
parents = get_at('... [t] g, ... 1 -> ... g', participants, top_winner)
# potential parents with queen is strongly mutated ("Mutant Bee")
strong_mutate_mask = torch.randn(parents.shape).argsort(dim = -1) < strong_num_code_mutate
noise = torch.randint(0, 2, parents.shape) * 2 - 1
mutated_parents = torch.where(strong_mutate_mask, parents + noise, parents)
mutated_parents.clamp_(0, 255)
# cross over all chosen drones with the queen
queen_parents = repeat(queen, '... -> p ...', p = POP_SIZE - 1)
queen_and_parents = torch.stack((queen_parents, mutated_parents), dim = 1)
# in my experiments, the crossover point must be random between queen and drones for this to work
# todo: get caught up with all the different types of crossover operators
rand_crossover_order = torch.randn(queen_and_parents.shape[:2]).argsort(dim = -1)
batch_arange = torch.arange(POP_SIZE - 1)[..., None]
queen_and_parents = queen_and_parents[batch_arange, rand_crossover_order]
# 从 queen_and_parents 张量中解绑出 queen_parents 和 mutated_parents,沿着第一个维度进行解绑
queen_parents, mutated_parents = queen_and_parents.unbind(dim = 1)
# 将 queen_parents 和 mutated_parents 沿着最后一个维度拼接起来,形成新的 pool 张量
pool = torch.cat((queen_parents[:, :gene_midpoint], mutated_parents[:, gene_midpoint:]), dim = -1)
# 对种群中的基因进行变异
# 创建一个与 pool 张量相同形状的张量,其中的元素按照正态分布排序,小于 num_code_mutate 的元素为 True
mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_code_mutate
# 创建一个与 pool 张量相同形状的张量,元素为 0 或 1
noise = torch.randint(0, 2, pool.shape) * 2 - 1
# 根据 mutate_mask,对 pool 张量中的元素进行变异,如果 mutate_mask 中对应位置为 True,则加上 noise 中对应位置的值
pool = torch.where(mutate_mask, pool + noise, pool)
# 将 pool 张量中的元素限制在 0 到 255 之间
pool.clamp_(0, 255)
# 增加一代
generation += 1
genetic-algorithm-pytorch
a simple genetic algorithm written in Pytorch
running
$ python ga.py
.\lucidrains\geometric-vector-perceptron\examples\data_handler.py
# 作者:Eric Alcaide
# 从 https://github.com/jonathanking/sidechainnet 借用了大部分代码
# 下面是其许可证:
# 版权所有 2020 Jonathan King
# 允许以源代码和二进制形式重新分发和使用,无论是否进行修改,只要满足以下条件:
#
# 1. 源代码的再分发必须保留上述版权声明、此条件列表和以下免责声明。
#
# 2. 以二进制形式再分发时,必须在提供的文档和/或其他材料中复制上述版权声明、此条件列表和以下免责声明。
#
# 3. 未经特定事先书面许可,不得使用版权持有人或其贡献者的名称来认可或推广从本软件派生的产品。
#
# 版权持有人和贡献者提供的本软件是按原样提供的,不提供任何明示或暗示的担保,包括但不限于对适销性和特定用途的暗示担保。
# 在任何情况下,无论是在合同、严格责任还是侵权(包括疏忽或其他方式)的情况下,版权持有人或贡献者均不对任何直接、间接、附带、特殊、惩罚性或后果性损害(包括但不限于替代商品或服务的采购、使用、数据或利润损失或业务中断)负责,即使已被告知可能发生此类损害。
import warnings
warnings.filterwarnings("ignore")
import torch
import numpy as np
from einops import repeat, rearrange
######################
## structural utils ##
######################
def get_dihedral(c1, c2, c3, c4):
""" 返回弯曲角度(弧度)。
将使用来自以下链接的 atan2 公式:
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
输入:
* c1: (batch, 3) 或 (3,)
* c2: (batch, 3) 或 (3,)
* c3: (batch, 3) 或 (3,)
* c4: (batch, 3) 或 (3,)
"""
u1 = c2 - c1
u2 = c3 - c2
u3 = c4 - c3
return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
def get_angle(c1, c2, c3):
""" 返回角度(弧度)。
输入:
* c1: (batch, 3) 或 (3,)
* c2: (batch, 3) 或 (3,)
* c3: (batch, 3) 或 (3,)
"""
u1 = c2 - c1
u2 = c3 - c2
# 不使用传统的 arccos,因为它得到的是“最小角度”,不一定是我们想要的
# return torch.acos( (u1*u2).sum(dim=-1) / (torch.norm(u1, dim=-1)*torch.norm(u2, dim=-1) )+
# 更好地使用 atan2 公式:atan2(cross, dot) 来自这里:
# https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html
# 添加负号,因为我们希望角度是反向的 - sidechainnet 问题
return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1),
-(u1*u2).sum(dim=-1) )
def kabsch_torch(X, Y):
""" 将 X 对齐到 Y 的 Kabsch 对齐。
假设 X、Y 都是 (D, N) 的形式 - 通常是 (3, N)
"""
# 将 X 和 Y 居中到原点
X_ = X - X.mean(dim=-1, keepdim=True)
Y_ = Y - Y.mean(dim=-1, keepdim=True)
# 计算协方差矩阵(对于每个批次中的蛋白质)
C = torch.matmul(X_, Y_.t())
# 通过 SVD 计算最佳旋转矩阵 - 警告!W 必须被转置
V, S, W = torch.svd(C.detach())
# 方向校正的行列式符号
d = (torch.det(V) * torch.det(W)) < 0.0
if d:
S[-1] = S[-1] * (-1)
V[:, -1] = V[:, -1] * (-1)
# 创建旋转矩阵 U
U = torch.matmul(V, W.t())
# 计算旋转
X_ = torch.matmul(X_.t(), U).t()
# 返回居中和对齐后的 X_ 和 Y_
return X_, Y_
# 计算两个张量之间的均方根偏差,假设 X 和 Y 的形状都是 (batch, d, n),通常是 (batch, 3, N)
def rmsd_torch(X, Y):
""" Assumes x,y are both (batch, d, n) - usually (batch, 3, N). """
return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) )
############
### INFO ###
############
# 包含了不同氨基酸的构建信息的字典
SC_BUILD_INFO = {
'A': {
'angles-names': ['N-CA-CB'],
'angles-types': ['N -CX-CT'],
'angles-vals': [1.9146261894377796],
'atom-names': ['CB'],
'bonds-names': ['CA-CB'],
'bonds-types': ['CX-CT'],
'bonds-vals': [1.526],
'torsion-names': ['C-N-CA-CB'],
'torsion-types': ['C -N -CX-CT'],
'torsion-vals': ['p']
},
'R': {
'angles-names': [
'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-NE', 'CD-NE-CZ', 'NE-CZ-NH1',
'NE-CZ-NH2'
],
'angles-types': [
'N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-N2', 'C8-N2-CA', 'N2-CA-N2',
'N2-CA-N2'
],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.9408061282176945,
2.150245638457014, 2.0943951023931953, 2.0943951023931953
],
'atom-names': ['CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-NE', 'NE-CZ', 'CZ-NH1', 'CZ-NH2'],
'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-N2', 'N2-CA', 'CA-N2', 'CA-N2'],
'bonds-vals': [1.526, 1.526, 1.526, 1.463, 1.34, 1.34, 1.34],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-NE', 'CG-CD-NE-CZ',
'CD-NE-CZ-NH1', 'CD-NE-CZ-NH2'
],
'torsion-types': [
'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-N2', 'C8-C8-N2-CA',
'C8-N2-CA-N2', 'C8-N2-CA-N2'
],
'torsion-vals': ['p', 'p', 'p', 'p', 'p', 'p', 'i']
},
'N': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-ND2'],
'angles-types': ['N -CX-2C', 'CX-2C-C ', '2C-C -O ', '2C-C -N '],
'angles-vals': [
1.9146261894377796, 1.9390607989657, 2.101376419401173, 2.035053907825388
],
'atom-names': ['CB', 'CG', 'OD1', 'ND2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-ND2'],
'bonds-types': ['CX-2C', '2C-C ', 'C -O ', 'C -N '],
'bonds-vals': [1.526, 1.522, 1.229, 1.335],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-ND2'],
'torsion-types': ['C -N -CX-2C', 'N -CX-2C-C ', 'CX-2C-C -O ', 'CX-2C-C -N '],
'torsion-vals': ['p', 'p', 'p', 'i']
},
'D': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-OD2'],
'angles-types': ['N -CX-2C', 'CX-2C-CO', '2C-CO-O2', '2C-CO-O2'],
'angles-vals': [
1.9146261894377796, 1.9390607989657, 2.0420352248333655, 2.0420352248333655
],
'atom-names': ['CB', 'CG', 'OD1', 'OD2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-OD2'],
'bonds-types': ['CX-2C', '2C-CO', 'CO-O2', 'CO-O2'],
'bonds-vals': [1.526, 1.522, 1.25, 1.25],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-OD2'],
'torsion-types': ['C -N -CX-2C', 'N -CX-2C-CO', 'CX-2C-CO-O2', 'CX-2C-CO-O2'],
'torsion-vals': ['p', 'p', 'p', 'i']
},
'C': {
'angles-names': ['N-CA-CB', 'CA-CB-SG'],
'angles-types': ['N -CX-2C', 'CX-2C-SH'],
'angles-vals': [1.9146261894377796, 1.8954275676658419],
'atom-names': ['CB', 'SG'],
'bonds-names': ['CA-CB', 'CB-SG'],
'bonds-types': ['CX-2C', '2C-SH'],
'bonds-vals': [1.526, 1.81],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-SG'],
'torsion-types': ['C -N -CX-2C', 'N -CX-2C-SH'],
'torsion-vals': ['p', 'p']
},
'Q': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-NE2'],
'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-C ', '2C-C -O ', '2C-C -N '],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.101376419401173,
2.035053907825388
],
'atom-names': ['CB', 'CG', 'CD', 'OE1', 'NE2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-NE2'],
'bonds-types': ['CX-2C', '2C-2C', '2C-C ', 'C -O ', 'C -N '],
'bonds-vals': [1.526, 1.526, 1.522, 1.229, 1.335],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-NE2'
],
'torsion-types': [
'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-C ', '2C-2C-C -O ', '2C-2C-C -N '
],
'torsion-vals': ['p', 'p', 'p', 'p', 'i']
},
'E': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-OE2'],
'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-CO', '2C-CO-O2', '2C-CO-O2'],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.0420352248333655,
2.0420352248333655
],
'atom-names': ['CB', 'CG', 'CD', 'OE1', 'OE2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-OE2'],
'bonds-types': ['CX-2C', '2C-2C', '2C-CO', 'CO-O2', 'CO-O2'],
'bonds-vals': [1.526, 1.526, 1.522, 1.25, 1.25],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-OE2'
],
'torsion-types': [
'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-CO', '2C-2C-CO-O2', '2C-2C-CO-O2'
],
'torsion-vals': ['p', 'p', 'p', 'p', 'i']
},
'G': {
'angles-names': [],
'angles-types': [],
'angles-vals': [],
'atom-names': [],
'bonds-names': [],
'bonds-types': [],
'bonds-vals': [],
'torsion-names': [],
'torsion-types': [],
'torsion-vals': []
},
'H': {
'angles-names': [
'N-CA-CB', 'CA-CB-CG', 'CB-CG-ND1', 'CG-ND1-CE1', 'ND1-CE1-NE2', 'CE1-NE2-CD2'
],
'angles-types': [
'N -CX-CT', 'CX-CT-CC', 'CT-CC-NA', 'CC-NA-CR', 'NA-CR-NB', 'CR-NB-CV'
],
'angles-vals': [
1.9146261894377796, 1.9739673840055867, 2.0943951023931953,
1.8849555921538759, 1.8849555921538759, 1.8849555921538759
],
'atom-names': ['CB', 'CG', 'ND1', 'CE1', 'NE2', 'CD2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-ND1', 'ND1-CE1', 'CE1-NE2', 'NE2-CD2'],
'bonds-types': ['CX-CT', 'CT-CC', 'CC-NA', 'NA-CR', 'CR-NB', 'NB-CV'],
'bonds-vals': [1.526, 1.504, 1.385, 1.343, 1.335, 1.394],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-ND1', 'CB-CG-ND1-CE1', 'CG-ND1-CE1-NE2',
'ND1-CE1-NE2-CD2'
],
'torsion-types': [
'C -N -CX-CT', 'N -CX-CT-CC', 'CX-CT-CC-NA', 'CT-CC-NA-CR', 'CC-NA-CR-NB',
'NA-CR-NB-CV'
],
'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0]
},
'I': {
'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CB-CG1-CD1', 'CA-CB-CG2'],
'angles-types': ['N -CX-3C', 'CX-3C-2C', '3C-2C-CT', 'CX-3C-CT'],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791
],
'atom-names': ['CB', 'CG1', 'CD1', 'CG2'],
'bonds-names': ['CA-CB', 'CB-CG1', 'CG1-CD1', 'CB-CG2'],
'bonds-types': ['CX-3C', '3C-2C', '2C-CT', '3C-CT'],
'bonds-vals': [1.526, 1.526, 1.526, 1.526],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'CA-CB-CG1-CD1', 'N-CA-CB-CG2'],
'torsion-types': ['C -N -CX-3C', 'N -CX-3C-2C', 'CX-3C-2C-CT', 'N -CX-3C-CT'],
'torsion-vals': ['p', 'p', 'p', 'p']
},
'L': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CB-CG-CD2'],
'angles-types': ['N -CX-2C', 'CX-2C-3C', '2C-3C-CT', '2C-3C-CT'],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791
],
'atom-names': ['CB', 'CG', 'CD1', 'CD2'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD1', 'CG-CD2'],
'bonds-types': ['CX-2C', '2C-3C', '3C-CT', '3C-CT'],
'bonds-vals': [1.526, 1.526, 1.526, 1.526],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CA-CB-CG-CD2'],
'torsion-types': ['C -N -CX-2C', 'N -CX-2C-3C', 'CX-2C-3C-CT', 'CX-2C-3C-CT'],
'torsion-vals': ['p', 'p', 'p', 'p']
},
'K': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-CE', 'CD-CE-NZ'],
'angles-types': ['N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-C8', 'C8-C8-N3'],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791,
1.9408061282176945
],
'atom-names': ['CB', 'CG', 'CD', 'CE', 'NZ'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-CE', 'CE-NZ'],
'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-C8', 'C8-N3'],
'bonds-vals': [1.526, 1.526, 1.526, 1.526, 1.471],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-CE', 'CG-CD-CE-NZ'
],
'torsion-types': [
'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-C8', 'C8-C8-C8-N3'
],
'torsion-vals': ['p', 'p', 'p', 'p', 'p']
},
'M': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-SD', 'CG-SD-CE'],
'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-S ', '2C-S -CT'],
'angles-vals': [
1.9146261894377796, 1.911135530933791, 2.0018926520374962, 1.726130630222392
],
'atom-names': ['CB', 'CG', 'SD', 'CE'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-SD', 'SD-CE'],
'bonds-types': ['CX-2C', '2C-2C', '2C-S ', 'S -CT'],
'bonds-vals': [1.526, 1.526, 1.81, 1.81],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-SD', 'CB-CG-SD-CE'],
'torsion-types': ['C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-S ', '2C-2C-S -CT'],
'torsion-vals': ['p', 'p', 'p', 'p']
},
'F': {
'angles-names': [
'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-CE2',
'CZ-CE2-CD2'
],
'angles-types': [
'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CA',
'CA-CA-CA'
],
'angles-vals': [
1.9146261894377796, 1.9896753472735358, 2.0943951023931953,
2.0943951023931953, 2.0943951023931953, 2.0943951023931953, 2.0943951023931953
],
'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'CE2', 'CD2'],
'bonds-names': [
'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-CE2', 'CE2-CD2'
],
'bonds-types': ['CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA'],
'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.4, 1.4, 1.4],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ',
'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2'
],
'torsion-types': [
'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-CA',
'CA-CA-CA-CA', 'CA-CA-CA-CA'
],
'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0, 0.0]
},
'P': {
'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD'],
'angles-types': ['N -CX-CT', 'CX-CT-CT', 'CT-CT-CT'],
'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
'atom-names': ['CB', 'CG', 'CD'],
'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD'],
'bonds-types': ['CX-CT', 'CT-CT', 'CT-CT'],
'bonds-vals': [1.526, 1.526, 1.526],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD'],
'torsion-types': ['C -N -CX-CT', 'N -CX-CT-CT', 'CX-CT-CT-CT'],
'torsion-vals': ['p', 'p', 'p']
},
# 定义了氨基酸"S"的键值对,包含了该氨基酸的角度、键和扭转信息
'S': {
'angles-names': ['N-CA-CB', 'CA-CB-OG'],
'angles-types': ['N -CX-2C', 'CX-2C-OH'],
'angles-vals': [1.9146261894377796, 1.911135530933791],
'atom-names': ['CB', 'OG'],
'bonds-names': ['CA-CB', 'CB-OG'],
'bonds-types': ['CX-2C', '2C-OH'],
'bonds-vals': [1.526, 1.41],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG'],
'torsion-types': ['C -N -CX-2C', 'N -CX-2C-OH'],
'torsion-vals': ['p', 'p']
},
# 定义了氨基酸"T"的键值对,包含了该氨基酸的角度、键和扭转信息
'T': {
'angles-names': ['N-CA-CB', 'CA-CB-OG1', 'CA-CB-CG2'],
'angles-types': ['N -CX-3C', 'CX-3C-OH', 'CX-3C-CT'],
'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
'atom-names': ['CB', 'OG1', 'CG2'],
'bonds-names': ['CA-CB', 'CB-OG1', 'CB-CG2'],
'bonds-types': ['CX-3C', '3C-OH', '3C-CT'],
'bonds-vals': [1.526, 1.41, 1.526],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG1', 'N-CA-CB-CG2'],
'torsion-types': ['C -N -CX-3C', 'N -CX-3C-OH', 'N -CX-3C-CT'],
'torsion-vals': ['p', 'p', 'p']
},
# 定义了氨基酸"W"的键值对,包含了该氨基酸的角度、键和扭转信息
'W': {
'angles-names': [
'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-NE1', 'CD1-NE1-CE2',
'NE1-CE2-CZ2', 'CE2-CZ2-CH2', 'CZ2-CH2-CZ3', 'CH2-CZ3-CE3', 'CZ3-CE3-CD2'
],
'angles-types': [
'N -CX-CT', 'CX-CT-C*', 'CT-C*-CW', 'C*-CW-NA', 'CW-NA-CN', 'NA-CN-CA',
'CN-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CB'
],
'angles-vals': [
1.9146261894377796, 2.0176006153054447, 2.181661564992912, 1.8971728969178363,
1.9477874452256716, 2.3177972466484698, 2.0943951023931953,
2.0943951023931953, 2.0943951023931953, 2.0943951023931953
],
'atom-names': [
'CB', 'CG', 'CD1', 'NE1', 'CE2', 'CZ2', 'CH2', 'CZ3', 'CE3', 'CD2'
],
'bonds-names': [
'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-NE1', 'NE1-CE2', 'CE2-CZ2', 'CZ2-CH2',
'CH2-CZ3', 'CZ3-CE3', 'CE3-CD2'
],
'bonds-types': [
'CX-CT', 'CT-C*', 'C*-CW', 'CW-NA', 'NA-CN', 'CN-CA', 'CA-CA', 'CA-CA',
'CA-CA', 'CA-CB'
],
'bonds-vals': [1.526, 1.495, 1.352, 1.381, 1.38, 1.4, 1.4, 1.4, 1.4, 1.404],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-NE1', 'CG-CD1-NE1-CE2',
'CD1-NE1-CE2-CZ2', 'NE1-CE2-CZ2-CH2', 'CE2-CZ2-CH2-CZ3', 'CZ2-CH2-CZ3-CE3',
'CH2-CZ3-CE3-CD2'
],
'torsion-types': [
'C -N -CX-CT', 'N -CX-CT-C*', 'CX-CT-C*-CW', 'CT-C*-CW-NA', 'C*-CW-NA-CN',
'CW-NA-CN-CA', 'NA-CN-CA-CA', 'CN-CA-CA-CA', 'CA-CA-CA-CA', 'CA-CA-CA-CB'
],
'torsion-vals': [
'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 3.141592653589793,
0.0, 0.0, 0.0
]
},
'Y': {
'angles-names': [
'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-OH',
'CE1-CZ-CE2', 'CZ-CE2-CD2'
],
'angles-types': [
'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-C ', 'CA-C -OH',
'CA-C -CA', 'C -CA-CA'
],
'angles-vals': [
1.9146261894377796, 1.9896753472735358, 2.0943951023931953,
2.0943951023931953, 2.0943951023931953, 2.0943951023931953,
2.0943951023931953, 2.0943951023931953
],
'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'OH', 'CE2', 'CD2'],
'bonds-names': [
'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-OH', 'CZ-CE2', 'CE2-CD2'
],
'bonds-types': [
'CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-C ', 'C -OH', 'C -CA', 'CA-CA'
],
'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.409, 1.364, 1.409, 1.4],
'torsion-names': [
'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ',
'CD1-CE1-CZ-OH', 'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2'
],
'torsion-types': [
'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-C ',
'CA-CA-C -OH', 'CA-CA-C -CA', 'CA-C -CA-CA'
],
'torsion-vals': [
'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 0.0, 0.0
]
},
'V': {
'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CA-CB-CG2'],
'angles-types': ['N -CX-3C', 'CX-3C-CT', 'CX-3C-CT'],
'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
'atom-names': ['CB', 'CG1', 'CG2'],
'bonds-names': ['CA-CB', 'CB-CG1', 'CB-CG2'],
'bonds-types': ['CX-3C', '3C-CT', '3C-CT'],
'bonds-vals': [1.526, 1.526, 1.526],
'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'N-CA-CB-CG2'],
'torsion-types': ['C -N -CX-3C', 'N -CX-3C-CT', 'N -CX-3C-CT'],
'torsion-vals': ['p', 'p', 'p']
},
'_': {
'angles-names': [],
'angles-types': [],
'angles-vals': [],
'atom-names': [],
'bonds-names': [],
'bonds-types': [],
'bonds-vals': [],
'torsion-names': [],
'torsion-types': [],
'torsion-vals': []
}
# 闭合 BB_BUILD_INFO 字典
}
# 定义 BB_BUILD_INFO 字典,包含键值对表示不同键的键值
BB_BUILD_INFO = {
"BONDLENS": {
# 更新的值是根据来自 1DPE_1_A 的晶体数据进行的验证
# 注释的值是 sidechainnet 的值
'n-ca': 1.4664931, # 1.442,
'ca-c': 1.524119, # 1.498,
'c-n': 1.3289373, # 1.379,
'c-o': 1.229, # 来自 parm10.dat || 根据结构有很大的变化
# 我们从 1DPE_1_A 得到 1.3389416,但也从 2F2H_d2f2hf1 得到 1.2289
'c-oh': 1.364
}, # 对于 OXT 来自 parm10.dat
# 用于放置氧原子
"BONDANGS": {
'ca-c-o': 2.0944, # 近似为 2pi / 3; parm10.dat 表示为 2.0350539
'ca-c-oh': 2.0944
}, # 等同于 'ca-c-o',对于 OXT
"BONDTORSIONS": {
'n-ca-c-n': -0.785398163
} # 一个简单的近似,不打算精确
}
# 定义函数 make_cloud_mask,返回一个包含相关点为 1,填充点为 0 的数组
def make_cloud_mask(aa):
""" relevent points will be 1. paddings will be 0. """
mask = np.zeros(14)
if aa != "_":
n_atoms = 4+len( SC_BUILD_INFO[aa]["atom-names"] )
mask[:n_atoms] = 1
return mask
# 定义函数 make_bond_mask,返回一个包含每个原子起始键长的数组
def make_bond_mask(aa):
""" Gives the length of the bond originating each atom. """
mask = np.zeros(14)
# backbone
mask[0] = BB_BUILD_INFO["BONDLENS"]['c-n']
mask[1] = BB_BUILD_INFO["BONDLENS"]['n-ca']
mask[2] = BB_BUILD_INFO["BONDLENS"]['ca-c']
mask[3] = BB_BUILD_INFO["BONDLENS"]['c-o']
# sidechain - except padding token
if aa in SC_BUILD_INFO.keys():
for i,bond in enumerate(SC_BUILD_INFO[aa]['bonds-vals']):
mask[4+i] = bond
return mask
# 定义函数 make_theta_mask,返回一个包含每个原子起始键角度的数组
def make_theta_mask(aa):
""" Gives the theta of the bond originating each atom. """
mask = np.zeros(14)
# backbone
#
# sidechain
for i,theta in enumerate(SC_BUILD_INFO[aa]['angles-vals']):
mask[4+i] = theta
return mask
# 定义函数 make_torsion_mask,返回一个包含每个原子起始键二面角的数组
def make_torsion_mask(aa):
""" Gives the dihedral of the bond originating each atom. """
mask = np.zeros(14)
# backbone
#
# sidechain
for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-vals']):
if torsion == 'p':
mask[4+i] = np.nan
elif torsion == "i":
# https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L372
mask[4+i] = 999 # anotate to change later # mask[4+i-1] - np.pi
else:
mask[4+i] = torsion
return mask
# 定义函数 make_idx_mask,返回一个包含前三个点的索引的数组
def make_idx_mask(aa):
""" Gives the idxs of the 3 previous points. """
mask = np.zeros((11, 3))
# backbone
mask[0, :] = np.arange(3)
# sidechain
mapper = {"N": 0, "CA": 1, "C":2, "CB": 4}
for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-names']):
# 获取形成二面角的所有原子
torsions = [x.rstrip(" ") for x in torsion.split("-")]
# 对于每个原子
for n, torsion in enumerate(torsions[:-1]):
# 获取坐标数组中原子的索引
loc = mapper[torsion] if torsion in mapper.keys() else 4 + SC_BUILD_INFO[aa]['atom-names'].index(torsion)
# 设置位置为索引
mask[i+1][n] = loc
return mask
# 定义 SUPREME_INFO 字典,包含各种信息的字典
SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
"bond_mask": make_bond_mask(k),
"theta_mask": make_theta_mask(k),
"torsion_mask": make_torsion_mask(k),
"idx_mask": make_idx_mask(k),
}
for k in "ARNDCQEGHILKMFPSTWYV_"}
# 定义函数 scn_cloud_mask,获取原子位置的布尔掩码
def scn_cloud_mask(seq, coords=None):
""" Gets the boolean mask atom positions (not all aas have same atoms).
Inputs:
* seqs: (length) iterable of 1-letter aa codes of a protein
* coords: optional .(batch, lc, 3). sidechainnet coords.
returns the true mask (solves potential atoms that might not be provided)
Outputs: (length, 14) boolean mask
"""
# 如果坐标不为空
if coords is not None:
# 重新排列坐标张量的维度,将最后一维拆分为两个维度
# 检查是否等于0,然后按最后一个维度求和
# 检查是否小于坐标张量的最后一个维度的长度
# 转换为浮点数并移动到 CPU 上
return ((rearrange(coords, '... (l c) d -> ... l c d', c=14) == 0).sum(dim=-1) < coords.shape[-1]).float().cpu()
# 如果坐标为空
# 返回一个张量,其中包含序列中每个氨基酸的云掩码信息
return torch.tensor([SUPREME_INFO[aa]['cloud_mask'] for aa in seq])
# 定义函数,根据氨基酸序列生成键长掩码
def scn_bond_mask(seq):
""" Inputs:
* seqs: (length). iterable of 1-letter aa codes of a protein
Outputs: (L, 14) maps point to bond length
"""
# 返回键长掩码的张量
return torch.tensor([SUPREME_INFO[aa]['bond_mask'] for aa in seq])
# 定义函数,根据氨基酸序列和角度生成角度掩码
def scn_angle_mask(seq, angles):
""" Inputs:
* seq: (length). iterable of 1-letter aa codes of a protein
* angles: (length, 12). [phi, psi, omega, b_angle(n_ca_c), b_angle(ca_c_n), b_angle(c_n_ca), 6_scn_torsions]
Outputs: (L, 14) maps point to theta and dihedral.
first angle is theta, second is dihedral
"""
# 获取设备和精度
device, precise = angles.device, angles.type()
angles = angles
# 获取角度掩码
theta_mask = torch.tensor([SUPREME_INFO[aa]['theta_mask'] for aa in seq]).type(precise)
torsion_mask = torch.tensor([SUPREME_INFO[aa]['torsion_mask'] for aa in seq]).type(precise)
# 填充掩码与角度值
theta_mask[:, 0] = angles[:, 4] # ca_c_n
theta_mask[1:, 1] = angles[:-1, 5] # c_n_ca
theta_mask[:, 2] = angles[:, 3] # n_ca_c
theta_mask[:, 3] = BB_BUILD_INFO["BONDANGS"]["ca-c-o"]
torsion_mask[:, 0] = angles[:, 1] # n determined by psi of previous
torsion_mask[1:, 1] = angles[:-1, 2] # ca determined by omega of previous
torsion_mask[:, 2] = angles[:, 0] # c determined by phi
torsion_mask[:, 3] = angles[:, 1] - np.pi
torsion_mask[-1, 3] += np.pi
to_fill = torsion_mask != torsion_mask
to_pick = torsion_mask == 999
for i in range(len(seq)):
number = to_fill[i].long().sum()
torsion_mask[i, to_fill[i]] = angles[i, 6:6+number]
for j, val in enumerate(to_pick[i]):
if val:
torsion_mask[i, j] = torsion_mask[i, j-1] - np.pi
return torch.stack([theta_mask, torsion_mask], dim=0).to(device)
# 定义函数,根据氨基酸序列生成索引掩码
def scn_index_mask(seq):
""" Inputs:
* seq: (length). iterable of 1-letter aa codes of a protein
Outputs: (L, 11, 3) maps point to theta and dihedral.
first angle is theta, second is dihedral
"""
# 获取索引掩码
idxs = torch.tensor([SUPREME_INFO[aa]['idx_mask'] for aa in seq])
return rearrange(idxs, 'l s d -> d l s')
# 定义函数,根据氨基酸序列和角度生成蛋白质骨架
def build_scaffolds_from_scn_angles(seq, angles, coords=None, device="auto"):
""" Builds scaffolds for fast access to data
Inputs:
* seq: string of aas (1 letter code)
* angles: (L, 12) tensor containing the internal angles.
Distributed as follows (following sidechainnet convention):
* (L, 3) for torsion angles
* (L, 3) bond angles
* (L, 6) sidechain angles
* coords: (L, 3) sidechainnet coords. builds the mask with those instead
(better accuracy if modified residues present).
Outputs:
* cloud_mask: (L, 14 ) mask of points that should be converted to coords
* point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of
previous 3 points in the coords array
* angles_mask: (2, L, 14) maps point to theta and dihedral
* bond_mask: (L, 14) gives the length of the bond originating that atom
"""
precise = angles.type()
if device == "auto":
device = angles.device
if coords is not None:
cloud_mask = scn_cloud_mask(seq, coords=coords)
else:
cloud_mask = scn_cloud_mask(seq)
cloud_mask = torch.tensor(cloud_mask).bool().to(device)
# 生成点云索引掩码,将其转换为长整型张量,并移动到指定设备上
point_ref_mask = torch.tensor(scn_index_mask(seq)).long().to(device)
# 生成角度掩码,将其转换为指定精度类型的张量,并移动到指定设备上
angles_mask = torch.tensor(scn_angle_mask(seq, angles)).type(precise).to(device)
# 生成键合掩码,将其转换为指定精度类型的张量,并移动到指定设备上
bond_mask = torch.tensor(scn_bond_mask(seq)).type(precise).to(device)
# 将所有结果以字典形式返回
return {"cloud_mask": cloud_mask,
"point_ref_mask": point_ref_mask,
"angles_mask": angles_mask,
"bond_mask": bond_mask }
#############################
####### ENCODERS ############
#############################
# 修改蛋白质支架的坐标信息
def modify_scaffolds_with_coords(scaffolds, coords):
""" Gets scaffolds and fills in the right data.
Inputs:
* scaffolds: dict. as returned by `build_scaffolds_from_scn_angles`
* coords: (L, 14, 3). sidechainnet tensor. same device as scaffolds
Outputs: corrected scaffolds
"""
# 计算距离并更新:
# N, CA, C
scaffolds["bond_mask"][1:, 0] = torch.norm(coords[1:, 0] - coords[:-1, 2], dim=-1) # N
scaffolds["bond_mask"][:, 1] = torch.norm(coords[:, 1] - coords[:, 0], dim=-1) # CA
scaffolds["bond_mask"][:, 2] = torch.norm(coords[:, 2] - coords[:, 1], dim=-1) # C
# O, CB, 侧链
selector = np.arange(len(coords))
for i in range(3, 14):
# 获取索引
idx_a, idx_b, idx_c = scaffolds["point_ref_mask"][:, :, i-3] # (3, L, 11) -> 3 * (L, 11)
# 修正距离
scaffolds["bond_mask"][:, i] = torch.norm(coords[:, i] - coords[selector, idx_c], dim=-1)
# 获取角度
scaffolds["angles_mask"][0, :, i] = get_angle(coords[selector, idx_b],
coords[selector, idx_c],
coords[:, i])
# 处理 C-beta,其中请求的 C 来自前一个氨基酸
if i == 4:
# 对于第一个氨基酸,使用第二个氨基酸的 N 位置
first_next_n = coords[1, :1] # 1, 3
# 请求的 C 来自前一个氨基酸
main_c_prev_idxs = coords[selector[:-1], idx_a[1:]] # (L-1), 3
# 连接
coords_a = torch.cat([first_next_n, main_c_prev_idxs])
else:
coords_a = coords[selector, idx_a]
# 获取二面角
scaffolds["angles_mask"][1, :, i] = get_dihedral(coords_a,
coords[selector, idx_b],
coords[selector, idx_c],
coords[:, i])
# 为主链修正角度和二面角
scaffolds["angles_mask"][0, :-1, 0] = get_angle(coords[:-1, 1], coords[:-1, 2], coords[1:, 0]) # ca_c_n
scaffolds["angles_mask"][0, 1:, 1] = get_angle(coords[:-1, 2], coords[1:, 0], coords[1:, 1]) # c_n_ca
scaffolds["angles_mask"][0, :, 2] = get_angle(coords[:, 0], coords[:, 1], coords[:, 2]) # n_ca_c
# N 由前一个 psi 决定 = f(n, ca, c, n+1)
scaffolds["angles_mask"][1, :-1, 0] = get_dihedral(coords[:-1, 0], coords[:-1, 1], coords[:-1, 2], coords[1:, 0])
# CA 由 omega 决定 = f(ca, c, n+1, ca+1)
scaffolds["angles_mask"][1, 1:, 1] = get_dihedral(coords[:-1, 1], coords[:-1, 2], coords[1:, 0], coords[1:, 1])
# C 由 phi 决定 = f(c-1, n, ca, c)
scaffolds["angles_mask"][1, 1:, 2] = get_dihedral(coords[:-1, 2], coords[1:, 0], coords[1:, 1], coords[1:, 2])
return scaffolds
if __name__ == "__main__":
print(scn_cloud_mask("AAAA"))


浙公网安备 33010602011771号