Lucidrains-系列项目源码解析-一-
Lucidrains 系列项目源码解析(一)
.\lucidrains\Adan-pytorch\adan_pytorch\adan.py
import math
import torch
from torch.optim import Optimizer
# 定义一个函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个名为 Adan 的类,继承自 Optimizer 类
class Adan(Optimizer):
# 初始化函数,接受一些参数并设置默认值
def __init__(
self,
params,
lr = 1e-3,
betas = (0.02, 0.08, 0.01),
eps = 1e-8,
weight_decay = 0,
restart_cond: callable = None
):
assert len(betas) == 3
# 将参数存储在 defaults 字典中
defaults = dict(
lr = lr,
betas = betas,
eps = eps,
weight_decay = weight_decay,
restart_cond = restart_cond
)
# 调用父类的初始化函数
super().__init__(params, defaults)
# 定义优化步骤函数
def step(self, closure = None):
loss = None
# 如果存在闭包函数,则计算损失值
if exists(closure):
loss = closure()
# 遍历参数组
for group in self.param_groups:
lr = group['lr']
beta1, beta2, beta3 = group['betas']
weight_decay = group['weight_decay']
eps = group['eps']
restart_cond = group['restart_cond']
# 遍历参数
for p in group['params']:
if not exists(p.grad):
continue
data, grad = p.data, p.grad.data
assert not grad.is_sparse
state = self.state[p]
# 初始化状态信息
if len(state) == 0:
state['step'] = 0
state['prev_grad'] = torch.zeros_like(grad)
state['m'] = torch.zeros_like(grad)
state['v'] = torch.zeros_like(grad)
state['n'] = torch.zeros_like(grad)
step, m, v, n, prev_grad = state['step'], state['m'], state['v'], state['n'], state['prev_grad']
if step > 0:
prev_grad = state['prev_grad']
# 主要算法
m.mul_(1 - beta1).add_(grad, alpha = beta1)
grad_diff = grad - prev_grad
v.mul_(1 - beta2).add_(grad_diff, alpha = beta2)
next_n = (grad + (1 - beta2) * grad_diff) ** 2
n.mul_(1 - beta3).add_(next_n, alpha = beta3)
# 偏置校正项
step += 1
correct_m, correct_v, correct_n = map(lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3))
# 梯度步骤
def grad_step_(data, m, v, n):
weighted_step_size = lr / (n * correct_n).sqrt().add_(eps)
denom = 1 + weight_decay * lr
data.addcmul_(weighted_step_size, (m * correct_m + (1 - beta2) * v * correct_v), value = -1.).div_(denom)
grad_step_(data, m, v, n)
# 重启条件
if exists(restart_cond) and restart_cond(state):
m.data.copy_(grad)
v.zero_()
n.data.copy_(grad ** 2)
grad_step_(data, m, v, n)
# 设置新的增量步骤
prev_grad.copy_(grad)
state['step'] = step
return loss
.\lucidrains\Adan-pytorch\adan_pytorch\__init__.py
# 从 adan_pytorch.adan 模块中导入 Adan 类
from adan_pytorch.adan import Adan
Adan - Pytorch
Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch.
Explanation from Davis Blalock
Install
$ pip install adan-pytorch
Usage
from adan_pytorch import Adan
# mock model
import torch
from torch import nn
model = torch.nn.Sequential(
nn.Linear(16, 16),
nn.GELU()
)
# instantiate Adan with model parameters
optim = Adan(
model.parameters(),
lr = 1e-3, # learning rate (can be much higher than Adam, up to 5-10x)
betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
weight_decay = 0.02 # weight decay 0.02 is optimal per author
)
# train
for _ in range(10):
loss = model(torch.randn(16)).sum()
loss.backward()
optim.step()
optim.zero_grad()
Citations
@article{Xie2022AdanAN,
title = {Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models},
author = {Xingyu Xie and Pan Zhou and Huan Li and Zhouchen Lin and Shuicheng Yan},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.06677}
}
.\lucidrains\Adan-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'adan-pytorch', # 包名
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.0', # 版本号
license='MIT', # 许可证
description = 'Adan - (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/Adan-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'optimizer', # 关键词
],
install_requires=[
'torch>=1.6', # 安装依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类
'Intended Audience :: Developers', # 分类
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类
'License :: OSI Approved :: MIT License', # 分类
'Programming Language :: Python :: 3.6', # 分类
],
)
.\lucidrains\adjacent-attention-network\adjacent_attention_network\adjacent_attention_network.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from isab_pytorch import ISAB
# helpers
# 检查值是否存在的辅助函数
def exists(val):
return val is not None
# 从 values 中按照 indices 进行批量索引选择的辅助函数
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
# helper classes
# 残差连接类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# 预层归一化类
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x, **kwargs):
return self.net(x)
# adjacent attention class
# 邻接注意力类
class AdjacentAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 4,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.heads = heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.null_k = nn.Parameter(torch.randn(heads, dim_head))
self.null_v = nn.Parameter(torch.randn(heads, dim_head))
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
adj_kv_indices,
mask
):
b, n, d, h = *x.shape, self.heads
flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h = h)
# derive query, key, value
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# gather keys and values according to adjacency matrix
k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v))
k = batched_index_select(k, flat_indices)
v = batched_index_select(v, flat_indices)
k, v = map(lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h = h, n = n), (k, v))
# add null key / value, so a node can attend to nothing
# have come across this in GNN literature as some other name
nk, nv = map(lambda t: rearrange(t, 'h d -> () h () () d').expand(b, -1, n, 1, -1), (self.null_k, self.null_v))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
mask = F.pad(mask, (1, 0), value = 1)
# similarity of each node to its neighbors
sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale
# mask out neighbors that are just padding
mask_value = -torch.finfo(sim.dtype).max
mask = rearrange(mask.bool(), 'b n a -> b () n a')
sim.masked_fill_(~mask.bool(), mask_value)
# attention
attn = sim.softmax(dim = -1)
# dropout
attn = self.dropout(attn)
# get weighted average of the values of all neighbors
out = einsum('b h n a, b h n a d -> b h n d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
# combine output
return self.to_out(out)
# adjacent network (layers of adjacent attention)
# 邻接注意力网络类
class AdjacentAttentionNetwork(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 4,
num_neighbors_cutoff = None,
num_global_nodes = 0,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()
self.num_neighbors_cutoff = num_neighbors_cutoff
self.layers = nn.ModuleList([])
for _ in range(depth):
global_attn = PreNorm(dim, ISAB(
dim = dim,
heads = heads,
num_induced_points = num_global_nodes
)) if num_global_nodes > 0 else None
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, AdjacentAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
dropout = attn_dropout
))),
global_attn,
Residual(PreNorm(dim, FeedForward(
dim = dim,
dropout = ff_dropout
)))
]))
def forward(self, x, adjacency_mat, mask = None):
device, n = x.device, x.shape[1]
diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
adjacency_mat |= diag # nodes should pay attention itself (self-interacting)
# zero out points on adjacency matrix
# where the nodes are just padding
if exists(mask):
adjacency_mat &= (mask[:, :, None] * mask[:, None, :])
adj_mat = adjacency_mat.float()
# if we don't set a hard limit to the number of neighbors:
# - get the maximum number of neighbors and pad the rest of the nodes with less than that number of neighbors
# else:
# - randomly sample the cutoff number of neighbors for any node that exceeds the max
# - this would be similar to random sparse attention (bigbird)
# get the maximum number of neighbors
max_neighbors = int(adj_mat.sum(dim = -1).max())
if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
# to randomly sample the neighbors, add a small uniform noise to the mask and topk
noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
adj_mat = adj_mat + noise
adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
# cast the mask back to 0s and 1s
adj_mask = (adj_mask > 0.5).float()
else:
# todo - get distribution of number of neighbors, and strategically break up attention (message passing) to multiple steps
# - start with a bimodal num neighbors test case, then generalize
# use topk to get all the neighbors
# also pass the mask into the attention, as some neighbors will be just padding and not actually neighbors
adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
for attn, global_attn, ff in self.layers:
x = attn(
x,
adj_kv_indices = adj_kv_indices,
mask = adj_mask
)
if exists(global_attn):
out, _ = global_attn(x, mask = mask)
x = x + out
x = ff(x)
return x
.\lucidrains\adjacent-attention-network\adjacent_attention_network\__init__.py
# 从相邻注意力网络模块中导入相邻注意力网络类
from adjacent_attention_network.adjacent_attention_network import AdjacentAttentionNetwork
Adjacent Attention Network
An implementation of a simple transformer that is equivalent to graph neural network where the message passing is done with multi-head attention at each successive layer. Since Graph Attention Network is already taken, I decided to name it Adjacent Attention Network instead. The design will be more transformer-centric. Instead of using the square root inverse adjacency matrix trick by Kipf and Welling, in this framework it will simply be translated to the proper attention mask at each layer.
This repository is for my own exploration into the graph neural network field. My gut tells me the transformers architecture can generalize and outperform graph neural networks.
Install
$ pip install adjacent-attention-network
Usage
Basically a transformers where each node pays attention to the neighbors as defined by the adjacency matrix. Complexity is O(n * max_neighbors). Max number of neighbors as defined by the adjacency matrix.
The following example will have a complexity of ~ 1024 * 100
import torch
from adjacent_attention_network import AdjacentAttentionNetwork
model = AdjacentAttentionNetwork(
dim = 512,
depth = 6,
heads = 4
)
adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1) < 0.1
nodes = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()
model(nodes, adj_mat, mask = mask) # (1, 1024, 512)
If the number of neighbors contain outliers, then the above will lead to wasteful computation, since a lot of nodes will be doing attention on padding. You can use the following stop-gap measure to account for these outliers.
import torch
from adjacent_attention_network import AdjacentAttentionNetwork
model = AdjacentAttentionNetwork(
dim = 512,
depth = 6,
heads = 4,
num_neighbors_cutoff = 100
).cuda()
adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1).cuda() < 0.1
nodes = torch.randn(1, 1024, 512).cuda()
mask = torch.ones(1, 1024).bool().cuda()
# for some reason, one of the nodes is fully connected to all others
adj_mat[:, 0] = 1.
model(nodes, adj_mat, mask = mask) # (1, 1024, 512)
For non-local attention, I've decided to use a trick from the Set Transformers paper, the Induced Set Attention Block (ISAB). From the lens of graph neural net literature, this would be analogous as having global nodes for message passing non-locally.
import torch
from adjacent_attention_network import AdjacentAttentionNetwork
model = AdjacentAttentionNetwork(
dim = 512,
depth = 6,
heads = 4,
num_global_nodes = 5
).cuda()
adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1).cuda() < 0.1
nodes = torch.randn(1, 1024, 512).cuda()
mask = torch.ones(1, 1024).bool().cuda()
model(nodes, adj_mat, mask = mask) # (1, 1024, 512)
.\lucidrains\adjacent-attention-network\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'adjacent-attention-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.12', # 版本号
license='MIT', # 许可证
description = 'Adjacent Attention Network - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/adjacent-attention-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'graph neural network',
'transformers'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6',
'isab-pytorch<0.2'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\agent-attention-pytorch\agent_attention_pytorch\agent_attention_pytorch.py
# 导入 torch 库
import torch
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 从 torch 模块中导入 nn、einsum、Tensor
from torch import nn, einsum, Tensor
# 从 einops 库中导入 rearrange、repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 中导入 Rearrange 类
# 定义函数
# 判断变量是否存在的函数
def exists(v):
return v is not None
# 主要类
# 自注意力机制的代理类
class AgentSelfAttention(Module):
def __init__(
self,
dim,
*,
num_agent_tokens,
dim_head = 64,
heads = 8,
dropout = 0.,
talking_heads = True,
gate = True,
combine_agent_tokens = False
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
# 将输入转换为查询、键、值
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
)
# 生成门控信息
self.to_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if gate else None
# 初始化代理令牌
self.agent_tokens = nn.Parameter(torch.zeros(heads, num_agent_tokens, dim_head))
nn.init.normal_(self.agent_tokens, std = 0.02)
# 对查询和键进行对话操作
self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
# 对查询和键进行 dropout 操作
self.qa_dropout = nn.Dropout(dropout)
self.ak_dropout = nn.Dropout(dropout)
# 输出层
self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
# 前向传播函数
def forward(
self,
x,
mask = None,
agent_tokens = None,
return_agent_tokens = False
):
batch = x.shape[0]
q, k, v = self.to_qkv(x)
if exists(agent_tokens):
a = agent_tokens
else:
a = repeat(self.agent_tokens, 'h m d -> b h m d', b = batch)
a = a * self.scale
qa_sim = einsum('b h i d, b h j d -> b h i j', q, a)
ak_sim = einsum('b h i d, b h j d -> b h i j', a, k)
if exists(mask):
max_neg_value = -torch.finfo(qa_sim.dtype).max
ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value)
qa_attn = qa_sim.softmax(dim = -1)
ak_attn = ak_sim.softmax(dim = -1)
qa_attn = self.qa_dropout(qa_attn)
ak_attn = self.ak_dropout(ak_attn)
qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)
agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens)
if exists(mask):
out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)
if exists(self.to_gates):
out = out * self.to_gates(x)
out = self.to_out(out)
if not return_agent_tokens:
return out
return out, agent_gathered_tokens
.\lucidrains\agent-attention-pytorch\agent_attention_pytorch\agent_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.nn 中导入 Module 和 ModuleList
from torch.nn import Module, ModuleList
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 einops 中导入 rearrange, repeat, pack, unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange
# 定义函数
# 判断变量是否存在的函数
def exists(v):
return v is not None
# 归一化函数
# RMS 归一化类
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
# 前馈网络函数
# 前馈网络类
def FeedForward(dim, mult = 4):
dim_inner = int(dim * mult)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Linear(dim_inner, dim)
)
# 主类
# 自注意力机制类
class AgentSelfAttention(Module):
def __init__(
self,
dim,
*,
num_agent_tokens,
dim_head = 64,
heads = 8,
dropout = 0.,
talking_heads = True,
gate = True,
sub_layernorm = False
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
)
self.to_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if gate else None
self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.qa_dropout = nn.Dropout(dropout)
self.ak_dropout = nn.Dropout(dropout)
self.to_agent_out = nn.Sequential(
nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
self.to_out = nn.Sequential(
nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
def forward(
self,
x,
*,
agent_tokens,
mask = None,
return_agent_tokens = False
):
x = self.norm(x)
a = self.norm(agent_tokens)
x_and_agents, xa_ps = pack([a, x], 'b * d')
qkv = self.to_qkv(x_and_agents)
qkv_agent, qkv_input = unpack(qkv, xa_ps, 'qkv b h * d')
q, k, v = qkv_input
agent_queries, agent_keys, _ = qkv_agent
q = q * self.scale
agent_queries = agent_queries * self.scale
qa_sim = einsum('b h i d, b h j d -> b h i j', q, agent_keys)
ak_sim = einsum('b h i d, b h j d -> b h i j', agent_queries, k)
if exists(mask):
max_neg_value = -torch.finfo(qa_sim.dtype).max
ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value)
qa_attn = qa_sim.softmax(dim = -1)
ak_attn = ak_sim.softmax(dim = -1)
qa_attn = self.qa_dropout(qa_attn)
ak_attn = self.ak_dropout(ak_attn)
qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)
agent_out = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_out)
if exists(mask):
out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)
if exists(self.to_gates):
out = out * self.to_gates(x)
agent_out = agent_out * self.to_gates(a)
out = self.to_out(out)
agent_out = self.to_agent_out(agent_out)
if not return_agent_tokens:
return out
return out, agent_out
# 变换器类
# 变换器类
class AgentTransformer(Module):
def __init__(
self,
dim,
*,
num_agent_tokens,
depth,
heads = 8,
dim_head = 64,
ff_mult = 4,
final_norm = True,
**attn_kwargs: dict
):
super().__init__()
self.agent_tokens = nn.Parameter(torch.zeros(num_agent_tokens, dim))
nn.init.normal_(self.agent_tokens, std = 0.02)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
AgentSelfAttention(
dim = dim,
heads = heads,
dim_head = dim_head,
num_agent_tokens = num_agent_tokens,
**attn_kwargs
),
FeedForward(dim = dim, mult = ff_mult)
]))
self.final_norm = RMSNorm(dim) if final_norm else None
def forward(
self,
x,
mask = None,
return_agent_tokens = False
):
batch = x.shape[0]
a = repeat(self.agent_tokens, 'm d -> b m d', b = batch)
for attn, ff in self.layers:
attn_out, agent_out = attn(
x,
agent_tokens = a,
mask = mask,
return_agent_tokens = True
)
a = a + agent_out
x = x + attn_out
x, ps = pack([a, x], 'b * d')
x = ff(x) + x
a, x = unpack(x, ps, 'b * d')
if exists(self.final_norm):
x = self.final_norm(x)
a = self.final_norm(a)
if not return_agent_tokens:
return x
return x, a
.\lucidrains\agent-attention-pytorch\agent_attention_pytorch\__init__.py
# 从 agent_attention_pytorch 包中导入 AgentSelfAttention 类
from agent_attention_pytorch.agent_attention_pytorch import (
AgentSelfAttention
)
# 从 agent_attention_pytorch 包中导入 AgentTransformer 类
from agent_attention_pytorch.agent_transformer import (
AgentTransformer
)
Agent Attention - Pytorch
Implementation of Agent Attention in Pytorch.
This work seems to be an elegant simplification of ISAB
architecture from the Set Transformers paper (requires only one attention block rather than two). While ISAB works, I have found it to be a bit unstable, thus wondering if the simplification in this work resolves that issue.
This repository will add support for variable sequence lengths (masking) and post-softmax talking heads.
Appreciation
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install agent-attention-pytorch
Usage
import torch
from agent_attention_pytorch import AgentSelfAttention
attn = AgentSelfAttention(
dim = 512,
num_agent_tokens = 256, # number of "agent" tokens
dim_head = 64, # attention head dimension
heads = 8 # number of heads
)
x = torch.randn(2, 65536, 512)
mask = torch.ones(2, 65536).bool()
out = attn(x, mask = mask)
assert out.shape == x.shape
For a full fledged linear transformer based on agent tokens, just import AgentTransformer
import torch
from agent_attention_pytorch import AgentTransformer
transformer = AgentTransformer(
dim = 512,
depth = 6,
num_agent_tokens = 128,
dim_head = 64,
heads = 8
)
x = torch.randn(2, 65536, 512)
mask = torch.ones(2, 65536).bool()
out, agent_tokens = transformer(x, mask = mask)
# (2, 65536, 512), (2, 128, 512)
assert out.shape == x.shape
Citations
@inproceedings{Han2023AgentAO,
title = {Agent Attention: On the Integration of Softmax and Linear Attention},
author = {Dongchen Han and Tianzhu Ye and Yizeng Han and Zhuofan Xia and Shiji Song and Gao Huang},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266210414}
}
@misc{shazeer2020talkingheads,
title = {Talking-Heads Attention},
author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
year = {2020},
eprint = {2003.02436},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@article{Wang2022FoundationT,
title = {Foundation Transformers},
author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.06423},
url = {https://api.semanticscholar.org/CorpusID:252846241}
}
.\lucidrains\agent-attention-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'agent-attention-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.7', # 版本号
license='MIT', # 许可证
description = 'Agent Attention - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/agent-attention-pytorch', # URL
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'attention', # 关键词
'linear attention' # 关键词
],
install_requires=[
'einops>=0.7.0', # 安装所需的依赖
'torch>=2.0' # 安装所需的依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
.\lucidrains\all-normalization-transformer\all_normalization_transformer\all_normalization_transformer.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义累积均值函数
def cum_mean(t):
# 获取张量的设备信息
device = t.device
# 创建一个从 1 到张量最后一个维度大小的张量
running_num = torch.arange(t.shape[-1], device=t.device) + 1
# 返回累积和除以运行次数的结果
return t.cumsum(dim=-1) / running_num
# 定义归一化函数
def normalize(t, eps=1e-8):
# 减去均值
t -= t.mean(dim=-1, keepdim=True)
# 计算标准差
s = (t ** 2).mean(dim=-1, keepdim=True)
# 返回归一化结果
return t * torch.rsqrt(s + eps)
# 定义因果归一化函数
def causal_normalize(t, eps=1e-8):
# 减去因果均值
t -= cum_mean(t).diagonal(dim1=-2, dim2=-1)[..., None]
# 计算因果标准差
s = cum_mean(t ** 2).diagonal(dim1=-2, dim2=-1)[..., None]
# 返回因果归一化结果
return t * torch.rsqrt(s + eps)
# 定义残差模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 定义后归一化模块
class PostNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.fn(x)
return self.norm(x)
# 定义前归一化模块
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# 定义前馈神经网络模块
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
return self.net(x)
# 定义注意力模块
class Attention(nn.Module):
def __init__(self, dim, heads = 8, causal = False, shared_kv = False):
super().__init__()
self.causal = causal
self.heads = heads
self.scale = dim ** -0.5
self.shared_kv = shared_kv
self.num_qkv = 3 if not shared_kv else 2
self.to_qkv = nn.Linear(dim, dim * self.num_qkv, bias = False)
self.to_out = nn.Linear(dim, dim)
self.norm_g = nn.Parameter(torch.ones(1, heads, 1, 1))
self.norm_b = nn.Parameter(torch.zeros(1, heads, 1, 1))
def forward(self, x):
b, n, _, h, device = *x.shape, self.heads, x.device
qkv = self.to_qkv(x)
qkv = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = self.num_qkv, h = h)
if self.shared_kv:
q, k = qkv
v = k
else:
q, k, v = qkv
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if self.causal:
mask = torch.ones(n, n, device = device).triu_(1).bool()
dots.masked_fill_(mask, 0.)
normalize_fn = causal_normalize if self.causal else normalize
normed_attn = normalize_fn(dots)
attn = normed_attn * self.norm_g + self.norm_b
if self.causal:
attn.masked_fill_(mask, 0.)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
# 定义变压器模块
class Transformer(nn.Module):
def __init__(self, dim, depth, heads = 8, causal = False, only_norm = False, shared_kv = False):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PostNorm(dim, Attention(dim, heads, causal = causal, shared_kv = shared_kv))),
Residual(PreNorm(dim, FeedForward(dim))) if not only_norm else nn.Identity(),
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
# 定义变压器语言模型模块
class TransformerLM(nn.Module):
def __init__(self, *, num_tokens, dim, depth, max_seq_len, heads = 8, causal = False, only_norm = False, shared_kv = False):
super().__init__()
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.transformer = Transformer(dim, depth, heads, causal = causal, only_norm = only_norm, shared_kv = shared_kv)
self.to_logits = nn.Linear(dim, num_tokens)
def forward(self, x, **kwargs):
_, n = x.shape
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device=x.device))
x = self.transformer(x)
x = self.to_logits(x)
return x
.\lucidrains\all-normalization-transformer\all_normalization_transformer\autoregressive_wrapper.py
# 导入必要的库
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
# 定义一个函数,返回参数值或默认值
def default(value, default):
return value if value is not None else default
# 定义一个函数,计算输入张量的对数
def log(t, eps=1e-9):
return torch.log(t + eps)
# 从输入logits中选择概率最高的元素,直到累积概率超过阈值
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 1.0 - thres
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# 从输入logits中选择概率最高的K个元素
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归封装类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = None, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = default(ignore_index, pad_value)
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列
@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
input_mask = kwargs.pop('src_mask', None)
if input_mask is None:
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
input_mask = input_mask[:, -self.max_seq_len:]
logits = self.net(x, src_mask=input_mask, **kwargs)
logits = logits[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
gumbel_noise = -log(-log(torch.zeros_like(filtered_logits).uniform_(0, 1)))
sample = ((filtered_logits / temperature) + gumbel_noise).argmax(dim=-1)
out = torch.cat((out, sample[:, None]), dim=-1)
input_mask = F.pad(input_mask, (1, 0), value=True)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
# 前向传播函数
def forward(self, x, *args, **kwargs):
pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)
m = kwargs.pop('input_mask', None)
xi, xo = x[:, :-1], x[:, 1:]
if m is not None:
assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
kwargs.update(input_mask = m[:, :-1])
out = self.net(xi, *args, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\all-normalization-transformer\all_normalization_transformer\__init__.py
# 从 all_normalization_transformer 包中导入 TransformerLM 类
from all_normalization_transformer.all_normalization_transformer import TransformerLM
# 从 all_normalization_transformer 包中导入 AutoregressiveWrapper 类
from all_normalization_transformer.autoregressive_wrapper import AutoregressiveWrapper
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Transformer with Normalized Attention
A Transformer that consists of only normalization as its sole non-linearity, as proposed in the paper Normalized Attention Without Probability Cage. This repository will build on the paper's contributions and attempt to make it work for the auto-regressive case.
Update - It works. You can have an entire language model built on only matrix multiplies and normalization.
Pre-requisites
$ pip install -r requirements.txt
Train
$ python train_enwik8.py
Citations
@misc{richter2020normalized,
title={Normalized Attention Without Probability Cage},
author={Oliver Richter and Roger Wattenhofer},
year={2020},
eprint={2005.09561},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
.\lucidrains\all-normalization-transformer\train_enwik8.py
# 导入所需的模块
from all_normalization_transformer import TransformerLM
from all_normalization_transformer.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化模型
# 创建 TransformerLM 模型对象
model = TransformerLM(
num_tokens = 256,
dim = 512,
depth = 12,
max_seq_len = SEQ_LEN,
heads = 8,
causal = True,
only_norm = True,
shared_kv = True
)
# 将模型包装为 AutoregressiveWrapper
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU 上
model.cuda()
# 准备 enwik8 数据
# 从压缩文件中读取数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据集对象
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练集和验证集的数据加载器
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
inp = inp[:SEQ_LEN]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
.\lucidrains\alphafold2\alphafold2_pytorch\alphafold2.py
import torch
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from inspect import isfunction
from functools import partial
from dataclasses import dataclass
import torch.nn.functional as F
from math import sqrt
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from alphafold2_pytorch.utils import *
import alphafold2_pytorch.constants as constants
from alphafold2_pytorch.mlm import MLM
# structure module
from invariant_point_attention import IPABlock
from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
# constants
@dataclass
class Recyclables:
coords: torch.Tensor
single_msa_repr_row: torch.Tensor
pairwise_repr: torch.Tensor
@dataclass
class ReturnValues:
distance: torch.Tensor = None
theta: torch.Tensor = None
phi: torch.Tensor = None
omega: torch.Tensor = None
msa_mlm_loss: torch.Tensor = None
recyclables: Recyclables = None
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth
def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
# helper classes
class Always(nn.Module):
def __init__(self, val):
super().__init__()
self.val = val
def forward(self, x):
return self.val
# feed forward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
init_zero_(self.net[-1])
def forward(self, x, **kwargs):
x = self.norm(x)
return self.net(x)
# attention
class Attention(nn.Module):
def __init__(
self,
dim,
seq_len = None,
heads = 8,
dim_head = 64,
dropout = 0.,
gating = True
):
super().__init__()
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads= heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.gating = nn.Linear(dim, inner_dim)
nn.init.constant_(self.gating.weight, 0.)
nn.init.constant_(self.gating.bias, 1.)
self.dropout = nn.Dropout(dropout)
init_zero_(self.to_out)
def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None):
device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
i, j = q.shape[-2], k.shape[-2]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# query / key similarities
if exists(tie_dim):
# as in the paper, for the extra MSAs
# they average the queries along the rows of the MSAs
# they named this particular module MSAColumnGlobalAttention
q, k = map(lambda t: rearrange(t, '(b r) ... -> b r ...', r = tie_dim), (q, k))
q = q.mean(dim = 1)
dots = einsum('b h i d, b r h j d -> b r h i j', q, k)
dots = rearrange(dots, 'b r ... -> (b r) ...')
else:
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# add attention bias, if supplied (for pairwise to msa attention communication)
if exists(attn_bias):
dots = dots + attn_bias
# masking
if exists(mask):
mask = default(mask, lambda: torch.ones(1, i, device = device).bool())
context_mask = mask if not has_context else default(context_mask, lambda: torch.ones(1, k.shape[-2], device = device).bool())
mask_value = -torch.finfo(dots.dtype).max
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
dots = dots.masked_fill(~mask, mask_value)
# attention
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# gating
gates = self.gating(x)
out = out * gates.sigmoid()
# combine to out
out = self.to_out(out)
return out
class AxialAttention(nn.Module):
def __init__(
self,
dim,
heads,
row_attn = True,
col_attn = True,
accept_edges = False,
global_query_attn = False,
**kwargs
):
super().__init__()
assert not (not row_attn and not col_attn), 'row or column attention must be turned on'
self.row_attn = row_attn
self.col_attn = col_attn
self.global_query_attn = global_query_attn
self.norm = nn.LayerNorm(dim)
self.attn = Attention(dim = dim, heads = heads, **kwargs)
self.edges_to_attn_bias = nn.Sequential(
nn.Linear(dim, heads, bias = False),
Rearrange('b i j h -> b h i j')
) if accept_edges else None
def forward(self, x, edges = None, mask = None):
assert self.row_attn ^ self.col_attn, 'has to be either row or column attention, but not both'
b, h, w, d = x.shape
x = self.norm(x)
# axial attention
if self.col_attn:
axial_dim = w
mask_fold_axial_eq = 'b h w -> (b w) h'
input_fold_eq = 'b h w d -> (b w) h d'
output_fold_eq = '(b w) h d -> b h w d'
elif self.row_attn:
axial_dim = h
mask_fold_axial_eq = 'b h w -> (b h) w'
input_fold_eq = 'b h w d -> (b h) w d'
output_fold_eq = '(b h) w d -> b h w d'
x = rearrange(x, input_fold_eq)
if exists(mask):
mask = rearrange(mask, mask_fold_axial_eq)
attn_bias = None
if exists(self.edges_to_attn_bias) and exists(edges):
attn_bias = self.edges_to_attn_bias(edges)
attn_bias = repeat(attn_bias, 'b h i j -> (b x) h i j', x = axial_dim)
tie_dim = axial_dim if self.global_query_attn else None
out = self.attn(x, mask = mask, attn_bias = attn_bias, tie_dim = tie_dim)
out = rearrange(out, output_fold_eq, h = h, w = w)
return out
class TriangleMultiplicativeModule(nn.Module):
def __init__(
self,
*,
dim,
hidden_dim = None,
mix = 'ingoing'
# 初始化函数,继承父类的初始化方法
def __init__(
super().__init__()
# 断言混合参数只能是'ingoing'或'outgoing'
assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'
# 如果隐藏维度未指定,则默认为输入维度
hidden_dim = default(hidden_dim, dim)
# 对输入进行层归一化
self.norm = nn.LayerNorm(dim)
# 左投影层
self.left_proj = nn.Linear(dim, hidden_dim)
# 右投影层
self.right_proj = nn.Linear(dim, hidden_dim)
# 左门控层
self.left_gate = nn.Linear(dim, hidden_dim)
# 右门控层
self.right_gate = nn.Linear(dim, hidden_dim)
# 输出门控层
self.out_gate = nn.Linear(dim, hidden_dim)
# 初始化所有门控为恒等变换
for gate in (self.left_gate, self.right_gate, self.out_gate):
nn.init.constant_(gate.weight, 0.)
nn.init.constant_(gate.bias, 1.)
# 根据混合类型设置混合的乘积表达式
if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
elif mix == 'ingoing':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'
# 输出层归一化
self.to_out_norm = nn.LayerNorm(hidden_dim)
# 输出层线性变换
self.to_out = nn.Linear(hidden_dim, dim)
# 前向传播函数
def forward(self, x, mask = None):
# 断言特征图必须是对称的
assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
# 如果存在掩码,则重排掩码的维度
if exists(mask):
mask = rearrange(mask, 'b i j -> b i j ()')
# 对输入进行归一化
x = self.norm(x)
# 左投影
left = self.left_proj(x)
# 右投影
right = self.right_proj(x)
# 如果存在掩码,则将投影结果与掩码相乘
if exists(mask):
left = left * mask
right = right * mask
# 计算左门控
left_gate = self.left_gate(x).sigmoid()
# 计算右门控
right_gate = self.right_gate(x).sigmoid()
# 计算输出门控
out_gate = self.out_gate(x).sigmoid()
# 左投影结果与左门控相乘
left = left * left_gate
# 右投影结果与右门控相乘
right = right * right_gate
# 执行乘积操作
out = einsum(self.mix_einsum_eq, left, right)
# 输出结果归一化
out = self.to_out_norm(out)
# 输出结果与输出门控相乘
out = out * out_gate
# 返回输出结果
return self.to_out(out)
# 定义 OuterMean 类,用于计算两个输入的外积均值
class OuterMean(nn.Module):
def __init__(
self,
dim,
hidden_dim = None,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.norm = nn.LayerNorm(dim)
hidden_dim = default(hidden_dim, dim)
self.left_proj = nn.Linear(dim, hidden_dim)
self.right_proj = nn.Linear(dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, dim)
def forward(self, x, mask = None):
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
outer = rearrange(left, 'b m i d -> b m i () d') * rearrange(right, 'b m j d -> b m () j d')
if exists(mask):
# 如果存在 mask,则进行 masked mean 操作,用于处理 MSA 中的填充
mask = rearrange(mask, 'b m i -> b m i () ()') * rearrange(mask, 'b m j -> b m () j ()')
outer = outer.masked_fill(~mask, 0.)
outer = outer.mean(dim = 1) / (mask.sum(dim = 1) + self.eps)
else:
outer = outer.mean(dim = 1)
return self.proj_out(outer)
# 定义 PairwiseAttentionBlock 类,用于计算两个输入的注意力
class PairwiseAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout = 0.,
global_column_attn = False
):
super().__init__()
self.outer_mean = OuterMean(dim)
self.triangle_attention_outgoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
self.triangle_attention_ingoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True, accept_edges = True, global_query_attn = global_column_attn)
self.triangle_multiply_outgoing = TriangleMultiplicativeModule(dim = dim, mix = 'outgoing')
self.triangle_multiply_ingoing = TriangleMultiplicativeModule(dim = dim, mix = 'ingoing')
def forward(
self,
x,
mask = None,
msa_repr = None,
msa_mask = None
):
if exists(msa_repr):
x = x + self.outer_mean(msa_repr, mask = msa_mask)
x = self.triangle_multiply_outgoing(x, mask = mask) + x
x = self.triangle_multiply_ingoing(x, mask = mask) + x
x = self.triangle_attention_outgoing(x, edges = x, mask = mask) + x
x = self.triangle_attention_ingoing(x, edges = x, mask = mask) + x
return x
# 定义 MsaAttentionBlock 类,用于计��� MSA 的注意力
class MsaAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout = 0.
):
super().__init__()
self.row_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
self.col_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True)
def forward(
self,
x,
mask = None,
pairwise_repr = None
):
x = self.row_attn(x, mask = mask, edges = pairwise_repr) + x
x = self.col_attn(x, mask = mask) + x
return x
# 定义 EvoformerBlock 类,包含 PairwiseAttentionBlock、FeedForward 和 MsaAttentionBlock
class EvoformerBlock(nn.Module):
def __init__(
self,
*,
dim,
seq_len,
heads,
dim_head,
attn_dropout,
ff_dropout,
global_column_attn = False
):
super().__init__()
self.layer = nn.ModuleList([
PairwiseAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, global_column_attn = global_column_attn),
FeedForward(dim = dim, dropout = ff_dropout),
MsaAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim = dim, dropout = ff_dropout),
])
def forward(self, inputs):
x, m, mask, msa_mask = inputs
attn, ff, msa_attn, msa_ff = self.layer
# msa attention and transition
m = msa_attn(m, mask = msa_mask, pairwise_repr = x)
m = msa_ff(m) + m
# pairwise attention and transition
x = attn(x, mask = mask, msa_repr = m, msa_mask = msa_mask)
x = ff(x) + x
return x, m, mask, msa_mask
# 定义 Evoformer 类,包含多个 EvoformerBlock
class Evoformer(nn.Module):
def __init__(
self,
*,
depth,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([EvoformerBlock(**kwargs) for _ in range(depth)])
def forward(
self,
x,
m,
mask = None,
msa_mask = None
):
inp = (x, m, mask, msa_mask)
x, m, *_ = checkpoint_sequential(self.layers, 1, inp)
return x, m
# 定义 Alphafold2 类,包含各种模型参数和结构相关的参数
class Alphafold2(nn.Module):
def __init__(
self,
*,
dim,
max_seq_len = 2048,
depth = 6,
heads = 8,
dim_head = 64,
max_rel_dist = 32,
num_tokens = constants.NUM_AMINO_ACIDS,
num_embedds = constants.NUM_EMBEDDS_TR,
max_num_msas = constants.MAX_NUM_MSA,
max_num_templates = constants.MAX_NUM_TEMPLATES,
extra_msa_evoformer_layers = 4,
attn_dropout = 0.,
ff_dropout = 0.,
templates_dim = 32,
templates_embed_layers = 4,
templates_angles_feats_dim = 55,
predict_angles = False,
symmetrize_omega = False,
predict_coords = False, # structure module related keyword arguments below
structure_module_depth = 4,
structure_module_heads = 1,
structure_module_dim_head = 4,
disable_token_embed = False,
mlm_mask_prob = 0.15,
mlm_random_replace_token_prob = 0.1,
mlm_keep_token_same_prob = 0.1,
mlm_exclude_token_ids = (0,),
recycling_distance_buckets = 32
):
# 初始化函数,继承父类的初始化方法
super().__init__()
# 设置维度
self.dim = dim
# token embedding
# 创建一个词嵌入层,用于将词索引映射为向量表示,如果禁用了词嵌入,则使用常数0
self.token_emb = nn.Embedding(num_tokens + 1, dim) if not disable_token_embed else Always(0)
# 线性层,用于将维度转换为双倍
self.to_pairwise_repr = nn.Linear(dim, dim * 2)
# 是否禁用了词嵌入
self.disable_token_embed = disable_token_embed
# positional embedding
# 设置最大相对距离
self.max_rel_dist = max_rel_dist
# 创建一个位置嵌入层,用于将位置索引映射为向量表示
self.pos_emb = nn.Embedding(max_rel_dist * 2 + 1, dim)
# extra msa embedding
# 创建一个额外的多序列比对嵌入模块
self.extra_msa_evoformer = Evoformer(
dim = dim,
depth = extra_msa_evoformer_layers,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
global_column_attn = True
)
# template embedding
# 线性层,用于将模板维度转换为指定维度
self.to_template_embed = nn.Linear(templates_dim, dim)
self.templates_embed_layers = templates_embed_layers
# 模板对注意力块
self.template_pairwise_embedder = PairwiseAttentionBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
seq_len = max_seq_len
)
# 模板点注意力
self.template_pointwise_attn = Attention(
dim = dim,
dim_head = dim_head,
heads = heads,
dropout = attn_dropout
)
# 模板角度MLP
self.template_angle_mlp = nn.Sequential(
nn.Linear(templates_angles_feats_dim, dim),
nn.GELU(),
nn.Linear(dim, dim)
)
# projection for angles, if needed
# 是否需要预测角度
self.predict_angles = predict_angles
self.symmetrize_omega = symmetrize_omega
if predict_angles:
# 线性层,用于将维度转换为角度桶的数量
self.to_prob_theta = nn.Linear(dim, constants.THETA_BUCKETS)
self.to_prob_phi = nn.Linear(dim, constants.PHI_BUCKETS)
self.to_prob_omega = nn.Linear(dim, constants.OMEGA_BUCKETS)
# custom embedding projection
# 自定义嵌入投影
self.embedd_project = nn.Linear(num_embedds, dim)
# main trunk modules
# 主干模块
self.net = Evoformer(
dim = dim,
depth = depth,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
# MSA SSL MLM
# 多序列比对自监督MLM
self.mlm = MLM(
dim = dim,
num_tokens = num_tokens,
mask_id = num_tokens, # 最后一个嵌入的标记用于掩码
mask_prob = mlm_mask_prob,
keep_token_same_prob = mlm_keep_token_same_prob,
random_replace_token_prob = mlm_random_replace_token_prob,
exclude_token_ids = mlm_exclude_token_ids
)
# calculate distogram logits
# 计算距离图的logits
self.to_distogram_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, constants.DISTOGRAM_BUCKETS)
)
# to coordinate output
# 是否预测坐标
self.predict_coords = predict_coords
self.structure_module_depth = structure_module_depth
self.msa_to_single_repr_dim = nn.Linear(dim, dim)
self.trunk_to_pairwise_repr_dim = nn.Linear(dim, dim)
with torch_default_dtype(torch.float32):
# IPA块
self.ipa_block = IPABlock(
dim = dim,
heads = structure_module_heads,
)
self.to_quaternion_update = nn.Linear(dim, 6)
init_zero_(self.ipa_block.attn.to_out)
self.to_points = nn.Linear(dim, 3)
# aux confidence measure
# 辅助置信度测量
self.lddt_linear = nn.Linear(dim, 1)
# recycling params
# 回收参数
self.recycling_msa_norm = nn.LayerNorm(dim)
self.recycling_pairwise_norm = nn.LayerNorm(dim)
self.recycling_distance_embed = nn.Embedding(recycling_distance_buckets, dim)
self.recycling_distance_buckets = recycling_distance_buckets
def forward(
self,
seq,
msa = None,
mask = None,
msa_mask = None,
extra_msa = None,
extra_msa_mask = None,
seq_index = None,
seq_embed = None,
msa_embed = None,
templates_feats = None,
templates_mask = None,
templates_angles = None,
embedds = None,
recyclables = None,
return_trunk = False,
return_confidence = False,
return_recyclables = False,
return_aux_logits = False
.\lucidrains\alphafold2\alphafold2_pytorch\constants.py
import torch
# 定义常量
MAX_NUM_MSA = 20
MAX_NUM_TEMPLATES = 10
NUM_AMINO_ACIDS = 21
NUM_EMBEDDS_TR = 1280 # 最佳 esm 模型
NUM_EMBEDDS_T5 = 1024 # 最佳 t5 模型
NUM_COORDS_PER_RES = 14
DISTOGRAM_BUCKETS = 37
THETA_BUCKETS = 25
PHI_BUCKETS = 13
OMEGA_BUCKETS = 25
# 与嵌入相关的常量
MSA_EMBED_DIM = 768
MSA_MODEL_PATH = ["facebookresearch/esm", "esm_msa1_t12_100M_UR50S"]
ESM_EMBED_DIM = 1280
ESM_MODEL_PATH = ["facebookresearch/esm", "esm1b_t33_650M_UR50S"]
PROTTRAN_EMBED_DIM = 1024
# 默认设备
DEVICE_NAME = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = torch.device(DEVICE_NAME)
# 氨基酸数据
AA_DATA = {
'A': {
'bonds': [[0,1], [1,2], [2,3], [1,4]]
},
'R': {
'bonds': [[0,1], [1,2], [2,3], [2,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [8,10]]
},
'N': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[5,7]]
},
'D': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[5,7]]
},
'C': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]]
},
'Q': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [6,8]]
},
'E': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8]]
},
'G': {
'bonds': [[0,1], [1,2], [2,3]]
},
'H': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [5,9]]
},
'I': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[4,7]]
},
'L': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[5,7]]
},
'K': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8]]
},
'M': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7]]
},
'F': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [9,10], [5,10]]
},
'P': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[0,6]]
},
'S': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]]
},
'T': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]]
},
'W': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [9,10], [10,11], [11,12],
[12, 13], [5,13], [8,13]]
},
'Y': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [8,10], [10,11], [5,11]]
},
'V': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]]
},
'_': {
'bonds': []
}
}
.\lucidrains\alphafold2\alphafold2_pytorch\embeds.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn 模块
from torch import nn
# 从 alphafold2_pytorch.utils 中导入 get_msa_embedd, get_esm_embedd, get_prottran_embedd, exists 函数
from alphafold2_pytorch.utils import get_msa_embedd, get_esm_embedd, get_prottran_embedd, exists
# 从 alphafold2_pytorch.constants 中导入 MSA_MODEL_PATH, MSA_EMBED_DIM, ESM_MODEL_PATH, ESM_EMBED_DIM, PROTTRAN_EMBED_DIM 常量
from alphafold2_pytorch.constants import MSA_MODEL_PATH, MSA_EMBED_DIM, ESM_MODEL_PATH, ESM_EMBED_DIM, PROTTRAN_EMBED_DIM
# 从 einops 中导入 rearrange 函数
from einops import rearrange
# 定义 ProtTranEmbedWrapper 类,继承自 nn.Module
class ProtTranEmbedWrapper(nn.Module):
# 初始化函数
def __init__(self, *, alphafold2):
super().__init__()
# 从 transformers 中导入 AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel
# 初始化属性 alphafold2
self.alphafold2 = alphafold2
# 创建线性层,用于将 PROTTRAN_EMBED_DIM 维度的数据映射到 alphafold2.dim 维度
self.project_embed = nn.Linear(PROTTRAN_EMBED_DIM, alphafold2.dim)
# 使用 'Rostlab/prot_bert' 模型初始化 tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert', do_lower_case=False)
# 使用 'Rostlab/prot_bert' 模型初始化 model
self.model = AutoModel.from_pretrained('Rostlab/prot_bert')
# 前向传播函数
def forward(self, seq, msa, msa_mask = None, **kwargs):
# 获取设备信息
device = seq.device
# 获取 msa 的数量
num_msa = msa.shape[1]
# 将 msa 展平
msa_flat = rearrange(msa, 'b m n -> (b m) n')
# 获取序列的 PROTTRAN 嵌入
seq_embed = get_prottran_embedd(seq, self.model, self.tokenizer, device = device)
# 获取 msa 的 PROTTRAN 嵌入
msa_embed = get_prottran_embedd(msa_flat, self.model, self.tokenizer, device = device)
# 将序列和 msa 的嵌入映射到指定维度
seq_embed, msa_embed = map(self.project_embed, (seq_embed, msa_embed))
# 重新排列 msa_embed 的维度
msa_embed = rearrange(msa_embed, '(b m) n d -> b m n d', m = num_msa)
# 调用 alphafold2 模型进行预测
return self.alphafold2(seq, msa, seq_embed = seq_embed, msa_embed = msa_embed, msa_mask = msa_mask, **kwargs)
# 定义 MSAEmbedWrapper 类,继承自 nn.Module
class MSAEmbedWrapper(nn.Module):
# 初始化函数
def __init__(self, *, alphafold2):
super().__init__()
# 初始化属性 alphafold2
self.alphafold2 = alphafold2
# 加载 MSA 模型和字母表
model, alphabet = torch.hub.load(*MSA_MODEL_PATH)
batch_converter = alphabet.get_batch_converter()
# 初始化 model, batch_converter, project_embed 属性
self.model = model
self.batch_converter = batch_converter
self.project_embed = nn.Linear(MSA_EMBED_DIM, alphafold2.dim) if MSA_EMBED_DIM != alphafold2.dim else nn.Identity()
# 前向传播函数
def forward(self, seq, msa, msa_mask = None, **kwargs):
# 断言序列和 msa 的长度相同
assert seq.shape[-1] == msa.shape[-1], 'sequence and msa must have the same length if you wish to use MSA transformer embeddings'
# 获取 model, batch_converter, device 信息
model, batch_converter, device = self.model, self.batch_converter, seq.device
# 将序列和 msa 连接
seq_and_msa = torch.cat((seq.unsqueeze(1), msa), dim = 1)
if exists(msa_mask):
# 处理 MSA 中完全填充的行
num_msa = msa_mask.any(dim = -1).sum(dim = -1).tolist()
seq_and_msa_list = seq_and_msa.unbind(dim = 0)
num_rows = seq_and_msa.shape[1]
embeds = []
for num, batch_el in zip(num_msa, seq_and_msa_list):
batch_el = rearrange(batch_el, '... -> () ...')
batch_el = batch_el[:, :num]
embed = get_msa_embedd(batch_el, model, batch_converter, device = device)
embed = F.pad(embed, (0, 0, 0, 0, 0, num_rows - num), value = 0.)
embeds.append(embed)
embeds = torch.cat(embeds, dim = 0)
else:
embeds = get_msa_embedd(seq_and_msa, model, batch_converter, device = device)
# 映射嵌入到指定维度
embeds = self.project_embed(embeds)
seq_embed, msa_embed = embeds[:, 0], embeds[:, 1:]
# 调用 alphafold2 模型进行预测
return self.alphafold2(seq, msa, seq_embed = seq_embed, msa_embed = msa_embed, msa_mask = msa_mask, **kwargs)
# 定义 ESMEmbedWrapper 类,继承自 nn.Module
class ESMEmbedWrapper(nn.Module):
# 初始化函数
def __init__(self, *, alphafold2):
super().__init__()
# 初始化属性 alphafold2
self.alphafold2 = alphafold2
# 加载 ESM 模型和字母表
model, alphabet = torch.hub.load(*ESM_MODEL_PATH)
batch_converter = alphabet.get_batch_converter()
# 初始化 model, batch_converter, project_embed 属性
self.model = model
self.batch_converter = batch_converter
self.project_embed = nn.Linear(ESM_EMBED_DIM, alphafold2.dim) if ESM_EMBED_DIM != alphafold2.dim else nn.Identity()
# 前向传播函数
def forward(self, seq, msa=None, **kwargs):
# 获取 model, batch_converter, device 信息
model, batch_converter, device = self.model, self.batch_converter, seq.device
# 获取序列的 ESM 嵌入
seq_embeds = get_esm_embedd(seq, model, batch_converter, device = device)
seq_embeds = self.project_embed(seq_embeds)
if msa is not None:
# 将 msa 展平
flat_msa = rearrange(msa, 'b m n -> (b m) n')
# 获取 msa 的 ESM 嵌入
msa_embeds = get_esm_embedd(flat_msa, model, batch_converter, device = device)
msa_embeds = rearrange(msa_embeds, '(b m) n d -> b m n d')
msa_embeds = self.project_embed(msa_embeds)
else:
msa_embeds = None
# 调用 alphafold2 模型进行预测
return self.alphafold2(seq, msa, seq_embed = seq_embeds, msa_embed = msa_embeds, **kwargs)
.\lucidrains\alphafold2\alphafold2_pytorch\mlm.py
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from alphafold2_pytorch import constants
from einops import rearrange
# 导入所需的库和模块
# MSA MLM
# 定义函数,根据给定的掩码和概率获取子集掩码
def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)
num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
# 定义 MLM 类
class MLM(nn.Module):
def __init__(
self,
dim,
num_tokens,
mask_id,
mask_prob = 0.15,
random_replace_token_prob = 0.1,
keep_token_same_prob = 0.1,
exclude_token_ids = (0,)
):
super().__init__()
self.to_logits = nn.Linear(dim, num_tokens)
self.mask_id = mask_id
self.mask_prob = mask_prob
self.exclude_token_ids = exclude_token_ids
self.keep_token_same_prob = keep_token_same_prob
self.random_replace_token_prob = random_replace_token_prob
# 对序列进行噪声处理
def noise(self, seq, mask):
num_msa = seq.shape[1]
seq = rearrange(seq, 'b n ... -> (b n) ...')
mask = rearrange(mask, 'b n ... -> (b n) ...')
# 准备用于处理序列的掩码
excluded_tokens_mask = mask
for token_id in self.exclude_token_ids:
excluded_tokens_mask = excluded_tokens_mask & (seq != token_id)
mlm_mask = get_mask_subset_with_prob(excluded_tokens_mask, self.mask_prob)
# 保持一些标记不变
replace_token_with_mask = get_mask_subset_with_prob(mlm_mask, 1. - self.keep_token_same_prob)
# 用掩码替换
seq = seq.masked_fill(mlm_mask, self.mask_id)
# 生成随机标记
random_replace_token_prob_mask = get_mask_subset_with_prob(mlm_mask, (1 - self.keep_token_same_prob) * self.random_replace_token_prob)
random_tokens = torch.randint(1, constants.NUM_AMINO_ACIDS, seq.shape).to(seq.device)
for token_id in self.exclude_token_ids:
random_replace_token_prob_mask = random_replace_token_prob_mask & (random_tokens != token_id) # 确保永远不会用排除的标记类型(填充、开始、结束)替换标记
# 噪声序列
noised_seq = torch.where(random_replace_token_prob_mask, random_tokens, seq)
noised_seq = rearrange(noised_seq, '(b n) ... -> b n ...', n = num_msa)
mlm_mask = rearrange(mlm_mask, '(b n) ... -> b n ...', n = num_msa)
return noised_seq, mlm_mask
# 前向传播函数
def forward(self, seq_embed, original_seq, mask):
logits = self.to_logits(seq_embed)
seq_logits = logits[mask]
seq_labels = original_seq[mask]
loss = F.cross_entropy(seq_logits, seq_labels, reduction = 'mean')
return loss
.\lucidrains\alphafold2\alphafold2_pytorch\reversible.py
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
from contextlib import contextmanager
from einops import reduce
# helpers
# 检查值是否存在
def exists(val):
return val is not None
# 上下文管理器,不执行任何操作
@contextmanager
def null_context():
yield
# 在指定维度上按索引分割张量
def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
# 用于反向传播确定性的函数包装器
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
# 记录随机数生成器状态
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 可逆自注意力块
class ReversibleSelfAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_pos_emb = None, msa_pos_emb = None, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with context():
y1 = x1 + self.f(x2, shape = seq_shape, record_rng = record_rng, mask = mask, rotary_emb = seq_pos_emb)
y2 = x2 + self.g(y1, shape = seq_shape, record_rng = record_rng)
n1 = m1 + self.j(m2, shape = msa_shape, record_rng = record_rng, mask = msa_mask, rotary_emb = msa_pos_emb)
n2 = m2 + self.k(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
def backward_pass(self, y, n, dy, dn, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_pos_emb = None, msa_pos_emb = None, **kwargs):
y1, y2 = torch.chunk(y, 2, dim = 2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, shape = seq_shape, set_rng = True)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, shape = seq_shape, set_rng = True, mask = mask, rotary_emb = seq_pos_emb)
torch.autograd.backward(fx2, dx1, retain_graph = True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
n1, n2 = torch.chunk(n, 2, dim = 2)
del n
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
del dn
with torch.enable_grad():
n1.requires_grad = True
gn1 = self.k(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
with torch.no_grad():
m2 = n2 - gn1
del n2, gn1
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
with torch.enable_grad():
m2.requires_grad = True
fm2 = self.j(m2, shape = msa_shape, set_rng = True, mask = msa_mask, rotary_emb = msa_pos_emb)
torch.autograd.backward(fm2, dm1, retain_graph=True)
with torch.no_grad():
m1 = n1 - fm2
del n1, fm2
dm2 = dn2 + m2.grad
del dn2
m2.grad = None
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
return x, m, dx, dm
# 可逆交叉注意力块
class ReversibleCrossAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_to_msa_pos_emb = None, msa_to_seq_pos_emb = None, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with context():
y1 = x1 + self.f(x2, m2, record_rng = record_rng, mask = mask, context_mask = msa_mask, shape = seq_shape, context_shape = msa_shape, rotary_emb = seq_to_msa_pos_emb)
y2 = x2 + self.k(y1, shape = seq_shape, record_rng = record_rng)
n1 = m1 + self.j(m2, y2, record_rng = record_rng, mask = msa_mask, context_mask = mask, shape = msa_shape, context_shape = seq_shape, rotary_emb = msa_to_seq_pos_emb)
n2 = m2 + self.g(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
# 反向传播函数,计算梯度并更新参数
def backward_pass(self, y, n, dy, dn, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_to_msa_pos_emb = None, msa_to_seq_pos_emb = None, **kwargs):
# 将输入张量 n 按照第二维度分成两部分
n1, n2 = torch.chunk(n, 2, dim = 2)
# 释放 n 张量的内存
del n
# 将输入张量 dn 按照第二维度分成两部分
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
# 释放 dn 张量的内存
del dn
# 将输入张量 y 按照第二维度分成两部分
y1, y2 = torch.chunk(y, 2, dim = 2)
# 释放 y 张量的内存
del y
# 将输入张量 dy 按照第二维度分成两部分
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
# 释放 dy 张量的内存
del dy
# 开启梯度计算
with torch.enable_grad():
# 设置 n1 张量需要计算梯度
n1.requires_grad = True
# 使用函数 g 计算 gn1,并进行反向传播
gn1 = self.g(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
# 关闭梯度计算
with torch.no_grad():
# 计算 m2,并释放 n2 和 gn1 张量的内存
m2 = n2 - gn1
del n2, gn1
# 计算 dm1,并释放 dn1 张量的内存
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
# 开启梯度计算
with torch.enable_grad():
# 设置 m2 和 y2 张量需要计算梯度
m2.requires_grad = True
y2.requires_grad = True
# 使用函数 j 计算 fm2,并进行反向传播
fm2 = self.j(m2, y2, set_rng=True, mask = msa_mask, context_mask = mask, shape = msa_shape, context_shape = seq_shape, rotary_emb = msa_to_seq_pos_emb)
torch.autograd.backward(fm2, dm1)
# 关闭梯度计算
with torch.no_grad():
# 计算 m1,并释放 n1 和 fm2 张量的内存
m1 = n1 - fm2
del n1, fm2
# 计算 dm2 和 dx2,并释放 dn2 和 dy2 张量的内存
dm2 = dn2 + m2.grad
dx2 = dy2 + y2.grad
del dn2
del dy2
m2.grad = None
y2.grad = None
# 开启梯度计算
with torch.enable_grad():
# 设置 y1 需要计算梯度
y1.requires_grad = True
# 使用函数 k 计算 gy1,并进行反向传播
gy1 = self.k(y1, shape = seq_shape, set_rng = True)
torch.autograd.backward(gy1, dx2)
# 关闭梯度计算
with torch.no_grad():
# 计算 x2,并释放 y2 和 gy1 张量的内存
x2 = y2 - gy1
del y2, gy1
# 计算 dx1,并释放 dy1 张量的内存
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
# 开启梯度计算
with torch.enable_grad():
# 设置 x2 和 m2 需要计算梯度
x2.requires_grad = True
m2.requires_grad = True
# 使用函数 f 计算 fx2,并进行反向传播
fx2 = self.f(x2, m2, set_rng = True, mask = mask, context_mask = msa_mask, shape = seq_shape, context_shape = msa_shape, rotary_emb = seq_to_msa_pos_emb)
torch.autograd.backward(fx2, dx1)
# 关闭梯度计算
with torch.no_grad():
# 计算 x1,并释放 y1 和 fx2 张量的内存
x1 = y1 - fx2
del y1, fx2
# 更新 dx2 和 dm2,并释放 x2 和 m2 的梯度
dx2 = dx2 + x2.grad
dm2 = dm2 + m2.grad
x2.grad = None
m2.grad = None
# 关闭梯度计算
with torch.no_grad():
# 拼接 m1 和 m2,释放 m1 和 m2 的梯度
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
# 拼接 x1 和 x2,释放 x1 和 x2 的梯度
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
# 返回更新后的张量和梯度
return x, m, dx, dm
# 定义可逆和不可逆函数
class ReversibleFunction(Function):
@staticmethod
def forward(ctx, inp, ind, blocks, kwargs):
# 将输入按照指定索引分割成两部分
x, m = split_at_index(1, ind, inp)
# 对每个块进行反向操作
for block in blocks:
x, m = block(x, m, _reverse = True, **kwargs)
# 保存上下文信息
ctx.blocks = blocks
ctx.kwargs = kwargs
ctx.ind = ind
ctx.save_for_backward(x.detach(), m.detach())
return torch.cat((x, m), dim = 1)
@staticmethod
def backward(ctx, d):
ind = ctx.ind
blocks = ctx.blocks
kwargs = ctx.kwargs
dy, dn = split_at_index(1, ind, d)
y, n = ctx.saved_tensors
# 对每个块进行反向传播
for block in blocks[::-1]:
y, n, dy, dn = block.backward_pass(y, n, dy, dn, **kwargs)
d = torch.cat((dy, dn), dim = 1)
return d, None, None, None
reversible_apply = ReversibleFunction.apply
def irreversible_apply(inputs, ind, blocks, kwargs):
# 将输入按照指定索引分割成两部分
x, m = split_at_index(1, ind, inputs)
for block in blocks:
x, m = block(x, m, _reverse = False, **kwargs)
return torch.cat((x, m), dim = 1)
# 主要的可逆序列类
class ReversibleSequence(nn.Module):
def __init__(self, input_blocks, block_types):
super().__init__()
self.block_types = block_types
blocks = nn.ModuleList([])
for block, block_type in zip(input_blocks, block_types):
if block_type == 'self':
reversible_klass = ReversibleSelfAttnBlock
elif block_type == 'cross':
reversible_klass = ReversibleCrossAttnBlock
elif block_type == 'conv':
reversible_klass = ReversibleSelfAttnBlock
blocks.append(reversible_klass(*block))
self.blocks = blocks
def forward(
self,
seq,
msa,
seq_shape = None,
msa_shape = None,
mask = None,
msa_mask = None,
seq_pos_emb = None,
msa_pos_emb = None,
seq_to_msa_pos_emb = None,
msa_to_seq_pos_emb = None,
reverse = True
):
assert exists(msa), 'reversibility does not work with no MSA sequences yet'
blocks = self.blocks
# 将序列和多序列对齐数据拼接在一起
seq, msa = list(map(lambda t: torch.cat((t, t), dim = -1), (seq, msa)))
kwargs = {'mask': mask, 'msa_mask': msa_mask, 'seq_shape': seq_shape, 'msa_shape': msa_shape, 'seq_pos_emb': seq_pos_emb, 'msa_pos_emb': msa_pos_emb, 'seq_to_msa_pos_emb': seq_to_msa_pos_emb, 'msa_to_seq_pos_emb': msa_to_seq_pos_emb}
fn = reversible_apply if reverse else irreversible_apply
ind = seq.shape[1]
inp = torch.cat((seq, msa), dim = 1)
out = fn(inp, ind, blocks, kwargs)
seq, msa = split_at_index(1, ind, out)
return list(map(lambda t: reduce(t, 'b n (c d) -> b n d', 'mean', c = 2), (seq, msa)))
.\lucidrains\alphafold2\alphafold2_pytorch\rotary.py
# 从 math 模块中导入 log, sqrt, pi 函数
# 导入 torch 模块
# 从 torch 模块中导入 nn, einsum 函数
# 从 einops 模块中导入 rearrange, repeat 函数
# 旋转嵌入的辅助函数
def rotate_every_two(x):
# 重新排列张量 x 的维度,将最后一个维度拆分成两个维度
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
# 将 x1, x2 进行旋转操作
x = torch.stack((-x2, x1), dim = -1)
# 重新排列张量 x 的维度
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sinu_pos):
# 将 sinu_pos 中的 sin, cos 重新排列维度
sin, cos = map(lambda t: rearrange(t, 'b ... -> b () ...'), sinu_pos)
rot_dim = sin.shape[-1]
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# 应用旋转位置嵌入
x = x * cos + rotate_every_two(x) * sin
# 拼接处理后的张量
return torch.cat((x, x_pass), dim = -1)
# 位置嵌入
class DepthWiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True, groups = None):
super().__init__()
groups = default(groups, dim_in)
# 定义深度卷积网络
self.net = nn.Sequential(
nn.Conv1d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, bias = bias),
nn.Conv1d(dim_in, dim_out, 1, bias = bias)
)
def forward(self, x):
# 前向传播
return self.net(x)
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
# 计算固定位置嵌入的频率
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, n, device):
# 生成序列
seq = torch.arange(n, device = device).type_as(self.inv_freq)
# 计算频率
freqs = einsum('i , j -> i j', seq, self.inv_freq)
freqs = repeat(freqs, 'i j -> () i (j r)', r = 2)
return [freqs.sin(), freqs.cos()]
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim // 2
# 计算轴向旋转嵌入的频率
inv_freq = 1. / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, n, device):
# 生成序列
seq = torch.arange(n, device = device).type_as(self.inv_freq)
x = einsum('n, d -> n d', seq, self.inv_freq)
y = einsum('n, d -> n d', seq, self.inv_freq)
x_sinu = repeat(x, 'i d -> i j d', j = n)
y_sinu = repeat(y, 'j d -> i j d', i = n)
# 计算 sin, cos
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: repeat(t, 'i j d -> () (i j) (d r)', r = 2), (sin, cos))
return [sin, cos]
.\lucidrains\alphafold2\alphafold2_pytorch\utils.py
# 导入必要的库
import os
import re
import numpy as np
import torch
import contextlib
from functools import wraps
from einops import rearrange, repeat
# import torch_sparse # only needed for sparse nth_deg adj calculation
# 导入生物信息学相关库
from Bio import SeqIO
import itertools
import string
# 导入sidechainnet相关库
from sidechainnet.utils.sequence import ProteinVocabulary, ONE_TO_THREE_LETTER_MAP
from sidechainnet.utils.measure import GLOBAL_PAD_CHAR
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES, BB_BUILD_INFO, SC_BUILD_INFO
from sidechainnet.structure.StructureBuilder import _get_residue_build_iter
# 导入自定义库
import mp_nerf
# 构建蛋白质词汇表
VOCAB = ProteinVocabulary()
# 常量
import alphafold2_pytorch.constants as constants
# 辅助函数
def exists(val):
return val is not None
# 常量:与alphafold2.py中相同
DISTANCE_THRESHOLDS = torch.linspace(2, 20, steps = constants.DISTOGRAM_BUCKETS)
# 距离分箱函数
def get_bucketed_distance_matrix(coords, mask, num_buckets = constants.DISTOGRAM_BUCKETS, ignore_index = -100):
distances = torch.cdist(coords, coords, p=2)
boundaries = torch.linspace(2, 20, steps = num_buckets, device = coords.device)
discretized_distances = torch.bucketize(distances, boundaries[:-1])
discretized_distances.masked_fill_(~(mask[..., None] & mask[..., None, :]), ignore_index)
return discretized_distances
# 装饰器
def set_backend_kwarg(fn):
@wraps(fn)
def inner(*args, backend = 'auto', **kwargs):
if backend == 'auto':
backend = 'torch' if isinstance(args[0], torch.Tensor) else 'numpy'
kwargs.update(backend = backend)
return fn(*args, **kwargs)
return inner
def expand_dims_to(t, length = 3):
if length == 0:
return t
return t.reshape(*((1,) * length), *t.shape) # will work with both torch and numpy
def expand_arg_dims(dim_len = 3):
""" pack here for reuse.
turns input into (B x D x N)
"""
def outer(fn):
@wraps(fn)
def inner(x, y, **kwargs):
assert len(x.shape) == len(y.shape), "Shapes of A and B must match."
remaining_len = dim_len - len(x.shape)
x = expand_dims_to(x, length = remaining_len)
y = expand_dims_to(y, length = remaining_len)
return fn(x, y, **kwargs)
return inner
return outer
def invoke_torch_or_numpy(torch_fn, numpy_fn):
def outer(fn):
@wraps(fn)
def inner(*args, **kwargs):
backend = kwargs.pop('backend')
passed_args = fn(*args, **kwargs)
passed_args = list(passed_args)
if isinstance(passed_args[-1], dict):
passed_kwargs = passed_args.pop()
else:
passed_kwargs = {}
backend_fn = torch_fn if backend == 'torch' else numpy_fn
return backend_fn(*passed_args, **passed_kwargs)
return inner
return outer
@contextlib.contextmanager
def torch_default_dtype(dtype):
prev_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(prev_dtype)
# 预处理数据
def get_atom_ids_dict():
""" 获取将每个原子映射到令牌的字典。 """
ids = set(["", "N", "CA", "C", "O"])
for k,v in SC_BUILD_INFO.items():
for name in v["atom-names"]:
ids.add(name)
return {k: i for i,k in enumerate(sorted(ids))}
def make_cloud_mask(aa):
""" 重要点为1,填充点为0。 """
mask = np.zeros(constants.NUM_COORDS_PER_RES)
# 如果是填充令牌,则提前停止
if aa == "_":
return mask
# 获取aa中的原子数
n_atoms = 4+len( SC_BUILD_INFO[ ONE_TO_THREE_LETTER_MAP[aa] ]["atom-names"] )
mask[:n_atoms] = 1
return mask
def make_atom_id_embedds(aa, atom_ids):
""" 返回aa中每个原子的令牌。 """
mask = np.zeros(constants.NUM_COORDS_PER_RES)
# 如果当前氨基酸是填充标记"_", 则直接返回掩码
if aa == "_":
return mask
# 获取氨基酸的原子ID
atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[ ONE_TO_THREE_LETTER_MAP[aa] ]["atom-names"]
# 遍历原子列表,获取每个原子对应的ID,并存储到掩码中
for i,atom in enumerate(atom_list):
mask[i] = ATOM_IDS[atom]
# 返回更新后的掩码
return mask
# 获取原子ID字典
ATOM_IDS = get_atom_ids_dict()
# 创建自定义信息字典,包括云掩码和原子ID嵌入
CUSTOM_INFO = {k: {"cloud_mask": make_cloud_mask(k),
"atom_id_embedd": make_atom_id_embedds(k, atom_ids=ATOM_IDS),
} for k in "ARNDCQEGHILKMFPSTWYV_"}
# 常用工具
# 从RCSB PDB下载PDB条目
def download_pdb(name, route):
""" Downloads a PDB entry from the RCSB PDB.
Inputs:
* name: str. the PDB entry id. 4 characters, capitalized.
* route: str. route of the destin file. usually ".pdb" extension
Output: route of destin file
"""
os.system(f"curl https://files.rcsb.org/download/{name}.pdb > {route}")
return route
# 清理PDB结构,只保留重要部分
def clean_pdb(name, route=None, chain_num=None):
""" Cleans the structure to only leave the important part.
Inputs:
* name: str. route of the input .pdb file
* route: str. route of the output. will overwrite input if not provided
* chain_num: int. index of chain to select (1-indexed as pdb files)
Output: route of destin file.
"""
import mdtraj
destin = route if route is not None else name
# 读取输入
raw_prot = mdtraj.load_pdb(name)
# 遍历蛋白质并选择指定的链
idxs = []
for chain in raw_prot.topology.chains:
# 如果传递了参数,只选择该链
if chain_num is not None:
if chain_num != chain.index:
continue
# 选择链的索引
chain_idxs = raw_prot.topology.select(f"chainid == {str(chain.index)}")
idxs.extend( chain_idxs.tolist() )
# 排序:拓扑和xyz选择是有序的
idxs = sorted(idxs)
# 从选择的索引子集获取新的轨迹并保存
prot = mdtraj.Trajectory(xyz=raw_prot.xyz[:, idxs],
topology=raw_prot.topology.subset(idxs))
prot.save(destin)
return destin
# 将自定义表示转换为.pdb文件
def custom2pdb(coords, proteinnet_id, route):
""" Takes a custom representation and turns into a .pdb file.
Inputs:
* coords: array/tensor of shape (3 x N) or (N x 3). in Angstroms.
same order as in the proteinnnet is assumed (same as raw pdb file)
* proteinnet_id: str. proteinnet id format (<class>#<pdb_id>_<chain_number>_<chain_id>)
see: https://github.com/aqlaboratory/proteinnet/
* route: str. destin route.
Output: tuple of routes: (original, generated) for the structures.
"""
import mdtraj
# 转换为numpy
if isinstance(coords, torch.Tensor):
coords = coords.detach().cpu().numpy()
# 确保(1, N, 3)
if coords.shape[1] == 3:
coords = coords.T
coords = np.newaxis(coords, axis=0)
# 获取pdb id和链号
pdb_name, chain_num = proteinnet_id.split("#")[-1].split("_")[:-1]
pdb_destin = "/".join(route.split("/")[:-1])+"/"+pdb_name+".pdb"
# 下载pdb文件并选择适当的链
download_pdb(pdb_name, pdb_destin)
clean_pdb(pdb_destin, chain_num=chain_num)
# 加载轨迹模板并替换坐标 - 假设顺序相同
scaffold = mdtraj.load_pdb(pdb_destin)
scaffold.xyz = coords
scaffold.save(route)
return pdb_destin, route
# 将坐标转换为PDB文件
def coords2pdb(seq, coords, cloud_mask, prefix="", name="af2_struct.pdb"):
""" Turns coordinates into PDB files ready to be visualized.
Inputs:
* seq: (L,) tensor of ints (sidechainnet aa-key pairs)
* coords: (3, N) coords of atoms
* cloud_mask: (L, C) boolean mask of occupied spaces in scn format
* prefix: str. directory to save files.
* name: str. name of destin file (ex: pred1.pdb)
"""
scaffold = torch.zeros( cloud_mask.shape, 3 )
scaffold[cloud_mask] = coords.cpu().float()
# 构建结构并保存
pred = scn.StructureBuilder( seq, crd=scaffold )
# 将预测结果保存为PDB文件,文件名由前缀和名称组成
pred.to_pdb(prefix+name)
# 定义函数,用于移除序列中的插入部分,以便在MSA中加载对齐的序列
def remove_insertions(sequence: str) -> str:
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)
return sequence.translate(translation)
# 从MSA文件中读取前nseq个序列,自动移除插入部分
def read_msa(filename: str, nseq: int):
""" Reads the first nseq sequences from an MSA file, automatically removes insertions."""
return [(record.description, remove_insertions(str(record.seq)))
for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]
# 将氨基酸id转换为用于计算ESM和MSA变换器嵌入的氨基酸字符串输入
def ids_to_embed_input(x):
""" Returns the amino acid string input for calculating the ESM and MSA transformer embeddings
Inputs:
* x: any deeply nested list of integers that correspond with amino acid id
"""
assert isinstance(x, list), 'input must be a list'
id2aa = VOCAB._int2char
out = []
for el in x:
if isinstance(el, list):
out.append(ids_to_embed_input(el))
elif isinstance(el, int):
out.append(id2aa[el])
else:
raise TypeError('type must be either list or character')
if all(map(lambda c: isinstance(c, str), out)):
return (None, ''.join(out))
return out
# 将氨基酸id转换为用于计算ESM和MSA变换器嵌入的氨基酸字符串输入
def ids_to_prottran_input(x):
""" Returns the amino acid string input for calculating the ESM and MSA transformer embeddings
Inputs:
* x: any deeply nested list of integers that correspond with amino acid id
"""
assert isinstance(x, list), 'input must be a list'
id2aa = VOCAB._int2char
out = []
for ids in x:
chars = ' '.join([id2aa[i] for i in ids])
chars = re.sub(r"[UZOB]", "X", chars)
out.append(chars)
return out
# 获取ProtTrans嵌入
def get_prottran_embedd(seq, model, tokenizer, device = None):
from transformers import pipeline
fe = pipeline('feature-extraction', model = model, tokenizer = tokenizer, device = (-1 if not exists(device) else device.index))
max_seq_len = seq.shape[1]
embedd_inputs = ids_to_prottran_input(seq.cpu().tolist())
embedding = fe(embedd_inputs)
embedding = torch.tensor(embedding, device = device)
return embedding[:, 1:(max_seq_len + 1)]
# 获取MSA嵌入
def get_msa_embedd(msa, embedd_model, batch_converter, device = None):
""" Returns the MSA_tr embeddings for a protein.
Inputs:
* seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention)
* embedd_model: MSA_tr model (see train_end2end.py for an example)
* batch_converter: MSA_tr batch converter (see train_end2end.py for an example)
Outputs: tensor of (batch, n_seqs, L, embedd_dim)
* n_seqs: number of sequences in the MSA
* embedd_dim: number of embedding dimensions. 768 for MSA_Transformer
"""
# use MSA transformer
REPR_LAYER_NUM = 12
device = seq.device
max_seq_len = msa.shape[-1]
embedd_inputs = ids_to_embed_input(msa.cpu().tolist())
msa_batch_labels, msa_batch_strs, msa_batch_tokens = batch_converter(embedd_inputs)
with torch.no_grad():
results = embedd_model(msa_batch_tokens.to(device), repr_layers=[REPR_LAYER_NUM], return_contacts=False)
# index 0 is for start token. so take from 1 one
token_reps = results["representations"][REPR_LAYER_NUM][..., 1:max_seq_len+1, :]
return token_reps
# 获取ESM嵌入
def get_esm_embedd(seq, embedd_model, batch_converter, msa_data=None):
""" Returns the ESM embeddings for a protein.
Inputs:
* seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention)
* embedd_model: ESM model (see train_end2end.py for an example)
* batch_converter: ESM batch converter (see train_end2end.py for an example)
Outputs: tensor of (batch, n_seqs, L, embedd_dim)
* n_seqs: number of sequences in the MSA. 1 for ESM-1b
* embedd_dim: number of embedding dimensions. 1280 for ESM-1b
"""
# use ESM transformer
# 获取输入序列的设备信息
device = seq.device
# 定义表示层编号
REPR_LAYER_NUM = 33
# 获取序列的最大长度
max_seq_len = seq.shape[-1]
# 将序列转换为嵌入输入
embedd_inputs = ids_to_embed_input(seq.cpu().tolist())
# 使用批量转换器将嵌入输入转换为批量标签、字符串和令牌
batch_labels, batch_strs, batch_tokens = batch_converter(embedd_inputs)
# 禁用梯度计算
with torch.no_grad():
# 使用嵌入模型获取结果
results = embedd_model(batch_tokens.to(device), repr_layers=[REPR_LAYER_NUM], return_contacts=False)
# 从结果中提取令牌表示,排除起始令牌
token_reps = results["representations"][REPR_LAYER_NUM][..., 1:max_seq_len+1, :].unsqueeze(dim=1)
# 返回令牌表示
return token_reps
# 返回给定蛋白质的ProtT5-XL-U50嵌入
def get_t5_embedd(seq, tokenizer, encoder, msa_data=None, device=None):
""" Returns the ProtT5-XL-U50 embeddings for a protein.
Inputs:
* seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention)
* tokenizer: tokenizer model: T5Tokenizer
* encoder: encoder model: T5EncoderModel
ex: from transformers import T5EncoderModel, T5Tokenizer
model_name = "Rostlab/prot_t5_xl_uniref50"
tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False )
model = T5EncoderModel.from_pretrained(model_name)
# prepare model
model = model.to(device)
model = model.eval()
if torch.cuda.is_available():
model = model.half()
Outputs: tensor of (batch, n_seqs, L, embedd_dim)
* n_seqs: number of sequences in the MSA. 1 for T5 models
* embedd_dim: number of embedding dimensions. 1024 for T5 models
"""
# 获取参数并准备
device = seq.device if device is None else device
embedd_inputs = ids_to_prottran_input(seq.cpu().tolist())
# 嵌入 - https://huggingface.co/Rostlab/prot_t5_xl_uniref50
inputs_embedding = []
shift_left, shift_right = 0, -1
ids = tokenizer.batch_encode_plus(embedd_inputs, add_special_tokens=True,
padding=True,
return_tensors="pt")
with torch.no_grad():
embedding = encoder(input_ids=torch.tensor(ids['input_ids']).to(device),
attention_mask=torch.tensor(ids["attention_mask"]).to(device))
# 返回 (batch, seq_len, embedd_dim)
token_reps = embedding.last_hidden_state[:, shift_left:shift_right].to(device)
token_reps = expand_dims_to(token_reps, 4-len(token_reps.shape))
return token_reps.float()
# 获取所有蛋白质的ID
def get_all_protein_ids(dataloader, verbose=False):
""" Given a sidechainnet dataloader for a CASP version,
Returns all the ids belonging to proteins.
Inputs:
* dataloader: a sidechainnet dataloader for a CASP version
Outputs: a set containing the ids for all protein entries.
"""
# 在此处存储ID
ids = set([])
# 遍历所有批次
for i,batch in tqdm(enumerate(dataloaders['train'])):
# 用于同时跳出两个循环
try:
for i in range(batch.int_seqs.shape[0]):
# 检查所有片段是否为:4_LETTER_PDB + NUM + CHAIN
max_len_10 = len(batch.pids[i]) < 10
fragments = [len(x) <= 4 for x in batch.pids[i].split("_")]
fragments_under_4 = sum(fragments) == len(fragments) # AND CONDITION
# 记录ID
if max_len_10 and fragments_under_4:
ids.add(batch.pids[i])
else:
if verbose:
print("skip:", batch.pids[i], "under 4", fragments)
except StopIteration:
break
# 返回ID集合
return ids
# 获取SCN序列的布尔掩码原子位置(不是所有氨基酸都具有相同的原子)
def scn_cloud_mask(scn_seq, boolean=True, coords=None):
""" Gets the boolean mask atom positions (not all aas have same atoms).
Inputs:
* scn_seq: (batch, length) sequence as provided by Sidechainnet package
* boolean: whether to return as array of idxs or boolean values
* coords: optional .(batch, lc, 3). sidechainnet coords.
returns the true mask (solves potential atoms that might not be provided)
Outputs: (batch, length, NUM_COORDS_PER_RES) boolean mask
"""
scn_seq = expand_dims_to(scn_seq, 2 - len(scn_seq.shape))
# 用于坐标掩码的早期检查
# 如果给定坐标不为空
if coords is not None:
# 重新排列坐标,将坐标的维度重新排列为'... l c d',其中c为每个残基的坐标数
batch_mask = ( rearrange(coords, '... (l c) d -> ... l c d', c=constants.NUM_COORDS_PER_RES) == 0 ).sum(dim=-1) < coords.shape[-1]
# 如果需要返回布尔值
if boolean:
# 返回布尔类型的批量掩码
return batch_mask.bool()
else:
# 返回非零元素的索引
return batch_mask.nonzero()
# 在 CPU 上执行循环
device = scn_seq.device
# 初始化空列表用于存储批量掩码
batch_mask = []
# 将 scn_seq 转移到 CPU 并转换为列表
scn_seq = scn_seq.cpu().tolist()
# 遍历 scn_seq 中的序列
for i, seq in enumerate(scn_seq):
# 获取每个蛋白质的掩码(每个氨基酸的点)
batch_mask.append( torch.tensor([CUSTOM_INFO[VOCAB._int2char[aa]]['cloud_mask'] \
for aa in seq]).bool().to(device) )
# 在最后一个维度上连接
batch_mask = torch.stack(batch_mask, dim=0)
# 返回掩码(布尔值或索引)
if boolean:
# 返回布尔类型的批量掩码
return batch_mask.bool()
else:
# 返回非零元素的索引
return batch_mask.nonzero()
def scn_backbone_mask(scn_seq, boolean=True, n_aa=3):
""" Gets the boolean mask for N and CA positions.
Inputs:
* scn_seq: sequence(s) as provided by Sidechainnet package (int tensor/s)
* n_aa: number of atoms in a backbone. (may include cbeta as 4th pos)
* bool: whether to return as array of idxs or boolean values
Outputs: (N_mask, CA_mask, C_mask)
"""
# 创建一个与输入形状相同的全零张量
wrapper = torch.zeros(*scn_seq.shape, n_aa).to(scn_seq.device)
# 将N设为每个氨基酸的第一个原子,CA设为第二个原子
wrapper[..., 0] = 1
wrapper[..., 1] = 2
wrapper[..., 2] = 3
# 重新排列张量的维度
wrapper = rearrange(wrapper, '... l c -> ... (l c)')
# 创建N、CA、C的布尔掩码
N_mask = wrapper == 1
CA_mask = wrapper == 2
C_mask = wrapper == 3
if boolean:
return N_mask, CA_mask, C_mask
return torch.nonzero(N_mask), torch.nonzero(CA_mask), torch.nonzero(C_mask)
def scn_atom_embedd(scn_seq):
""" Returns the token for each atom in the aa.
Inputs:
* scn_seq: sequence(s) as provided by Sidechainnet package (int tensor/s)
"""
device = scn_seq.device
batch_tokens = []
# 在CPU上进行循环
scn_seq = scn_seq.cpu().tolist()
for i,seq in enumerate(scn_seq):
# 为每个氨基酸中的原子返回令牌
batch_tokens.append( torch.tensor([CUSTOM_INFO[VOCAB.int2char(aa)]["atom_id_embedd"] \
for aa in seq]) )
batch_tokens = torch.stack(batch_tokens, dim=0).long().to(device)
return batch_tokens
def mat_input_to_masked(x, x_mask=None, edges_mat=None, edges=None,
edge_mask=None, edge_attr_mat=None,
edge_attr=None):
""" Turns the padded input and edges + mask into the
non-padded inputs and edges.
At least one of (edges_mat, edges) must be provided.
The same format for edges and edge_attr must be provided
(either adj matrix form or flattened form).
Inputs:
* x: ((batch), N, D) a tensor of N nodes and D dims for each one
* x_mask: ((batch), N,) boolean mask for x
* edges: (2, E) optional. indices of the corresponding adjancecy matrix.
* edges_mat: ((batch), N, N) optional. adjacency matrix for x
* edge_mask: optional. boolean mask of the same shape of either "edge_mat" or "edges".
* edge_attr: (E, D_edge) optional. edge attributes of D_edge dims.
* edge_attr_mat: ((batch), N, N) optional. adjacency matrix with features
Outputs:
* x: (N_, D) the masked node features
* edge_index: (2, E_) the masked x-indices for the edges
* edge_attr: (E_, D_edge) the masked edge attributes
* batch: (N_,) the corresponding index in the batch for each node
"""
# 折叠批处理维度
if len(x.shape) == 3:
batch_dim = x.shape[1]
# 为x和其掩码折叠
x = rearrange(x, 'b n d ... -> (b n) d ...')
if x_mask is not None:
x_mask = rearrange(x_mask, 'b n ... -> (b n) ...')
else:
x_mask = torch.ones_like(x[..., 0]).bool()
# 如果需要,为边索引和属性折叠
if edges_mat is not None and edges is None:
edges = torch.nonzero(edges_mat, as_tuple=False).t()
edges = edges[1:] + edges[:1]*batch_dim
# 获取每个节点的批处理标识符
batch = (torch.arange(x.shape[0], device=x.device) // batch_dim)[x_mask]
else:
# 将边转换为索引格式
if edges_mat is not None and edges is None:
edges = torch.nonzero(edges_mat, as_tuple=False).t()
# 获取每个节点的批处理标识符
batch = torch.zeros(x.shape[0], device=x.device).to(x.device)
# 如果提供了边属性矩阵,则调整边属性
if edge_attr_mat is not None and edge_attr is None:
edge_attr = edge_attr[edges_mat.bool()]
# 如果未提供边掩码,则生成边掩码
if edge_mask is None:
edge_mask = torch.ones_like(edges[-1]).bool()
# 开始应用掩码,筛选出符合条件的元素
x = x[x_mask]
# 处理边的索引:获取方阵并移除所有非编码原子
# 计算边的最大值,用于创建方阵
max_num = edges.max().item()+1
# 创建一个全零的方阵,大小为最大值,转移到与 x 相同的设备上
wrapper = torch.zeros(max_num, max_num).to(x.device)
# 根据边的索引,将对应位置置为 1
wrapper[edges[0][edge_mask], edges[1][edge_mask]] = 1
# 根据 x 的掩码,筛选出对应的行和列,得到新的方阵
wrapper = wrapper[x_mask, :][:, x_mask]
# 找到非零元素的索引,作为新的边索引
edge_index = torch.nonzero(wrapper, as_tuple=False).t()
# 处理边的属性
# 如果边属性不为空,则根据边的掩码筛选出对应的属性,否则为 None
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
# 返回处理后的 x、边索引、边属性和批次信息
return x, edge_index, edge_attr, batch
def nth_deg_adjacency(adj_mat, n=1, sparse=False):
""" Calculates the n-th degree adjacency matrix.
计算第 n 次邻接矩阵。
Performs mm of adj_mat and adds the newly added.
执行 adj_mat 的矩阵乘法并添加新添加的部分。
Default is dense. Mods for sparse version are done when needed.
默认为密集矩阵。在需要时进行稀疏版本的修改。
Inputs:
* adj_mat: (N, N) adjacency tensor
* n: int. degree of the output adjacency
* sparse: bool. whether to use torch-sparse module
输入:
* adj_mat: (N, N) 邻接张量
* n: int。输出邻接的度
* sparse: bool。是否使用 torch-sparse 模块
Outputs:
* edge_idxs: ij positions of the adjacency matrix
* edge_attrs: degree of connectivity (1 for neighs, 2 for neighs^2, ... )
输出:
* edge_idxs: 邻接矩阵的 ij 位置
* edge_attrs: 连通度的度数(1 表示邻居,2 表示邻居的平方,...)
"""
adj_mat = adj_mat.float()
attr_mat = torch.zeros_like(adj_mat)
new_adj_mat = adj_mat.clone()
for i in range(n):
if i == 0:
attr_mat += adj_mat
continue
if i == 1 and sparse:
idxs = adj_mat.nonzero().t()
vals = adj_mat[idxs[0], idxs[1]]
new_idxs = idxs.clone()
new_vals = vals.clone()
m, k, n = 3 * [adj_mat.shape[0]] # (m, n) * (n, k) , but adj_mats are squared: m=n=k
if sparse:
new_idxs, new_vals = torch_sparse.spspmm(new_idxs, new_vals, idxs, vals, m=m, k=k, n=n)
new_vals = new_vals.bool().float()
# fill by indexes bc it's faster in sparse mode - will need an intersection function
previous = attr_mat[new_idxs[0], new_idxs[1]].bool().float()
attr_mat[new_idxs[0], new_idxs[1]] = (1 - previous)*(i+1)
else:
new_adj_mat = (new_adj_mat @ adj_mat).bool().float()
attr_mat.masked_fill( (new_adj_mat - attr_mat.bool().float()).bool(), i+1 )
return new_adj_mat, attr_mat
def prot_covalent_bond(seqs, adj_degree=1, cloud_mask=None, mat=True, sparse=False):
""" Returns the idxs of covalent bonds for a protein.
返回蛋白质的共价键的索引。
Inputs
* seq: (b, n) torch long.
* adj_degree: int. adjacency degree
* cloud_mask: mask selecting the present atoms.
* mat: whether to return as indexes of only atoms (PyG version)
or matrices of masked atoms (for batched training).
for indexes, only 1 seq is supported.
* sparse: bool. whether to use torch_sparse for adj_mat calc
输入
* seq: (b, n) torch long.
* adj_degree: int. 邻接度
* cloud_mask: 选择当前原子的掩码。
* mat: 是否返回仅原子的索引(PyG 版本)或掩码原子的矩阵(用于批量训练)。
对于索引,仅支持 1 个 seq。
* sparse: bool。是否使用 torch_sparse 计算 adj_mat
Outputs: edge_idxs, edge_types (degree of adjacency).
输出:edge_idxs, edge_types(邻接度)。
"""
device = seqs.device
# set up container adj_mat (will get trimmed - less than 14)
next_aa = NUM_COORDS_PER_RES
adj_mat = torch.zeros(seqs.shape[0], *[seqs.shape[1]*NUM_COORDS_PER_RES]*2)
# not needed to device since it's only for indices
seq_list = seqs.cpu().tolist()
for s,seq in enumerate(seq_list):
next_idx = 0
for i,idx in enumerate(seq):
aa_bonds = constants.AA_DATA[VOCAB._int2char[idx]]['bonds']
# if no edges -> padding token -> finish bond creation for this seq
if len(aa_bonds) == 0:
break
# correct next position. for indexes functionality
next_aa = max(aa_bonds, key=lambda x: max(x))[-1]
# offset by pos in chain ( intra-aa bonds + with next aa )
bonds = next_idx + torch.tensor( aa_bonds + [[2, next_aa]] ).t()
next_idx += next_aa
# delete link with next if final AA in seq
if i == seqs.shape[1] - 1:
bonds = bonds[:, :-1]
# modify adj mat
adj_mat[s, bonds[0], bonds[1]] = 1
# convert to undirected
adj_mat[s] = adj_mat[s] + adj_mat[s].t()
# do N_th degree adjacency
adj_mat, attr_mat = nth_deg_adjacency(adj_mat, n=adj_degree, sparse=sparse)
if mat:
# return the full matrix/tensor
return attr_mat.bool().to(seqs.device), attr_mat.to(device)
else:
edge_idxs = attr_mat[0].nonzero().t().long()
edge_types = attr_mat[0, edge_idxs[0], edge_idxs[1]]
return edge_idxs.to(seqs.device), edge_types.to(seqs.device)
def sidechain_container(seqs, backbones, atom_mask, cloud_mask=None, padding_tok=20):
""" Gets a backbone of the protein, returns the whole coordinates
with sidechains (same format as sidechainnet). Keeps differentiability.
Inputs:
* seqs: (batch, L) either tensor or list
* backbones: (batch, L*n_aa, 3): assume batch=1 (could be extended (?not tested)).
Coords for (N-term, C-alpha, C-term, (c_beta)) of every aa.
* atom_mask: (14,). int or bool tensor specifying which atoms are passed.
* cloud_mask: (batch, l, c). optional. cloud mask from scn_cloud_mask`.
sets point outside of mask to 0. if passed, else c_alpha
* padding: int. padding token. same as in sidechainnet: 20
Outputs: whole coordinates of shape (batch, L, 14, 3)
"""
# 将 atom_mask 转换为布尔类型,并移动到 CPU 上进行分离
atom_mask = atom_mask.bool().cpu().detach()
# 计算累积的原子掩码
cum_atom_mask = atom_mask.cumsum(dim=-1).tolist()
# 获取 backbones 的设备信息和形状
device = backbones.device
batch, length = backbones.shape[0], backbones.shape[1] // cum_atom_mask[-1]
predicted = rearrange(backbones, 'b (l back) d -> b l back d', l=length)
# 如果整个链已经被预测,则直接返回预测结果
if cum_atom_mask[-1] == 14:
return predicted
# 从 (N, CA, C, CB) 构建支架 - 在 CPU 上进行
new_coords = torch.zeros(batch, length, constants.NUM_COORDS_PER_RES, 3)
predicted = predicted.cpu() if predicted.is_cuda else predicted
# 如果原子已经传递,则填充原子
for i,atom in enumerate(atom_mask.tolist()):
if atom:
new_coords[:, :, i] = predicted[:, :, cum_atom_mask[i]-1]
# 如果原子未传递,则生成侧链
for s,seq in enumerate(seqs):
# 格式化序列
if isinstance(seq, torch.Tensor):
padding = (seq == padding_tok).sum().item()
seq_str = ''.join([VOCAB._int2char[aa] for aa in seq.cpu().numpy()[:-padding or None]])
elif isinstance(seq, str):
padding = 0
seq_str = seq
# 获取支架
scaffolds = mp_nerf.proteins.build_scaffolds_from_scn_angles(seq_str, angles=None, device="cpu")
coords, _ = mp_nerf.proteins.sidechain_fold(wrapper = new_coords[s, :-padding or None].detach(),
**scaffolds, c_beta = cum_atom_mask[4]==5)
# 添加分离的 scn
for i,atom in enumerate(atom_mask.tolist()):
if not atom:
new_coords[:, :-padding or None, i] = coords[:, i]
new_coords = new_coords.to(device)
if cloud_mask is not None:
new_coords[torch.logical_not(cloud_mask)] = 0.
# 用前一个点位置(或 N 如果位置是 AA 的第 13 个)替换任何 NaN
nan_mask = list(torch.nonzero(new_coords!=new_coords, as_tuple=True))
new_coords[nan_mask[0], nan_mask[1], nan_mask[2]] = new_coords[nan_mask[0],
nan_mask[1],
(nan_mask[-2]+1) % new_coords.shape[-1]
return new_coords.to(device)
# 距离工具(距离直方图到距离矩阵 + 掩码)
def center_distogram_torch(distogram, bins=DISTANCE_THRESHOLDS, min_t=1., center="mean", wide="std"):
""" Returns the central estimate of a distogram. Median for now.
Inputs:
* distogram: (batch, N, N, B) where B is the number of buckets.
* bins: (B,) containing the cutoffs for the different buckets
* min_t: float. lower bound for distances.
Outputs:
* central: (batch, N, N)
* dispersion: (batch, N, N)
* weights: (batch, N, N)
"""
shape, device = distogram.shape, distogram.device
# 将阈值转换为权重,并找到每个桶的平均值
n_bins = ( bins - 0.5 * (bins[2] - bins[1]) ).to(device)
n_bins[0] = 1.5
n_bins[-1] = 1.33*bins[-1] # 忽略最后一个阈值以上的值
# 计算中心性和离散度的度量 -
# 计算直方图的幅度
magnitudes = distogram.sum(dim=-1)
# 如果选择中心为"中位数"
if center == "median":
# 计算累积分布
cum_dist = torch.cumsum(distogram, dim=-1)
# 计算中位数
medium = 0.5 * cum_dist[..., -1:]
# 找到中心位置
central = torch.searchsorted(cum_dist, medium).squeeze()
# 限制中心位置在合理范围内
central = n_bins[torch.min(central, max_bin_allowed)]
# 如果选择中心为"均值"
elif center == "mean":
# 计算加权平均值
central = (distogram * n_bins).sum(dim=-1) / magnitudes
# 创建最后一个类别的掩码 - (IGNORE_INDEX)
mask = (central <= bins[-2].item()).float()
# 将对角线上的距离设为0,避免就地操作错误
diag_idxs = np.arange(shape[-2])
central = expand_dims_to(central, 3 - len(central.shape))
central[:, diag_idxs, diag_idxs] *= 0.
# 提供权重
if wide == "var":
# 计算方差
dispersion = (distogram * (n_bins - central.unsqueeze(-1))**2).sum(dim=-1) / magnitudes
elif wide == "std":
# 计算标准差
dispersion = ((distogram * (n_bins - central.unsqueeze(-1))**2).sum(dim=-1) / magnitudes).sqrt()
else:
# 如果未指定宽度,则权重为0
dispersion = torch.zeros_like(central, device=device)
# 重新缩放到0-1。较低的标准差/方差 --> 权重=1。将潜在的NaN值设为0
weights = mask / (1 + dispersion)
weights[weights != weights] *= 0.
weights[:, diag_idxs, diag_idxs] *= 0.
return central, weights
# 将距离矩阵转换为三维坐标
def mds_torch(pre_dist_mat, weights=None, iters=10, tol=1e-5, eigen=False, verbose=2):
""" 获取距离矩阵,输出三维坐标。参见下面的包装器。
假设(目前)距离图是(N x N)且对称的
输出:
* best_3d_coords: (batch x 3 x N)
* historic_stresses: (batch x steps)
"""
device, dtype = pre_dist_mat.device, pre_dist_mat.type()
# 确保批处理的MDS
pre_dist_mat = expand_dims_to(pre_dist_mat, length=(3 - len(pre_dist_mat.shape)))
# 开始
batch, N, _ = pre_dist_mat.shape
diag_idxs = np.arange(N)
his = [torch.tensor([np.inf]*batch, device=device)]
# 通过特征分解进行初始化:https://www.lptmc.jussieu.fr/user/lesne/bioinformatics.pdf
# 参考:https://www.biorxiv.org/content/10.1101/2020.11.27.401232v1.full.pdf
D = pre_dist_mat**2
M = 0.5 * (D[:, :1, :] + D[:, :, :1] - D)
# 使用循环SVD,因为它更快:(在CPU上快2-3倍,在GPU上快1-2倍)
# https://discuss.pytorch.org/t/batched-svd-lowrank-being-much-slower-than-loop-implementation-both-cpu-and-gpu/119336
svds = [torch.svd_lowrank(mi) for mi in M]
u = torch.stack([svd[0] for svd in svds], dim=0)
s = torch.stack([svd[1] for svd in svds], dim=0)
v = torch.stack([svd[2] for svd in svds], dim=0)
best_3d_coords = torch.bmm(u, torch.diag_embed(s).abs().sqrt())[..., :3]
# 仅使用特征分解 - 更快但不支持权重
if weights is None and eigen==True:
return torch.transpose(best_3d_coords, -1, -2), torch.zeros_like(torch.stack(his, dim=0))
elif eigen==True:
if verbose:
print("如果激活权重,则无法使用特征分解标志。回退到迭代方式")
# 继续迭代方式
if weights is None:
weights = torch.ones_like(pre_dist_mat)
# 迭代更新:
for i in range(iters):
# 计算坐标和应力的距离矩阵
best_3d_coords = best_3d_coords.contiguous()
dist_mat = torch.cdist(best_3d_coords, best_3d_coords, p=2).clone()
stress = (weights * (dist_mat - pre_dist_mat)**2).sum(dim=(-1,-2)) * 0.5
# 扰动 - 使用Guttman变换更新X - 类似于sklearn
dist_mat[dist_mat <= 0] += 1e-7
ratio = weights * (pre_dist_mat / dist_mat)
B = -ratio
B[:, diag_idxs, diag_idxs] += ratio.sum(dim=-1)
# 更新
coords = (1. / N * torch.matmul(B, best_3d_coords))
dis = torch.norm(coords, dim=(-1, -2))
if verbose >= 2:
print('迭代次数:%d,��力 %s' % (i, stress))
# 如果相对改进超过容差,则更新指标
if (his[-1] - stress / dis).mean() <= tol:
if verbose:
print('在迭代 %d 中以应力 %s 结束' % (i, stress / dis))
break
best_3d_coords = coords
his.append(stress / dis)
return torch.transpose(best_3d_coords, -1, -2), torch.stack(his, dim=0)
def mds_numpy(pre_dist_mat, weights=None, iters=10, tol=1e-5, eigen=False, verbose=2):
""" 获取距离矩阵。输出三维坐标。参见下面的包装器。
假设(目前)距离图是(N x N)且对称的
输出:
* best_3d_coords: (3 x N)
* historic_stress
"""
if weights is None:
weights = np.ones_like(pre_dist_mat)
# 确保批处理的MDS
pre_dist_mat = expand_dims_to(pre_dist_mat, length=(3 - len(pre_dist_mat.shape)))
# 开始
batch, N, _ = pre_dist_mat.shape
his = [np.inf]
# 初始化随机坐标
best_stress = np.inf * np.ones(batch)
best_3d_coords = 2*np.random.rand(batch, 3, N) - 1
# 迭代更新:
# 遍历指定次数的迭代
for i in range(iters):
# 计算坐标和压力的距离矩阵
dist_mat = np.linalg.norm(best_3d_coords[:, :, :, None] - best_3d_coords[:, :, None, :], axis=-3)
stress = (( weights * (dist_mat - pre_dist_mat) )**2).sum(axis=(-1, -2)) * 0.5
# 扰动 - 使用 Guttman 变换更新 X - 类似于 sklearn
dist_mat[dist_mat == 0] = 1e-7
ratio = weights * (pre_dist_mat / dist_mat)
B = -ratio
B[:, np.arange(N), np.arange(N)] += ratio.sum(axis=-1)
# 更新 - 双重转置。待办事项:考虑修复
coords = (1. / N * np.matmul(best_3d_coords, B))
dis = np.linalg.norm(coords, axis=(-1, -2))
if verbose >= 2:
print('it: %d, stress %s' % (i, stress))
# 如果相对改进超过容差,则更新指标
if (best_stress - stress / dis).mean() <= tol:
if verbose:
print('breaking at iteration %d with stress %s' % (i,
stress / dis))
break
best_3d_coords = coords
best_stress = stress / dis
his.append(best_stress)
return best_3d_coords, np.array(his)
# 定义一个函数,用于计算四个坐标点之间的二面角(dihedral angle)并返回结果,使用 torch 库
def get_dihedral_torch(c1, c2, c3, c4):
""" Returns the dihedral angle in radians.
Will use atan2 formula from:
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
Can't use torch.dot bc it does not broadcast
Inputs:
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
"""
# 计算四个坐标点之间的向量
u1 = c2 - c1
u2 = c3 - c2
u3 = c4 - c3
# 使用 torch 库中的 atan2 函数计算二面角
return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
# 定义一个函数,用于计算四个坐标点之间的二面角(dihedral angle)并返回结果,使用 numpy 库
def get_dihedral_numpy(c1, c2, c3, c4):
""" Returns the dihedral angle in radians.
Will use atan2 formula from:
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
Inputs:
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
"""
# 计算四个坐标点之间的向量
u1 = c2 - c1
u2 = c3 - c2
u3 = c4 - c3
# 使用 numpy 库中的 arctan2 函数计算二面角
return np.arctan2( ( (np.linalg.norm(u2, axis=-1, keepdims=True) * u1) * np.cross(u2,u3, axis=-1)).sum(axis=-1),
( np.cross(u1,u2, axis=-1) * np.cross(u2, u3, axis=-1) ).sum(axis=-1) )
# 定义一个函数,用于计算蛋白质的 phi 角度,选择具有最多负 phi 角度的镜像
def calc_phis_torch(pred_coords, N_mask, CA_mask, C_mask=None,
prop=True, verbose=0):
""" Filters mirrors selecting the 1 with most N of negative phis.
Used as part of the MDScaling wrapper if arg is passed. See below.
Angle Phi between planes: (Cterm{-1}, N, Ca{0}) and (N{0}, Ca{+1}, Cterm{+1})
Inputs:
* pred_coords: (batch, 3, N) predicted coordinates
* N_mask: (batch, N) boolean mask for N-term positions
* CA_mask: (batch, N) boolean mask for C-alpha positions
* C_mask: (batch, N) or None. boolean mask for C-alpha positions or
automatically calculate from N_mask and CA_mask if None.
* prop: bool. whether to return as a proportion of negative phis.
* verbose: bool. verbosity level
Output: (batch, N) containing the phi angles or (batch,) containing
the proportions.
Note: use [0] since all prots in batch have same backbone
"""
# 分离梯度以进行角度计算 - 选择镜像
pred_coords_ = torch.transpose(pred_coords.detach(), -1 , -2).cpu()
# 确保维度正确
N_mask = expand_dims_to( N_mask, 2-len(N_mask.shape) )
CA_mask = expand_dims_to( CA_mask, 2-len(CA_mask.shape) )
if C_mask is not None:
C_mask = expand_dims_to( C_mask, 2-len(C_mask.shape) )
else:
C_mask = torch.logical_not(torch.logical_or(N_mask,CA_mask))
# 选择点
n_terms = pred_coords_[:, N_mask[0].squeeze()]
c_alphas = pred_coords_[:, CA_mask[0].squeeze()]
c_terms = pred_coords_[:, C_mask[0].squeeze()]
# 计算每个批次中每个蛋白质的 phi 角度
phis = [get_dihedral_torch(c_terms[i, :-1],
n_terms[i, 1:],
c_alphas[i, 1:],
c_terms[i, 1:]) for i in range(pred_coords.shape[0])]
# 返回小于 0 的比例
if prop:
return torch.stack([(x<0).float().mean() for x in phis], dim=0 )
return phis
def calc_phis_numpy(pred_coords, N_mask, CA_mask, C_mask=None,
prop=True, verbose=0):
""" Filters mirrors selecting the 1 with most N of negative phis.
Used as part of the MDScaling wrapper if arg is passed. See below.
Angle Phi between planes: (Cterm{-1}, N, Ca{0}) and (N{0}, Ca{+1}, Cterm{+1})
Inputs:
* pred_coords: (batch, 3, N) predicted coordinates
* N_mask: (N, ) boolean mask for N-term positions
* CA_mask: (N, ) boolean mask for C-alpha positions
* C_mask: (N, ) or None. boolean mask for C-alpha positions or
automatically calculate from N_mask and CA_mask if None.
* prop: bool. whether to return as a proportion of negative phis.
* verbose: bool. verbosity level
Output: (batch, N) containing the phi angles or (batch,) containing
the proportions.
"""
# detach gradients for angle calculation - mirror selection
# 转置预测坐标,将维度顺序变为 (batch, N, 3)
pred_coords_ = np.transpose(pred_coords, (0, 2, 1))
# 获取 N 位置的坐标
n_terms = pred_coords_[:, N_mask.squeeze()]
# 获取 C-alpha 位置的坐标
c_alphas = pred_coords_[:, CA_mask.squeeze()]
# 如果未传入 C_mask,则自动选择 C-term
if C_mask is not None:
c_terms = pred_coords_[:, C_mask]
else:
# 根据 N_mask 和 CA_mask 自动计算 C-term
c_terms = pred_coords_[:, (np.ones_like(N_mask)-N_mask-CA_mask).squeeze().astype(bool) ]
# 计算每个批次中蛋白质的 phi 角度
phis = [get_dihedral_numpy(c_terms[i, :-1],
n_terms[i, 1:],
c_alphas[i, 1:],
c_terms[i, 1:]) for i in range(pred_coords.shape[0])]
# 返回小于 0 的比例
if prop:
return np.array( [(x<0).mean() for x in phis] )
return phis
# alignment by centering + rotation to compute optimal RMSD
# adapted from : https://github.com/charnley/rmsd/
def kabsch_torch(X, Y, cpu=True):
""" Kabsch alignment of X into Y.
Assumes X,Y are both (Dims x N_points). See below for wrapper.
"""
device = X.device
# center X and Y to the origin
X_ = X - X.mean(dim=-1, keepdim=True)
Y_ = Y - Y.mean(dim=-1, keepdim=True)
# calculate convariance matrix (for each prot in the batch)
C = torch.matmul(X_, Y_.t()).detach()
if cpu:
C = C.cpu()
# Optimal rotation matrix via SVD
if int(torch.__version__.split(".")[1]) < 8:
# warning! int torch 1.<8 : W must be transposed
V, S, W = torch.svd(C)
W = W.t()
else:
V, S, W = torch.linalg.svd(C)
# determinant sign for direction correction
d = (torch.det(V) * torch.det(W)) < 0.0
if d:
S[-1] = S[-1] * (-1)
V[:, -1] = V[:, -1] * (-1)
# Create Rotation matrix U
U = torch.matmul(V, W).to(device)
# calculate rotations
X_ = torch.matmul(X_.t(), U).t()
# return centered and aligned
return X_, Y_
def kabsch_numpy(X, Y):
""" Kabsch alignment of X into Y.
Assumes X,Y are both (Dims x N_points). See below for wrapper.
"""
# center X and Y to the origin
X_ = X - X.mean(axis=-1, keepdims=True)
Y_ = Y - Y.mean(axis=-1, keepdims=True)
# calculate convariance matrix (for each prot in the batch)
C = np.dot(X_, Y_.transpose())
# Optimal rotation matrix via SVD
V, S, W = np.linalg.svd(C)
# determinant sign for direction correction
d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
if d:
S[-1] = S[-1] * (-1)
V[:, -1] = V[:, -1] * (-1)
# Create Rotation matrix U
U = np.dot(V, W)
# calculate rotations
X_ = np.dot(X_.T, U).T
# return centered and aligned
return X_, Y_
# metrics - more formulas here: http://predictioncenter.org/casp12/doc/help.html
def distmat_loss_torch(X=None, Y=None, X_mat=None, Y_mat=None, p=2, q=2,
custom=None, distmat_mask=None, clamp=None):
""" Calculates a loss on the distance matrix - no need to align structs.
Inputs:
* X: (N, d) tensor. the predicted structure. One of (X, X_mat) is needed.
* X_mat: (N, N) tensor. the predicted distance matrix. Optional ()
* Y: (N, d) tensor. the true structure. One of (Y, Y_mat) is needed.
* Y_mat: (N, N) tensor. the predicted distance matrix. Optional ()
* p: int. power for the distance calculation (2 for euclidean)
* q: float. power for the scaling of the loss (2 for MSE, 1 for MAE, etc)
* custom: func or None. custom loss over distance matrices.
ex: lambda x,y: 1 - 1/ (1 + ((x-y))**2) (1 is very bad. 0 is good)
* distmat_mask: (N, N) mask (boolean or weights for each ij pos). optional.
* clamp: tuple of (min,max) values for clipping distance matrices. ex: (0,150)
"""
assert (X is not None or X_mat is not None) and \
(Y is not None or Y_mat is not None), "The true and predicted coords or dist mats must be provided"
# calculate distance matrices
if X_mat is None:
X = X.squeeze()
if clamp is not None:
X = torch.clamp(X, *clamp)
X_mat = torch.cdist(X, X, p=p)
if Y_mat is None:
Y = Y.squeeze()
if clamp is not None:
Y = torch.clamp(Y, *clamp)
Y_mat = torch.cdist(Y, Y, p=p)
if distmat_mask is None:
distmat_mask = torch.ones_like(Y_mat).bool()
# do custom expression if passed
if custom is not None:
return custom(X_mat.squeeze(), Y_mat.squeeze()).mean()
# **2 ensures always positive. Later scale back to desired power
else:
loss = ( X_mat - Y_mat )**2
if q != 2:
loss = loss**(q/2)
return loss[distmat_mask].mean()
def rmsd_torch(X, Y):
# 假设 x 和 y 都是 (B x D x N) 的张量,计算它们的平方差,然后对最后两个维度求平均值再开方,返回结果
return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) )
def rmsd_numpy(X, Y):
""" Assumes x,y are both (B x D x N). See below for wrapper. """
# 计算均方根偏差(RMSD)的numpy实现
return np.sqrt( np.mean((X - Y)**2, axis=(-1, -2)) )
def gdt_torch(X, Y, cutoffs, weights=None):
""" Assumes x,y are both (B x D x N). see below for wrapper.
* cutoffs is a list of `K` thresholds
* weights is a list of `K` weights (1 x each threshold)
"""
# 计算全局距离差(GDT)的torch实现
device = X.device
if weights is None:
weights = torch.ones(1,len(cutoffs))
else:
weights = torch.tensor([weights]).to(device)
# 初始化GDT为零,并填充值
GDT = torch.zeros(X.shape[0], len(cutoffs), device=device)
dist = ((X - Y)**2).sum(dim=1).sqrt()
# 遍历阈值
for i,cutoff in enumerate(cutoffs):
GDT[:, i] = (dist <= cutoff).float().mean(dim=-1)
# 加权平均
return (GDT*weights).mean(-1)
def gdt_numpy(X, Y, cutoffs, weights=None):
""" Assumes x,y are both (B x D x N). see below for wrapper.
* cutoffs is a list of `K` thresholds
* weights is a list of `K` weights (1 x each threshold)
"""
# 计算全局距离差(GDT)的numpy实现
if weights is None:
weights = np.ones( (1,len(cutoffs)) )
else:
weights = np.array([weights])
# 初始化GDT为零,并填充值
GDT = np.zeros( (X.shape[0], len(cutoffs)) )
dist = np.sqrt( ((X - Y)**2).sum(axis=1) )
# 遍历阈值
for i,cutoff in enumerate(cutoffs):
GDT[:, i] = (dist <= cutoff).mean(axis=-1)
# 加权平均
return (GDT*weights).mean(-1)
def tmscore_torch(X, Y):
""" Assumes x,y are both (B x D x N). see below for wrapper. """
# 计算TM得分的torch实现
L = max(15, X.shape[-1])
d0 = 1.24 * (L - 15)**(1/3) - 1.8
dist = ((X - Y)**2).sum(dim=1).sqrt()
# 公式计算
return (1 / (1 + (dist/d0)**2)).mean(dim=-1)
def tmscore_numpy(X, Y):
""" Assumes x,y are both (B x D x N). see below for wrapper. """
# 计算TM得分的numpy实现
L = max(15, X.shape[-1])
d0 = 1.24 * np.cbrt(L - 15) - 1.8
dist = np.sqrt( ((X - Y)**2).sum(axis=1) )
# 公式计算
return (1 / (1 + (dist/d0)**2)).mean(axis=-1)
def mdscaling_torch(pre_dist_mat, weights=None, iters=10, tol=1e-5,
fix_mirror=True, N_mask=None, CA_mask=None, C_mask=None,
eigen=False, verbose=2):
""" Handles the specifics of MDS for proteins (mirrors, ...) """
# MDS的torch实现,处理蛋白质的特殊情况(镜像等)
preds, stresses = mds_torch(pre_dist_mat, weights=weights,iters=iters,
tol=tol, eigen=eigen, verbose=verbose)
if not fix_mirror:
return preds, stresses
phi_ratios = calc_phis_torch(preds, N_mask, CA_mask, C_mask, prop=True)
to_correct = torch.nonzero( (phi_ratios < 0.5)).view(-1)
# 修正镜像
preds[to_correct, -1] = (-1)*preds[to_correct, -1]
if verbose == 2:
print("Corrected mirror idxs:", to_correct)
return preds, stresses
def mdscaling_numpy(pre_dist_mat, weights=None, iters=10, tol=1e-5,
fix_mirror=True, N_mask=None, CA_mask=None, C_mask=None, verbose=2):
""" Handles the specifics of MDS for proteins (mirrors, ...) """
# MDS的numpy实现,处理蛋白质的特殊情况(镜像等)
preds, stresses = mds_numpy(pre_dist_mat, weights=weights,iters=iters,
tol=tol, verbose=verbose)
if not fix_mirror:
return preds, stresses
phi_ratios = calc_phis_numpy(preds, N_mask, CA_mask, C_mask, prop=True)
for i,pred in enumerate(preds):
if phi_ratios < 0.5:
preds[i, -1] = (-1)*preds[i, -1]
if verbose == 2:
print("Corrected mirror in struct no.", i)
return preds, stresses
def lddt_ca_torch(true_coords, pred_coords, cloud_mask, r_0=15.):
""" Computes the lddt score for each C_alpha.
https://academic.oup.com/bioinformatics/article/29/21/2722/195896
Inputs:
* true_coords: (b, l, c, d) in sidechainnet format.
* pred_coords: (b, l, c, d) in sidechainnet format.
* cloud_mask : (b, l, c) adapted for scn format.
* r_0: float. maximum inclusion radius in reference struct.
Outputs:
* (b, l) lddt for c_alpha scores (ranging between 0 and 1)
See wrapper below.
"""
device, dtype = true_coords.device, true_coords.type()
thresholds = torch.tensor([0.5, 1, 2, 4], device=device).type(dtype)
# adapt masks
cloud_mask = cloud_mask.bool().cpu()
c_alpha_mask = torch.zeros(cloud_mask.shape[1:], device=device).bool() # doesn't have batch dim
c_alpha_mask[..., 1] = True
# container for c_alpha scores (between 0,1)
wrapper = torch.zeros(true_coords.shape[:2], device=device).type(dtype)
for bi, seq in enumerate(true_coords):
# select atoms for study
c_alphas = cloud_mask[bi]*c_alpha_mask # only pick c_alpha positions
selected_pred = pred_coords[bi, c_alphas, :]
selected_target = true_coords[bi, c_alphas, :]
# get number under distance
dist_mat_pred = torch.cdist(selected_pred, selected_pred, p=2)
dist_mat_target = torch.cdist(selected_target, selected_target, p=2)
under_r0_target = dist_mat_target < r_0
compare_dists = torch.abs(dist_mat_pred - dist_mat_target)[under_r0_target]
# measure diff below threshold
score = torch.zeros_like(under_r0_target).float()
max_score = torch.zeros_like(under_r0_target).float()
max_score[under_r0_target] = 4.
# measure under how many thresholds
score[under_r0_target] = thresholds.shape[0] - \
torch.bucketize( compare_dists, boundaries=thresholds ).float()
# dont include diagonal
l_mask = c_alphas.float().sum(dim=-1).bool()
wrapper[bi, l_mask] = ( score.sum(dim=-1) - thresholds.shape[0] ) / \
( max_score.sum(dim=-1) - thresholds.shape[0] )
return wrapper
################
### WRAPPERS ###
################
@set_backend_kwarg
@invoke_torch_or_numpy(mdscaling_torch, mdscaling_numpy)
def MDScaling(pre_dist_mat, **kwargs):
""" Gets distance matrix (-ces). Outputs 3d.
Assumes (for now) distrogram is (N x N) and symmetric.
For support of ditograms: see `center_distogram_torch()`
Inputs:
* pre_dist_mat: (1, N, N) distance matrix.
* weights: optional. (N x N) pairwise relative weights .
* iters: number of iterations to run the algorithm on
* tol: relative tolerance at which to stop the algorithm if no better
improvement is achieved
* backend: one of ["numpy", "torch", "auto"] for backend choice
* fix_mirror: int. number of iterations to run the 3d generation and
pick the best mirror (highest number of negative phis)
* N_mask: indexing array/tensor for indices of backbone N.
Only used if fix_mirror > 0.
* CA_mask: indexing array/tensor for indices of backbone C_alpha.
Only used if fix_mirror > 0.
* verbose: whether to print logs
Outputs:
* best_3d_coords: (3 x N)
* historic_stress: (timesteps, )
"""
pre_dist_mat = expand_dims_to(pre_dist_mat, 3 - len(pre_dist_mat.shape))
return pre_dist_mat, kwargs
@expand_arg_dims(dim_len = 2)
@set_backend_kwarg
@invoke_torch_or_numpy(kabsch_torch, kabsch_numpy)
def Kabsch(A, B):
"""
返回通过将 A 对齐到 B 而产生的 Kabsch 旋转矩阵。
从 https://github.com/charnley/rmsd/ 改编而来。
* 输入:
* A,B 是 (3 x N) 的矩阵
* backend: 选择 ["numpy", "torch", "auto"] 之一作为后端
* 输出:形状为 (3 x N) 的张量/数组
"""
# 运行计算 - 选择第 0 个,因为额外的维度已经被创建
return A, B
# 为 RMSD 函数添加装饰器,用于扩展参数维度
# 为 RMSD 函数添加装饰器,设置后端参数
# 调用 torch 或 numpy 中的 rmsd_torch 或 rmsd_numpy 函数
def RMSD(A, B):
""" Returns RMSD score as defined here (lower is better):
https://en.wikipedia.org/wiki/
Root-mean-square_deviation_of_atomic_positions
* Inputs:
* A,B are (B x 3 x N) or (3 x N)
* backend: one of ["numpy", "torch", "auto"] for backend choice
* Outputs: tensor/array of size (B,)
"""
return A, B
# 为 GDT 函数添加装饰器,用于扩展参数维度
# 为 GDT 函数添加装饰器,设置后端参数
# 调用 torch 或 numpy 中的 gdt_torch 或 gdt_numpy 函数
def GDT(A, B, *, mode="TS", cutoffs=[1,2,4,8], weights=None):
""" Returns GDT score as defined here (highre is better):
Supports both TS and HA
http://predictioncenter.org/casp12/doc/help.html
* Inputs:
* A,B are (B x 3 x N) (np.array or torch.tensor)
* cutoffs: defines thresholds for gdt
* weights: list containing the weights
* mode: one of ["numpy", "torch", "auto"] for backend
* Outputs: tensor/array of size (B,)
"""
# 根据不同的模式设置不同的截断值和权重
cutoffs = [0.5,1,2,4] if mode in ["HA", "ha"] else [1,2,4,8]
# 计算 GDT
return A, B, cutoffs, {'weights': weights}
# 为 TMscore 函数添加装饰器,用于扩展参数维度
# 为 TMscore 函数添加装饰器,设置后端参数
# 调用 torch 或 numpy 中的 tmscore_torch 或 tmscore_numpy 函数
def TMscore(A, B):
""" Returns TMscore as defined here (higher is better):
>0.5 (likely) >0.6 (highly likely) same folding.
= 0.2. https://en.wikipedia.org/wiki/Template_modeling_score
Warning! It's not exactly the code in:
https://zhanglab.ccmb.med.umich.edu/TM-score/TMscore.cpp
but will suffice for now.
Inputs:
* A,B are (B x 3 x N) (np.array or torch.tensor)
* mode: one of ["numpy", "torch", "auto"] for backend
Outputs: tensor/array of size (B,)
"""
return A, B
.\lucidrains\alphafold2\alphafold2_pytorch\__init__.py
# 从 alphafold2_pytorch.alphafold2 模块中导入 Alphafold2 和 Evoformer 类
from alphafold2_pytorch.alphafold2 import Alphafold2, Evoformer
Alphafold2 - Pytorch (wip)
To eventually become an unofficial working Pytorch implementation of Alphafold2, the breathtaking attention network that solved CASP14. Will be gradually implemented as more details of the architecture is released.
Once this is replicated, I intend to fold all available amino acid sequences out there in-silico and release it as an academic torrent, to further science. If you are interested in replication efforts, please drop by #alphafold at this Discord channel
Update: Deepmind has open sourced the official code in Jax, along with the weights 🙏! This repository will now be geared towards a straight pytorch translation with some improvements on positional encoding
Install
$ pip install alphafold2-pytorch
Status
lhatsk has reported training a modified trunk of this repository, using the same setup as trRosetta, with competitive results
blue used the the trRosetta input (MSA -> potts -> axial attention), green used the ESM embedding (only sequence) -> tiling -> axial attention
- lhatsk
Usage
Predicting distogram, like Alphafold-1, but with attention
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
reversible = False # set this to True for fully reversible self / cross attention for the trunk
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda() # AA length of 128
msa = torch.randint(0, 21, (1, 5, 120)).cuda() # MSA doesn't have to be the same length as primary sequence
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
You can also turn on prediction for the angles, by passing a predict_angles = True
on init. The below example would be equivalent to trRosetta but with self / cross attention.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_angles = True # set this to True
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram, theta, phi, omega = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# distogram - (1, 128, 128, 37),
# theta - (1, 128, 128, 25),
# phi - (1, 128, 128, 13),
# omega - (1, 128, 128, 25)
Predicting Coordinates
Fabian's recent paper suggests iteratively feeding the coordinates back into SE3 Transformer, weight shared, may work. I have decided to execute based on this idea, even though it is still up in the air how it actually works.
You can also use E(n)-Transformer or EGNN for structural refinement.
Update: Baker's lab have shown that an end-to-end architecture from sequence and MSA embeddings to SE3 Transformers can best trRosetta and close the gap to Alphafold2. We will be using the Graph Transformer, which acts on the trunk embeddings, to generate the initial set of coordinates to be sent to the equivariant network. (This is further corroborated by Costa et al in their work teasing out 3d coordinates from MSA Transformer embeddings in a paper predating Baker lab's)
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_coords = True,
structure_module_type = 'se3', # use SE3 Transformer - if set to False, will use E(n)-Transformer, Victor and Max Welling's new paper
structure_module_dim = 4, # se3 transformer dimension
structure_module_depth = 1, # depth
structure_module_heads = 1, # heads
structure_module_dim_head = 16, # dimension of heads
structure_module_refinement_iters = 2, # number of equivariant coordinate refinement iterations
structure_num_global_nodes = 1 # number of global nodes for the structure module, only works with SE3 transformer
).cuda()
seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (2, 64 * 3, 3) <-- 3 atoms per residue
Atoms
The underlying assumption is that the trunk works on the residue level, and then constitutes to atomic level for the structure module, whether it be SE3 Transformers, E(n)-Transformer, or EGNN doing the refinement. This library defaults to the 3 backbone atoms (C, Ca, N), but you can configure it to include any other atom you like, including Cb and the sidechains.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_coords = True,
atoms = 'backbone-with-cbeta'
).cuda()
seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (2, 64 * 4, 3) <-- 4 atoms per residue (C, Ca, N, Cb)
Valid choices for atoms
include:
backbone
- 3 backbone atoms (C, Ca, N) [default]backbone-with-cbeta
- 3 backbone atoms and C betabackbone-with-oxygen
- 3 backbone atoms and oxygen from carboxylbackbone-with-cbeta-and-oxygen
- 3 backbone atoms with C beta and oxygenall
- backbone and all other atoms from sidechain
You can also pass in a tensor of shape (14,) defining which atoms you would like to include
ex.
atoms = torch.tensor([1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
MSA, ESM, or ProtTrans Embeddings
This repository offers you an easy supplement the network with pre-trained embeddings from Facebook AI. It contains wrappers for the pre-trained ESM, MSA Transformers or Protein Transformer.
There are some prerequisites. You will need to make sure that you have Nvidia's apex library installed, as the pretrained transformers make use of some fused operations.
Or you can try running the script below
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Next, you will simply have to import and wrap your Alphafold2
instance with a ESMEmbedWrapper
, MSAEmbedWrapper
, or ProtTranEmbedWrapper
and it will take care of embedding both the sequence and the multiple-sequence alignments for you (and projecting it to the dimensions as specified on your model). Nothing needs to be changed save for adding the wrapper.
import torch
from alphafold2_pytorch import Alphafold2
from alphafold2_pytorch.embeds import MSAEmbedWrapper
alphafold2 = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64
)
model = MSAEmbedWrapper(
alphafold2 = alphafold2
).cuda()
seq = torch.randint(0, 21, (2, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa = torch.randint(0, 21, (2, 5, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
By default, even if the wrapper supplies the trunk with the sequence and MSA embeddings, they would be summed with the usual token embeddings. If you want to train Alphafold2 without token embeddings (only rely on pretrained embeddings), you would need to set disable_token_embed
to True
on Alphafold2
init.
alphafold2 = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
disable_token_embed = True
)
Real-Value Distance Prediction
A paper by Jinbo Xu suggests that one doesn't need to bin the distances, and can instead predict the mean and standard deviation directly. You can use this by turning on one flag predict_real_value_distances
, in which case, the distance prediction returned will have a dimension of 2
for the mean and standard deviation respectively.
If predict_coords
is also turned on, then the MDS will accept the mean and standard deviation predictions directly without having to calculate that from the distogram bins.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_coords = True,
predict_real_value_distances = True, # set this to True
structure_module_type = 'se3',
structure_module_dim = 4,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 16,
structure_module_refinement_iters = 2
).cuda()
seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (2, 64 * 3, 3) <-- 3 atoms per residue
Convolutions
You can add convolutional blocks, for both the primary sequence as well as the MSA, by simply setting one extra keyword argument use_conv = True
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
use_conv = True # set this to True
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
The convolutional kernels follow the lead of this paper, combining 1d and 2d kernels in one resnet-like block. You can fully customize the kernels as such.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
use_conv = True, # set this to True
conv_seq_kernels = ((9, 1), (1, 9), (3, 3)), # kernels for N x N primary sequence
conv_msa_kernels = ((1, 9), (3, 3)), # kernels for {num MSAs} x N MSAs
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
You can also do cycle dilation with one extra keyword argument. Default dilation is 1
for all layers.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
use_conv = True, # set this to True
dilations = (1, 3, 5) # cycle between dilations of 1, 3, 5
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
Finally, instead of following the pattern of convolutions, self-attention, cross-attention per depth repeating, you can customize any order you wish with the custom_block_types
keyword
ex. A network where you do predominately convolutions first, followed by self-attention + cross-attention blocks
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
heads = 8,
dim_head = 64,
custom_block_types = (
*(('conv',) * 6),
*(('self', 'cross') * 6)
)
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
Sparse Attention
You can train with Microsoft Deepspeed's Sparse Attention, but you will have to endure the installation process. It is two-steps.
First, you need to install Deepspeed with Sparse Attention
$ sh install_deepspeed.sh
Next, you need to install the pip package triton
$ pip install triton
If both of the above succeeded, now you can train with Sparse Attention!
Sadly, the sparse attention is only supported for self attention, and not cross attention. I will bring in a different solution for making cross attention performant.
model = Alphafold2(
dim = 256,
depth = 12,
heads = 8,
dim_head = 64,
max_seq_len = 2048, # the maximum sequence length, this is required for sparse attention. the input cannot exceed what is set here
sparse_self_attn = (True, False) * 6 # interleave sparse and full attention for all 12 layers
).cuda()
Linear Attention
I have also added one of the best linear attention variants, in the hope of lessening the burden of cross attending. I personally have not found Performer to work that well, but since in the paper they reported some ok numbers for protein benchmarks, I thought I'd include it and allow others to experiment.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
cross_attn_linear = True # simply set this to True to use Performer for all cross attention
).cuda()
You can also specify the exact layers you wish to use linear attention by passing in a tuple of the same length as the depth
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 6,
heads = 8,
dim_head = 64,
cross_attn_linear = (True, False) * 3 # interleave linear and full attention
).cuda()
Kronecker Attention for Cross Attention
This paper suggests that if you have queries or contexts that have defined axials (say an image), you can reduce the amount of attention needed by averaging across those axials (height and width) and concatenating the averaged axials into one sequence. You can turn this on as a memory saving technique for the cross attention, specifically for the primary sequence.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 6,
heads = 8,
dim_head = 64,
cross_attn_kron_primary = True # make sure primary sequence undergoes the kronecker operator during cross attention
).cuda()
You can also apply the same operator to the MSAs during cross attention with the cross_attn_kron_msa
flag, if your MSAs are aligned and of the same width.
Todo
Memory Compressed Attention
To save on memory for cross attention, you can set a compression ratio for the key / values, following the scheme laid out in this paper. A compression ratio of 2-4 is usually acceptable.
model = Alphafold2(
dim = 256,
depth = 12,
heads = 8,
dim_head = 64,
cross_attn_compress_ratio = 3
).cuda()
MSA processing in Trunk
A new paper by Roshan Rao proposes using axial attention for pretraining on MSA's. Given the strong results, this repository will use the same scheme in the trunk, specifically for the MSA self-attention.
You can also tie the row attentions of the MSA with the msa_tie_row_attn = True
setting on initialization of Alphafold2
. However, in order to use this, you must make sure that if you have uneven number of MSAs per primary sequence, that the MSA mask is properly set to False
for the rows not in use.
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
msa_tie_row_attn = True # just set this to true
)
Template processing in Trunk
Template processing is also largely done with axial attention, with cross attention done along the number of templates dimension. This largely follows the same scheme as in the recent all-attention approach to video classification as shown here.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 5,
heads = 8,
dim_head = 64,
reversible = True,
sparse_self_attn = False,
max_seq_len = 256,
cross_attn_compress_ratio = 3
).cuda()
seq = torch.randint(0, 21, (1, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa = torch.randint(0, 21, (1, 10, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
templates_seq = torch.randint(0, 21, (1, 2, 16)).cuda()
templates_coors = torch.randint(0, 37, (1, 2, 16, 3)).cuda()
templates_mask = torch.ones_like(templates_seq).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
templates_seq = templates_seq,
templates_coors = templates_coors,
templates_mask = templates_mask
)
If sidechain information is also present, in the form of the unit vector between the C and C-alpha coordinates of each residue, you can also pass it in as follows.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 5,
heads = 8,
dim_head = 64,
reversible = True,
sparse_self_attn = False,
max_seq_len = 256,
cross_attn_compress_ratio = 3
).cuda()
seq = torch.randint(0, 21, (1, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa = torch.randint(0, 21, (1, 10, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
templates_seq = torch.randint(0, 21, (1, 2, 16)).cuda()
templates_coors = torch.randn(1, 2, 16, 3).cuda()
templates_mask = torch.ones_like(templates_seq).bool().cuda()
templates_sidechains = torch.randn(1, 2, 16, 3).cuda() # unit vectors of difference of C and C-alpha coordinates
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
templates_seq = templates_seq,
templates_mask = templates_mask,
templates_coors = templates_coors,
templates_sidechains = templates_sidechains
)
Equivariant Attention
I have prepared a reimplementation of SE3 Transformer, as explained by Fabian Fuchs in a speculatory blogpost.
In addition, a new paper from Victor and Welling uses invariant features for E(n) equivariance, reaching SOTA and outperforming SE3 Transformer at a number of benchmarks, while being much faster. I have taken the main ideas from this paper and modified it to become a transformer (added attention to both features and coordinate updates).
All three of the equivariant networks above have been integrated and are available for use in the repository for atomic coordinate refinement by simply setting one hyperparameter structure_module_type
.
-
se3
SE3 Transformer -
egnn
EGNN
Of interest to readers, each of the three frameworks have also been validated by researchers on related problems.
Testing
$ python setup.py test
Data
This library will use the awesome work by Jonathan King at this repository. Thank you Jonathan 🙏!
We also have the MSA data, all ~3.5 TB worth, downloaded and hosted by Archivist, who owns The-Eye project. (They also host the data and models for Eleuther AI) Please consider a donation if you find them helpful.
$ curl -s https://the-eye.eu/eleuther_staging/globus_stuffs/tree.txt
Speculation
https://xukui.cn/alphafold2.html
Recent works by competing labs
https://www.biorxiv.org/content/10.1101/2020.12.10.419994v1.full.pdf
https://pubmed.ncbi.nlm.nih.gov/33637700/
tFold presentation, from Tencent AI labs
External packages
- Final step - Fast Relax - Installation Instructions:
- Download the pyrosetta wheel from: http://www.pyrosetta.org/dow (select appropiate version) - beware the file is heavy (approx 1.2 Gb)
- The download should be free for anyone with an academic email
- Bash >
cd downloads_folder
>pip install pyrosetta_wheel_filename.whl
- Download the pyrosetta wheel from: http://www.pyrosetta.org/dow (select appropiate version) - beware the file is heavy (approx 1.2 Gb)
Citations
@misc{unpublished2021alphafold2,
title = {Alphafold2},
author = {John Jumper},
year = {2020},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
@article{Rao2021.02.12.430858,
author = {Rao, Roshan and Liu, Jason and Verkuil, Robert and Meier, Joshua and Canny, John F. and Abbeel, Pieter and Sercu, Tom and Rives, Alexander},
title = {MSA Transformer},
year = {2021},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/02/13/2021.02.12.430858},
journal = {bioRxiv}
}
@article {Rives622803,
author = {Rives, Alexander and Goyal, Siddharth and Meier, Joshua and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
title = {Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
year = {2019},
doi = {10.1101/622803},
publisher = {Cold Spring Harbor Laboratory},
journal = {bioRxiv}
}
@article {Elnaggar2020.07.12.199554,
author = {Elnaggar, Ahmed and Heinzinger, Michael and Dallago, Christian and Rehawi, Ghalia and Wang, Yu and Jones, Llion and Gibbs, Tom and Feher, Tamas and Angerer, Christoph and Steinegger, Martin and BHOWMIK, DEBSINDHU and Rost, Burkhard},
title = {ProtTrans: Towards Cracking the Language of Life{\textquoteright}s Code Through Self-Supervised Deep Learning and High Performance Computing},
elocation-id = {2020.07.12.199554},
year = {2021},
doi = {10.1101/2020.07.12.199554},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/05/04/2020.07.12.199554},
eprint = {https://www.biorxiv.org/content/early/2021/05/04/2020.07.12.199554.full.pdf},
journal = {bioRxiv}
}
@misc{king2020sidechainnet,
title = {SidechainNet: An All-Atom Protein Structure Dataset for Machine Learning},
author = {Jonathan E. King and David Ryan Koes},
year = {2020},
eprint = {2010.08162},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
@misc{alquraishi2019proteinnet,
title = {ProteinNet: a standardized data set for machine learning of protein structure},
author = {Mohammed AlQuraishi},
year = {2019},
eprint = {1902.00249},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
@misc{gomez2017reversible,
title = {The Reversible Residual Network: Backpropagation Without Storing Activations},
author = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
year = {2017},
eprint = {1707.04585},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{fuchs2021iterative,
title = {Iterative SE(3)-Transformers},
author = {Fabian B. Fuchs and Edward Wagstaff and Justas Dauparas and Ingmar Posner},
year = {2021},
eprint = {2102.13419},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@article{Gao_2020,
title = {Kronecker Attention Networks},
ISBN = {9781450379984},
url = {http://dx.doi.org/10.1145/3394486.3403065},
DOI = {10.1145/3394486.3403065},
journal = {Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
publisher = {ACM},
author = {Gao, Hongyang and Wang, Zhengyang and Ji, Shuiwang},
year = {2020},
month = {Jul}
}
@article {Si2021.05.10.443415,
author = {Si, Yunda and Yan, Chengfei},
title = {Improved protein contact prediction using dimensional hybrid residual networks and singularity enhanced loss function},
elocation-id = {2021.05.10.443415},
year = {2021},
doi = {10.1101/2021.05.10.443415},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/05/11/2021.05.10.443415},
eprint = {https://www.biorxiv.org/content/early/2021/05/11/2021.05.10.443415.full.pdf},
journal = {bioRxiv}
}
@article {Costa2021.06.02.446809,
author = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam},
title = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers},
year = {2021},
doi = {10.1101/2021.06.02.446809},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809},
eprint = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf},
journal = {bioRxiv}
}
@article {Baek2021.06.14.448402,
author = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David},
title = {Accurate prediction of protein structures and interactions using a 3-track network},
year = {2021},
doi = {10.1101/2021.06.14.448402},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402},
eprint = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf},
journal = {bioRxiv}
}