Lucidrains-系列项目源码解析-四-
Lucidrains 系列项目源码解析(四)
.\lucidrains\big-sleep\big_sleep\cli.py
# 导入 fire 模块,用于命令行接口
import fire
# 导入 random 模块并重命名为 rnd
import random as rnd
# 从 big_sleep 模块中导入 Imagine 类和 version 变量
from big_sleep import Imagine, version
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从当前目录下的 version 模块中导入 __version__ 变量
from .version import __version__;
# 定义 train 函数,接受多个参数
def train(
text=None,
img=None,
text_min="",
lr = .07,
image_size = 512,
gradient_accumulate_every = 1,
epochs = 20,
iterations = 1050,
save_every = 50,
overwrite = False,
save_progress = False,
save_date_time = False,
bilinear = False,
open_folder = True,
seed = 0,
append_seed = False,
random = False,
torch_deterministic = False,
max_classes = None,
class_temperature = 2.,
save_best = False,
experimental_resample = False,
ema_decay = 0.5,
num_cutouts = 128,
center_bias = False,
larger_model = False
):
# 打印版本信息
print(f'Starting up... v{__version__}')
# 如果 random 为 True,则生成一个随机种子
if random:
seed = rnd.randint(0, 1e6)
# 创建 Imagine 对象,传入各种参数
imagine = Imagine(
text=text,
img=img,
text_min=text_min,
lr = lr,
image_size = image_size,
gradient_accumulate_every = gradient_accumulate_every,
epochs = epochs,
iterations = iterations,
save_every = save_every,
save_progress = save_progress,
bilinear = bilinear,
seed = seed,
append_seed = append_seed,
torch_deterministic = torch_deterministic,
open_folder = open_folder,
max_classes = max_classes,
class_temperature = class_temperature,
save_date_time = save_date_time,
save_best = save_best,
experimental_resample = experimental_resample,
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
larger_clip = larger_model
)
# 如果不覆盖且文件已存在,则询问是否覆盖
if not overwrite and imagine.filename.exists():
answer = input('Imagined image already exists, do you want to overwrite? (y/n) ').lower()
if answer not in ('yes', 'y'):
exit()
# 调用 Imagine 对象的方法开始训练
imagine()
# 定义主函数
def main():
# 使用 fire 模块创建命令行接口,传入 train 函数
fire.Fire(train)
.\lucidrains\big-sleep\big_sleep\ema.py
# 导入必要的库
from copy import deepcopy
import torch
from torch import nn
# 定义指数移动平均类
class EMA(nn.Module):
# 初始化函数,接受模型和衰减率作为参数
def __init__(self, model, decay):
super().__init__()
self.model = model
self.decay = decay
# 注册缓冲区
self.register_buffer('accum', torch.tensor(1.))
self._biased = deepcopy(self.model)
self.average = deepcopy(self.model)
# 将偏置参数和平均参数初始化为零
for param in self._biased.parameters():
param.detach_().zero_()
for param in self.average.parameters():
param.detach_().zero_()
# 更新参数
self.update()
# 更新函数,用于更新指数移动平均
@torch.no_grad()
def update(self):
assert self.training, 'Update should only be called during training'
# 更新累积值
self.accum *= self.decay
# 获取模型参数、偏置参数和平均参数
model_params = dict(self.model.named_parameters())
biased_params = dict(self._biased.named_parameters())
average_params = dict(self.average.named_parameters())
assert model_params.keys() == biased_params.keys() == average_params.keys(), f'Model parameter keys incompatible with EMA stored parameter keys'
# 更新参数
for name, param in model_params.items():
biased_params[name].mul_(self.decay)
biased_params[name].add_((1 - self.decay) * param)
average_params[name].copy_(biased_params[name])
average_params[name].div_(1 - self.accum)
# 获取模型缓冲区、偏置缓冲区和平均缓冲区
model_buffers = dict(self.model.named_buffers())
biased_buffers = dict(self._biased.named_buffers())
average_buffers = dict(self.average.named_buffers())
assert model_buffers.keys() == biased_buffers.keys() == average_buffers.keys()
# 更新缓冲区
for name, buffer in model_buffers.items():
biased_buffers[name].copy_(buffer)
average_buffers[name].copy_(buffer)
# 前向传播函数,根据是否处于训练状态返回模型或平均模型的输出
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
return self.average(*args, **kwargs)
.\lucidrains\big-sleep\big_sleep\resample.py
"""Good differentiable image resampling for PyTorch."""
# 导入所需的库
from functools import update_wrapper
import math
import torch
from torch.nn import functional as F
# 定义 sinc 函数
def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
# 定义 lanczos 函数
def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
return out / out.sum()
# 定义 ramp 函数
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
# 定义 odd 函数
def odd(fn):
return update_wrapper(lambda x: torch.sign(x) * fn(abs(x)), fn)
# 定义将输入转换为线性 sRGB 的函数
def _to_linear_srgb(input):
cond = input <= 0.04045
a = input / 12.92
b = ((input + 0.055) / 1.055)**2.4
return torch.where(cond, a, b)
# 定义将输入转换为非线性 sRGB 的函数
def _to_nonlinear_srgb(input):
cond = input <= 0.0031308
a = 12.92 * input
b = 1.055 * input**(1/2.4) - 0.055
return torch.where(cond, a, b)
# 使用 odd 函数包装 _to_linear_srgb 函数和 _to_nonlinear_srgb 函数
to_linear_srgb = odd(_to_linear_srgb)
to_nonlinear_srgb = odd(_to_nonlinear_srgb)
# 定义 resample 函数
def resample(input, size, align_corners=True, is_srgb=False):
n, c, h, w = input.shape
dh, dw = size
# 如果 is_srgb 为 True,则将输入转换为线性 sRGB
if is_srgb:
input = to_linear_srgb(input)
input = input.view([n * c, 1, h, w])
# 如果目标高度小于原始高度
if dh < h:
kernel_h = lanczos(ramp(dh / h, 3), 3).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
# 如果目标宽度小于原始宽度
if dw < w:
kernel_w = lanczos(ramp(dw / w, 3), 3).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.view([n, c, h, w])
input = F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
# 如果 is_srgb 为 True,则将输出转换为非线性 sRGB
if is_srgb:
input = to_nonlinear_srgb(input)
return input
.\lucidrains\big-sleep\big_sleep\version.py
# 定义变量 __version__,赋值为字符串 '0.9.1'
__version__ = '0.9.1'
.\lucidrains\big-sleep\big_sleep\__init__.py
# 从 big_sleep.big_sleep 模块中导入 BigSleep 和 Imagine 类
from big_sleep.big_sleep import BigSleep, Imagine

artificial intelligence

cosmic love and attention

fire in the sky

a pyramid made of ice

a lonely house in the woods

marriage in the mountains

lantern dangling from a tree in a foggy graveyard

a vivid dream

balloons over the ruins of a city

the death of the lonesome astronomer - by moirage

the tragic intimacy of the eternal conversation with oneself - by moirage

demon fire - by WiseNat
Big Sleep
Ryan Murdock has done it again, combining OpenAI's CLIP and the generator from a BigGAN! This repository wraps up his work so it is easily accessible to anyone who owns a GPU.
You will be able to have the GAN dream up images using natural language with a one-line command in the terminal.
User-made notebook with bugfixes and added features, like google drive integration
Install
$ pip install big-sleep
Usage
$ dream "a pyramid made of ice"
Images will be saved to wherever the command is invoked
Advanced
You can invoke this in code with
from big_sleep import Imagine
dream = Imagine(
text = "fire in the sky",
lr = 5e-2,
save_every = 25,
save_progress = True
)
dream()
You can now train more than one phrase using the delimiter "|"
Train on Multiple Phrases
In this example we train on three phrases:
an armchair in the form of pikachuan armchair imitating pikachuabstract
from big_sleep import Imagine
dream = Imagine(
text = "an armchair in the form of pikachu|an armchair imitating pikachu|abstract",
lr = 5e-2,
save_every = 25,
save_progress = True
)
dream()
Penalize certain prompts as well!
In this example we train on the three phrases from before,
and penalize the phrases:
blurzoom
from big_sleep import Imagine
dream = Imagine(
text = "an armchair in the form of pikachu|an armchair imitating pikachu|abstract",
text_min = "blur|zoom",
)
dream()
You can also set a new text by using the .set_text(<str>) command
dream.set_text("a quiet pond underneath the midnight moon")
And reset the latents with .reset()
dream.reset()
To save the progression of images during training, you simply have to supply the --save-progress flag
$ dream "a bowl of apples next to the fireplace" --save-progress --save-every 100
Due to the class conditioned nature of the GAN, Big Sleep often steers off the manifold into noise. You can use a flag to save the best high scoring image (per CLIP critic) to {filepath}.best.png in your folder.
$ dream "a room with a view of the ocean" --save-best
Larger model
If you have enough memory, you can also try using a bigger vision model released by OpenAI for improved generations.
$ dream "storm clouds rolling in over a white barnyard" --larger-model
Experimentation
You can set the number of classes that you wish to restrict Big Sleep to use for the Big GAN with the --max-classes flag as follows (ex. 15 classes). This may lead to extra stability during training, at the cost of lost expressivity.
$ dream 'a single flower in a withered field' --max-classes 15
Alternatives
Deep Daze - CLIP and a deep SIREN network
Citations
@misc{unpublished2021clip,
title = {CLIP: Connecting Text and Images},
author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
year = {2021}
}
@misc{brock2019large,
title = {Large Scale GAN Training for High Fidelity Natural Image Synthesis},
author = {Andrew Brock and Jeff Donahue and Karen Simonyan},
year = {2019},
eprint = {1809.11096},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\big-sleep\setup.py
# 导入 sys 模块
import sys
# 从 setuptools 模块中导入 setup 和 find_packages 函数
from setuptools import setup, find_packages
# 将 'big_sleep' 目录添加到 sys.path 的最前面
sys.path[0:0] = ['big_sleep']
# 从 version 模块中导入 __version__ 变量
from version import __version__
# 设置包的元数据
setup(
# 包的名称
name = 'big-sleep',
# 查找并包含所有包
packages = find_packages(),
# 包含所有数据文件
include_package_data = True,
# 设置入口点,命令行脚本为 'dream'
entry_points={
'console_scripts': [
'dream = big_sleep.cli:main',
],
},
# 版本号
version = __version__,
# 许可证
license='MIT',
# 描述
description = 'Big Sleep',
# 作者
author = 'Ryan Murdock, Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/big-sleep',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'text to image',
'generative adversarial networks'
],
# 安装依赖
install_requires=[
'torch>=1.7.1',
'einops>=0.3',
'fire',
'ftfy',
'pytorch-pretrained-biggan',
'regex',
'torchvision>=0.8.2',
'tqdm'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\big-sleep\test\multi_prompt_minmax.py
# 导入所需的库
import time
import shutil
import torch
from big_sleep import Imagine
# 初始化终止标志
terminate = False
# 信号处理函数,设置终止标志为True
def signal_handling(signum,frame):
global terminate
terminate = True
# 设定尝试次数
num_attempts = 4
# 循环尝试生成图像
for attempt in range(num_attempts):
# 创建Imagine对象,用于生成图像
dream = Imagine(
text = "an armchair in the form of pikachu\\an armchair imitating pikachu\\abstract",
text_min = "blur\\zoom",
lr = 7e-2,
image_size = 512,
gradient_accumulate_every = 1,
save_every = 50,
epochs = 5,
iterations = 50,
save_progress = False,
bilinear = False,
open_folder = False,
seed = None,
torch_deterministic = False,
max_classes = 20,
class_temperature = 2.,
save_date_time = False,
save_best = True,
experimental_resample = True,
ema_decay = 0.99
)
# 生成图像
dream()
# 复制生成的最佳图像
shutil.copy(dream.textpath + ".best.png", f"{attempt}.png")
try:
# 等待2秒
time.sleep(2)
# 删除dream对象
del dream
# 再次等待2秒
time.sleep(2)
# 清空GPU缓存
torch.cuda.empty_cache()
except Exception:
# 出现异常时,仅清空GPU缓存
torch.cuda.empty_cache()
.\lucidrains\bit-diffusion\bit_diffusion\bit_diffusion.py
# 导入所需的库
import math
from pathlib import Path
from functools import partial
from multiprocessing import cpu_count
import torch
from torch import nn, einsum
from torch.special import expm1
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchvision import transforms as T, utils
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA
from accelerate import Accelerator
# 常量定义
BITS = 8
# 辅助函数
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cycle(dl):
while True:
for data in dl:
yield data
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def convert_image_to(pil_img_type, image):
if image.mode != pil_img_type:
return image.convert(pil_img_type)
return image
# 小型辅助模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# 位置嵌入
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
# 构建块模块
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
# 初始化函数,定义神经网络结构
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
# 调用父类的初始化函数
super().__init__()
# 如果存在时间嵌入维度,则创建包含激活函数和线性层的序列模块
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
# 创建第一个块
self.block1 = Block(dim, dim_out, groups = groups)
# 创建第二个块
self.block2 = Block(dim_out, dim_out, groups = groups)
# 如果输入维度和输出维度不相等,则使用卷积层进行维度转换,否则使用恒等映射
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 前向传播函数
def forward(self, x, time_emb = None):
scale_shift = None
# 如果存在时间嵌入模块和时间嵌入向量,则进行处理
if exists(self.mlp) and exists(time_emb):
# 对时间嵌入向量进行处理
time_emb = self.mlp(time_emb)
# 重新排列时间嵌入向量的维度
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
# 将时间嵌入向量分成两部分,用于缩放和平移
scale_shift = time_emb.chunk(2, dim = 1)
# 使用第一个块处理输入数据
h = self.block1(x, scale_shift = scale_shift)
# 使用第二个块处理第一个块的输出
h = self.block2(h)
# 返回块处理后的结果与输入数据经过维度转换后的结果的和
return h + self.res_conv(x)
# 定义一个线性注意力模块,继承自 nn.Module 类
class LinearAttention(nn.Module):
# 初始化函数,接受维度 dim、头数 heads 和头维度 dim_head 作为参数
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
# 缩放因子为头维度的倒数
self.scale = dim_head ** -0.5
# 头数
self.heads = heads
# 隐藏维度为头维度乘以头数
hidden_dim = dim_head * heads
# 将输入转换为查询、键、值的形式
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
# 输出转换层,包含一个卷积层和一个 LayerNorm 层
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
# 前向传播函数
def forward(self, x):
# 获取输入张量的形状信息
b, c, h, w = x.shape
# 将输入通过查询、键、值转换层,并按维度 1 切分为三部分
qkv = self.to_qkv(x).chunk(3, dim = 1)
# 将查询、键、值按照指定维度重排
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
# 对查询和键进行 softmax 操作
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
# 对查询进行缩放
q = q * self.scale
# 对值进行归一化
v = v / (h * w)
# 计算上下文信息
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
# 计算输出
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
# 重排输出张量的维度
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
# 定义一个注意力模块,继承自 nn.Module 类
class Attention(nn.Module):
# 初始化函数,接受维度 dim、头数 heads 和头维度 dim_head 作为参数
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
# 缩放因子为头维度的倒数
self.scale = dim_head ** -0.5
# 头数
self.heads = heads
# 隐藏维度为头维度乘以头数
hidden_dim = dim_head * heads
# 将输入转换为查询、键、值的形式
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
# 输出转换层,包含一个卷积层
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
# 前向传播函数
def forward(self, x):
# 获取输入张量的形状信息
b, c, h, w = x.shape
# 将输入通过查询、键、值转换层,并按维度 1 切分为三部分
qkv = self.to_qkv(x).chunk(3, dim = 1)
# 将查询、键、值按照指定维度重排
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
# 对查询进行缩放
q = q * self.scale
# 计算相似度
sim = einsum('b h d i, b h d j -> b h i j', q, k)
# 对相似度进行 softmax 操作
attn = sim.softmax(dim = -1)
# 计算输出
out = einsum('b h i j, b h d j -> b h i d', attn, v)
# 重排输出张量的维度
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
# 定义一个 Unet 模型,继承自 nn.Module 类
class Unet(nn.Module):
# 初始化函数,接受维度 dim、初始维度 init_dim、维度倍增 dim_mults、通道数 channels、位数 bits、ResNet 块组数 resnet_block_groups 和学习的正弦维度 learned_sinusoidal_dim 作为参数
def __init__(
self,
dim,
init_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
bits = BITS,
resnet_block_groups = 8,
learned_sinusoidal_dim = 16
):
# 调用父类的构造函数
super().__init__()
# 确定维度
channels *= bits
self.channels = channels
input_channels = channels * 2
# 初始化维度
init_dim = default(init_dim, dim)
# 创建一个卷积层,输入通道数为input_channels,输出通道数为init_dim,卷积核大小为7,填充为3
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
# 计算不同层次的维度
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# 使用ResnetBlock类创建一个部分函数block_klass,其中groups参数为resnet_block_groups
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 时间嵌入
time_dim = dim * 4
# 创建一个LearnedSinusoidalPosEmb对象sinu_pos_emb
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
fourier_dim = learned_sinusoidal_dim + 1
# 创建一个包含线性层和激活函数的神经网络模块time_mlp
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 层
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
# 遍历不同层次的维度
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
# 向downs列表中添加模块列表
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
# 创建一个ResnetBlock对象mid_block1
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 创建一个包含注意力机制的Residual对象mid_attn
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
# 创建一个ResnetBlock对象mid_block2
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 反向遍历不同层次的维度
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
# ���ups列表中添加模块列表
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# 创建一个ResnetBlock对象final_res_block
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
# 创建一个卷积层final_conv,输入通道数为dim,输出通道数为channels,卷积核大小为1
self.final_conv = nn.Conv2d(dim, channels, 1)
def forward(self, x, time, x_self_cond = None):
# 如果x_self_cond为None,则创建一个与x相同形状的全零张量
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
# 在通道维度上拼接x_self_cond和x
x = torch.cat((x_self_cond, x), dim = 1)
# 将输入数据x通过init_conv卷积层
x = self.init_conv(x)
r = x.clone()
# 通过时间嵌入网络计算时间信息t
t = self.time_mlp(time)
h = []
# 遍历downs列表中的模块列表
for block1, block2, attn, downsample in self.downs:
# 通过block1进行处理
x = block1(x, t)
h.append(x)
# 通过block2进行处理
x = block2(x, t)
# 通过attn进行处理
x = attn(x)
h.append(x)
# 通过downsample进行处理
x = downsample(x)
# 通过mid_block1进行处理
x = self.mid_block1(x, t)
# 通过mid_attn进行处理
x = self.mid_attn(x)
# 通过mid_block2进行处理
x = self.mid_block2(x, t)
# 遍历ups列表中的模块列表
for block1, block2, attn, upsample in self.ups:
# 在通道维度上拼接x和h中的张量
x = torch.cat((x, h.pop()), dim = 1)
# 通过block1进行处理
x = block1(x, t)
# 在通道维度上拼接x和h中的张量
x = torch.cat((x, h.pop()), dim = 1)
# 通过block2进行处理
x = block2(x, t)
# 通过attn进行处理
x = attn(x)
# 通过upsample进行处理
x = upsample(x)
# 在通道维度上拼接x和r
x = torch.cat((x, r), dim = 1)
# 通过final_res_block进行处理
x = self.final_res_block(x, t)
return self.final_conv(x)
# 将十进制数转换为位表示,并反向转换
def decimal_to_bits(x, bits = BITS):
"""将范围在0到1之间的图像张量转换为范围在-1到1之间的位张量"""
device = x.device
# 将图像张量乘以255并取整,限制在0到255之间
x = (x * 255).int().clamp(0, 255)
# 创建位掩码
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device)
mask = rearrange(mask, 'd -> d 1 1')
x = rearrange(x, 'b c h w -> b c 1 h w')
# 将图像张量转换为位张量
bits = ((x & mask) != 0).float()
bits = rearrange(bits, 'b c d h w -> b (c d) h w')
bits = bits * 2 - 1
return bits
def bits_to_decimal(x, bits = BITS):
"""将范围在-1到1之间的位转换为范围在0到1之间的图像张量"""
device = x.device
# 将位张量转换为整数张量
x = (x > 0).int()
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device, dtype = torch.int32)
mask = rearrange(mask, 'd -> d 1 1')
x = rearrange(x, 'b (c d) h w -> b c d h w', d = bits)
dec = reduce(x * mask, 'b c d h w -> b c h w', 'sum')
return (dec / 255).clamp(0., 1.)
# 位扩散类
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # 不确定这是否考虑了在离散版本中将beta剪切为0.999
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
class BitDiffusion(nn.Module):
def __init__(
self,
model,
*,
image_size,
timesteps = 1000,
use_ddim = False,
noise_schedule = 'cosine',
time_difference = 0.,
bit_scale = 1.
):
super().__init__()
self.model = model
self.channels = self.model.channels
self.image_size = image_size
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
self.bit_scale = bit_scale
self.timesteps = timesteps
self.use_ddim = use_ddim
# 在论文中提出���与time_next相加,作为修复自我条件不足和在采样时间步数小于400时降低FID的方法
self.time_difference = time_difference
@property
def device(self):
return next(self.model.parameters()).device
def get_sampling_timesteps(self, batch, *, device):
times = torch.linspace(1., 0., self.timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
@torch.no_grad()
# 从 DDPM 模型中采样生成图像
def ddpm_sample(self, shape, time_difference = None):
# 获取批次大小和设备信息
batch, device = shape[0], self.device
# 设置时间差,默认为 self.time_difference
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步骤对
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成随机噪声图像
img = torch.randn(shape, device=device)
x_start = None
# 遍历时间步骤对
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):
# 添加时间延迟
time_next = (time_next - self.time_difference).clamp(min = 0.)
# 获取噪声条件
noise_cond = self.log_snr(time)
# 获取预测的 x0
x_start = self.model(img, noise_cond, x_start)
# 限制 x0 的范围
x_start.clamp_(-self.bit_scale, self.bit_scale)
# 获取 log(snr)
log_snr = self.log_snr(time)
log_snr_next = self.log_snr(time_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
# 获取时间和下一个时间的 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# 推导后验均值和方差
c = -expm1(log_snr - log_snr_next)
mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
variance = (sigma_next ** 2) * c
log_variance = log(variance)
# 获取噪声
noise = torch.where(
rearrange(time_next > 0, 'b -> b 1 1 1'),
torch.randn_like(img),
torch.zeros_like(img)
)
img = mean + (0.5 * log_variance).exp() * noise
return bits_to_decimal(img)
# 无梯度计算的 DDIM 模型采样函数
@torch.no_grad()
def ddim_sample(self, shape, time_difference = None):
# 获取批次大小和设备信息
batch, device = shape[0], self.device
# 设置时间差,默认为 self.time_difference
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步骤对
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成随机噪声图像
img = torch.randn(shape, device = device)
x_start = None
# 遍历时间步骤对
for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):
# 添加时间延迟
times_next = (times_next - time_difference).clamp(min = 0.)
# 获取时间和噪声水平
log_snr = self.log_snr(times)
log_snr_next = self.log_snr(times_next)
padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)
# 预测 x0
x_start = self.model(img, log_snr, x_start)
# 限制 x0 的范围
x_start.clamp_(-self.bit_scale, self.bit_scale)
# 获取预测的噪声
pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
# 计算下一个 x
img = x_start * alpha_next + pred_noise * sigma_next
return bits_to_decimal(img)
# 采样函数,根据是否使用 DDIM 选择不同的采样方法
@torch.no_grad()
def sample(self, batch_size = 16):
image_size, channels = self.image_size, self.channels
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size))
# 定义前向传播函数,接受图像和其他参数
def forward(self, img, *args, **kwargs):
# 解包图像的形状和设备信息
batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
# 断言图像的高度和宽度必须为指定的图像大小
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 生成随机采样时间
times = torch.zeros((batch,), device=device).float().uniform_(0, 1.)
# 将图像转换为比特表示
img = decimal_to_bits(img) * self.bit_scale
# 生成噪声样本
noise = torch.randn_like(img)
# 计算噪声水平
noise_level = self.log_snr(times)
# 将噪声水平填充到与图像相同的维度
padded_noise_level = right_pad_dims_to(img, noise_level)
# 将噪声水平转换为 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level)
# 添加噪声到图像
noised_img = alpha * img + sigma * noise
# 如果进行自条件训练,50%的概率从当前时间预测 x_start,并使用 unet 进行条件
# 这种技术会使训练速度减慢 25%,但似乎显著降低 FID
self_cond = None
if torch.rand((1)) < 0.5:
with torch.no_grad():
# 使用模型预测 x_start,并分离计算图
self_cond = self.model(noised_img, noise_level).detach_()
# 预测并进行梯度下降步骤
pred = self.model(noised_img, noise_level, self_cond)
# 返回预测值和真实值的均方误差损失
return F.mse_loss(pred, img)
# dataset classes
# 定义 Dataset 类,继承自 torch.utils.data.Dataset
class Dataset(Dataset):
# 初始化函数
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False,
pil_img_type = None
):
# 调用父类的初始化函数
super().__init__()
# 设置属性
self.folder = folder
self.image_size = image_size
# 获取指定扩展名的所有文件路径
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
# 部分转换函数
maybe_convert_fn = partial(convert_image_to, pil_img_type) if exists(pil_img_type) else nn.Identity()
# 数据转换操作
self.transform = T.Compose([
T.Lambda(maybe_convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
])
# 返回数据集的长度
def __len__(self):
return len(self.paths)
# 获取指定索引的数据
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# trainer class
# 定义 Trainer 类
class Trainer(object):
# 初始化函数
def __init__(
self,
diffusion_model,
folder,
*,
train_batch_size = 16,
gradient_accumulate_every = 1,
augment_horizontal_flip = True,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
amp = False,
mixed_precision_type = 'fp16',
split_batches = True,
pil_img_type = None
):
# 调用父类的初始化函数
super().__init__()
# 初始化加速器
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = mixed_precision_type if amp else 'no'
)
# 设置扩散模型
self.model = diffusion_model
# 检查样本数量是否有整数平方根
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.image_size = diffusion_model.image_size
# dataset and dataloader
# 创建数据集
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, pil_img_type = pil_img_type)
# 创建数据加载器
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
# 准备数据加载器
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# optimizer
# 创建优化器
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
# for logging results in a folder periodically
# 如果是主进程
if self.accelerator.is_main_process:
# 创建指数移动平均模型
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
# 设置结果文件夹路径
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
# step counter state
# 步数计数器
self.step = 0
# prepare model, dataloader, optimizer with accelerator
# 使用加速器准备模型、数据加载器和优化器
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
# 保存模型
def save(self, milestone):
# 如果不是本地主进程,则返回
if not self.accelerator.is_local_main_process:
return
# 保存模型相关数据
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
}
# 将数据保存到文件
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
# 加载指定里程碑的模型数据
def load(self, milestone):
# 从文件中加载模型数据
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))
# 获取未包装的模型对象
model = self.accelerator.unwrap_model(self.model)
# 加载模型的状态字典
model.load_state_dict(data['model'])
# 设置当前步数为加载的数据中的步数
self.step = data['step']
# 加载优化器的状态字典
self.opt.load_state_dict(data['opt'])
# 加载指数移动平均模型的状态字典
self.ema.load_state_dict(data['ema'])
# 如果加速器的缩放器和加载的数据中的缩放器都存在,则加载缩放器的状态字典
if exists(self.accelerator.scaler) and exists(data['scaler']):
self.accelerator.scaler.load_state_dict(data['scaler'])
# 训练模型
def train(self):
# 获取加速器和设备
accelerator = self.accelerator
device = accelerator.device
# 使用 tqdm 显示训练进度条
with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:
# 在未达到训练步数之前循环训练
while self.step < self.train_num_steps:
total_loss = 0.
# 根据梯度累积的次数循环
for _ in range(self.gradient_accumulate_every):
# 从数据加载器中获取数据并移动到设备上
data = next(self.dl).to(device)
# 使用自动混合精度计算模型的损失
with self.accelerator.autocast():
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
# 反向传播计算梯度
self.accelerator.backward(loss)
# 更新进度条显示损失值
pbar.set_description(f'loss: {total_loss:.4f}')
# 等待所有进程完成当前步骤
accelerator.wait_for_everyone()
# 更新优化器参数
self.opt.step()
self.opt.zero_grad()
# 等待所有进程完成当前步骤
accelerator.wait_for_everyone()
# 如果是主进程
if accelerator.is_main_process:
# 将指数移动平均模型移动到设备上并更新
self.ema.to(device)
self.ema.update()
# 如果步数不为0���可以保存和采样
if self.step != 0 and self.step % self.save_and_sample_every == 0:
# 将指数移动平均模型设置为评估模式
self.ema.ema_model.eval()
# 使用无梯度计算生成样本图像
with torch.no_grad():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
# 拼接所有生成的图像并保存
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow=int(math.sqrt(self.num_samples)))
self.save(milestone)
# 更新步数并进度条
self.step += 1
pbar.update(1)
# 打印训练完成信息
accelerator.print('training complete')
.\lucidrains\bit-diffusion\bit_diffusion\__init__.py
# 从 bit_diffusion.bit_diffusion 模块中导入 Unet, BitDiffusion, Trainer 类
from bit_diffusion.bit_diffusion import Unet, BitDiffusion, Trainer

Bit Diffusion - Pytorch
Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
It seems like they missed the mark for text, but the research direction still seems promising. I think a clean repository will do the research community a lot of benefits for those branching off from here.
Install
$ pip install bit-diffusion
Usage
from bit_diffusion import Unet, Trainer, BitDiffusion
model = Unet(
dim = 32,
channels = 3,
dim_mults = (1, 2, 4, 8),
).cuda()
bit_diffusion = BitDiffusion(
model,
image_size = 128,
timesteps = 100,
time_difference = 0.1, # they found in the paper that at lower number of timesteps, a time difference during sampling of greater than 0 helps FID. as timesteps increases, this time difference can be set to 0 as it does not help
use_ddim = True # use ddim
).cuda()
trainer = Trainer(
bit_diffusion,
'/path/to/your/data', # path to your folder of images
results_folder = './results', # where to save results
num_samples = 16, # number of samples
train_batch_size = 4, # training batch size
gradient_accumulate_every = 4, # gradient accumulation
train_lr = 1e-4, # learning rate
save_and_sample_every = 1000, # how often to save and sample
train_num_steps = 700000, # total training steps
ema_decay = 0.995, # exponential moving average decay
)
trainer.train()
Results will be saved periodically to the ./results folder
If you would like to experiment with the Unet and BitDiffusion class outside the Trainer
import torch
from bit_diffusion import Unet, BitDiffusion
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8)
)
bit_diffusion = BitDiffusion(
model,
image_size = 128,
timesteps = 1000
)
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
loss = bit_diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = bit_diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)
Citations
@article{Chen2022AnalogBG,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey E. Hinton},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.04202}
}
.\lucidrains\bit-diffusion\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包的名称
name = 'bit-diffusion',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.4',
# 许可证
license='MIT',
# 描述
description = 'Bit Diffusion - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/bit-diffusion',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'denoising diffusion'
],
# 安装依赖
install_requires=[
'accelerate',
'einops',
'ema-pytorch',
'pillow',
'torch>=1.12.0',
'torchvision',
'tqdm'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\block-recurrent-transformer-pytorch\block_recurrent_transformer_pytorch\block_recurrent_transformer_pytorch.py
# 导入数学库
import math
# 从 random 模块中导入 random 函数
from random import random
# 从 functools 模块中导入 wraps 和 partial 函数
from functools import wraps, partial
# 从 itertools 模块中导入 zip_longest 函数
from itertools import zip_longest
# 从 collections 模块中导入 namedtuple 和 defaultdict 类
from collections import namedtuple, defaultdict
# 从 packaging 模块中导入 version 类
from packaging import version
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange、repeat、pack、unpack 函数和 Rearrange 类
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# 从 beartype 库中导入 beartype 函数
from beartype import beartype
# 从 beartype.door 模块中导入 is_bearable 函数
from beartype.door import is_bearable
# 从 beartype.typing 模块中导入 Optional、List、Tuple 类
from beartype.typing import Optional, List, Tuple
# helpers
# 判断值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 判断张量是否为空
def is_empty(t: torch.Tensor):
return t.numel() == 0
# 将输入转换为元组
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 判断列表中的元素是否唯一
def all_unique(arr):
return len(arr) == len(set(arr))
# 评估装饰器
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# 仅执行一次的装饰器
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 仅打印一次的装饰器
print_once = once(print)
# 过滤掉空值
def compact(arr):
return [*filter(exists, arr)]
# 对张量列表进行按位与操作
def and_reduce(arr: List[torch.Tensor]):
if len(arr) == 0:
return None
head, *rest = arr
for t in rest:
head = head & t
return head
# 安全拼接张量
def safe_cat(*args, dim = 1):
args = compact(args)
if len(args) == 0:
return None
return torch.cat(args, dim = dim)
# 判断是否可以整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 计算张量的 L2 范数
def l2norm(t):
return F.normalize(t, dim = -1)
# 将张量打包成指定模式
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 在指定维度上填充张量
def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# 无偏置的 Layernorm 类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 采样辅助函数
# 计算张量的对数
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# Gumbel 采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# Top-k 采样
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 旋转位置嵌入类
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
width,
scale_base = 512,
theta = 10000
):
super().__init__()
self.width = width
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent = False)
self.scale_base = scale_base
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer('scale', scale, persistent = False)
self.register_buffer('cached_freqs', None, persistent = False)
self.register_buffer('cached_scales', None, persistent = False)
@property
def device(self):
return next(self.buffers()).device
# 定义一个方法用于前向传播,self代表类的实例
def forward(self):
# 获取设备和序列长度
device, seq_len = self.device, self.width
# 如果已经存在缓存的频率信息
if exists(self.cached_freqs):
# 获取缓存的序列长度
cached_seq_len = self.cached_freqs.shape[-2]
# 如果缓存的序列长度大于等于当前序列长度,则直接返回缓存的频率和尺度
if cached_seq_len >= seq_len:
return self.cached_freqs[:seq_len], self.cached_scales[:seq_len]
# 生成一个序列t,长度为seq_len,设备为device,数据类型与self.inv_freq相同
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# 计算频率信息,使用torch.einsum进行张量乘法
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
# 将频率信息复制一份,拼接在一起,维度为-1
freqs = torch.cat((freqs, freqs), dim=-1)
# 计算尺度信息,根据公式计算得到
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
# 将尺度信息复制一份,拼接在一起,维度为-1
scale = torch.cat((scale, scale), dim=-1)
# 将频率信息和尺度信息注册为缓存,persistent=False表示不持久化
self.register_buffer('cached_freqs', freqs, persistent=False)
self.register_buffer('cached_scales', scale, persistent=False)
# 返回频率信息和尺度信息
return freqs, scale
# 将输入张量 x 沿着最后一个维度分成两部分 x1 和 x2
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
# 返回将 x2 逆时针旋转 180 度后与 x1 拼接的结果
return torch.cat((-x2, x1), dim=-1)
# 对输入张量 t 应用旋转位置编码 pos,并乘以缩放因子 scale
def apply_rotary_pos_emb(t, pos, scale = 1.):
# 如果未提供缩放因子,则默认为 1
scale = default(scale, 1.)
# 获取序列长度
seq_len = t.shape[-2]
# 断言位置编码的长度大于等于序列长度
assert pos.shape[-2] >= seq_len
# 截取位置编码,保留与序列长度相同的部分
pos = pos[-seq_len:]
# 如果缩放因子是张量,则断言其长度大于等于序列长度,并截取与序列长度相同的部分
if isinstance(scale, torch.Tensor):
assert scale.shape[-2] >= seq_len
scale = scale[-seq_len:]
# 返回应用旋转位置编码后的结果
return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
# 内存管理类
class MemoryManager(nn.Module):
def __init__(
self,
dim,
*,
layers = 1,
mem_lengths = 512,
compress_factors = 1
):
super().__init__()
# 将内存长度和压缩因子转换为元组形式
mem_lengths = cast_tuple(mem_lengths)
compress_factors = cast_tuple(compress_factors)
# 断言所有内存长度大于 0
assert all([mem_length > 0 for mem_length in mem_lengths])
# 断言内存长度和压缩因子长度相同
assert len(mem_lengths) == len(compress_factors)
# 断言层数大于等于 1
assert layers >= 1
self.mem_lengths = mem_lengths
self.compress_factors = compress_factors
# 初始化层列表
self.layers = nn.ModuleList([])
# 遍历层数
for _ in range(layers):
compress_fns = nn.ModuleList([])
# 遍历压缩因子
for compress_factor in compress_factors:
compress_fn = nn.Identity()
# 如果压缩因子大于 1,则使用卷积进行压缩
if compress_factor > 1:
compress_fn = nn.Sequential(
Rearrange('b n d -> b d n'),
nn.Conv1d(
dim * 2,
dim * 2,
compress_factor,
stride = compress_factor,
groups = 2
),
Rearrange('b d n -> b n d'),
)
compress_fns.append(compress_fn)
self.layers.append(compress_fns)
def forward(
self,
past_memories: List[torch.Tensor],
new_memories: List[torch.Tensor]
):
# 初始化一个空列表,用于存储下一个时间步的记忆
next_memories = []
# 遍历过去记忆、新记忆和压缩函数的组合
for past_memory, new_memory, compress_fns in zip_longest(past_memories, new_memories, self.layers):
# 处理当过去记忆和新记忆都不存在的情况
if not (exists(past_memory) or exists(new_memory)):
next_memories.append(None)
continue
next_memory = None
# 遍历记忆长度、压缩因子和压缩函数的组合
for mem_length, compress_factor, compress_fn in zip(self.mem_lengths, self.compress_factors, compress_fns):
# 获取给定压缩因子下的记忆 "current_memory"
current_memory = None
if exists(past_memory):
past_memory, current_memory = past_memory[..., :-mem_length, :], past_memory[..., -mem_length:, :]
# 基于初始化设置的压缩因子,压缩新进来的记忆
if (not is_empty(new_memory)) and compress_factor > 1:
# 确保记忆长度可以被压缩因子整除
new_mem_length = new_memory.shape[-2]
curtailed_length = (new_mem_length // compress_factor) * compress_factor
curtailed_slice = slice(-curtailed_length, None) if curtailed_length > 0 else slice(0, 0)
new_memory = new_memory[..., curtailed_slice, :]
# 压缩推送到下一阶段的记忆
if new_memory.shape[-2] > 0:
new_memory = rearrange(new_memory, 'm b n d -> b n (m d)')
new_memory = compress_fn(new_memory)
new_memory = rearrange(new_memory, 'b n (m d) -> m b n d', m = 2)
# FIFO 记忆队列
# 将新记忆添加到右侧
current_memory = safe_cat(current_memory, new_memory, dim = -2)
# "new" 记忆是相对于下一个压缩段的新记忆
new_memory, current_memory = current_memory[..., :-mem_length, :], current_memory[..., -mem_length:, :]
# 将新记忆连接到过去记忆的左侧
next_memory = safe_cat(current_memory, next_memory, dim = -2)
next_memories.append(next_memory)
return next_memories
# maybe flash attention, if using pytorch 2.0
# 定义一个命名元组 Config,包含三个布尔类型的配置参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# state container
# 定义状态容器类 StateContainer,继承自 nn.Module
class StateContainer(nn.Module):
def __init__(
self,
dim,
*,
num_state_vectors,
dim_head = 64,
heads = 8,
qk_rmsnorm = False,
qk_rmsnorm_scale = 8,
use_flash_attn = False
):
super().__init__()
assert num_state_vectors > 0
self.heads = heads
inner_dim = dim_head * heads
# 对状态进行归一化
self.state_norm = LayerNorm(dim)
# 定义线性层,用于将输入转换为查询向量
self.q_to_state = nn.Linear(dim, inner_dim, bias = False)
self.q_from_state = nn.Linear(dim, inner_dim, bias = False)
# 定义线性层,用于将状态转换为查询向量和键值对
self.state_to_q = nn.Linear(dim, inner_dim, bias = False)
self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)
# 初始化状态和位置编码
self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))
self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))
# 定义线性层,用于将输出转换为状态
self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)
# 定义注意力机制,用于状态之间的交叉注意力和自注意力
self.to_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)
self.state_self_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)
self.from_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)
# gating related parameters - using the fixed simple config
# 定义线性层和参数,用于门控机制
self.state_out_to_gate = nn.Linear(dim, dim)
self.learned_ema_beta = nn.Parameter(torch.randn(dim))
# since each read should be followed by a write, just store cache in the container
# 初始化缓存和下一个读取状态
self.cache = None
self.next_read_state = None
# 设置下一个读取状态
def set_next_read_state(
self,
states
):
if not exists(states):
states = self.init_state
self.next_read_state = (states,)
# 读取状态
def read(self, x):
assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state'
states, = self.next_read_state
self.next_read_state = None
# 对状态进行注意力前的归一化
normed_states = self.state_norm(states)
# 添加位置编码
normed_states = normed_states + self.state_pos_ids
# 获取查询向量用于交叉注意力
q_to_state = self.q_to_state(x)
q_to_state = rearrange(q_to_state, '... n (h d) -> ... h n d', h = self.heads)
# 状态的自注意力机制
state_k, state_v = self.state_to_kv(normed_states).chunk(2, dim = -1)
# 交叉注意力
to_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v)
to_state_out = rearrange(to_state_out, 'b h n d -> b n (h d)')
# 缓存下一个写入状态
self.cache = (states, normed_states, state_k, state_v)
return to_state_out
# 写入状态
def write(
self,
*,
memories
):
# 断言缓存存在
assert exists(self.cache)
# 解包记忆
k, v = memories
batch = k.shape[0]
# 从先前读取的缓存中获取缓存的值
states, normed_states, state_k, state_v = self.cache
self.cache = None
# 推导查询
q_from_state = self.q_from_state(normed_states)
q_from_state = rearrange(q_from_state, '... n (h d) -> ... h n d', h = self.heads)
state_q = self.state_to_q(normed_states)
state_q_einsum = 'n (h d)' if state_q.ndim == 2 else 'b n (h d)'
state_q = repeat(state_q, f'{state_q_einsum} -> b h n d', h = self.heads, b = batch)
# 状态也必须经过自注意力
if q_from_state.ndim == 3:
q_from_state = repeat(q_from_state, '... -> b ...', b = batch)
state_out = self.state_self_attn(state_q, state_k, state_v)
from_state_out = self.from_state_cross_attn(q_from_state, k, v)
state_out = torch.cat((state_out, from_state_out), dim = -1)
state_out = rearrange(state_out, 'b h n d -> b n (h d)')
state_out = self.to_state_out(state_out)
# 使用表现最佳的配置
# 固定简单门 - 仅仅是一个学习的EMA,与高速公路网络有些相似
z = self.state_out_to_gate(state_out)
learned_ema_decay = self.learned_ema_beta.sigmoid()
# 使用学习的EMA门设置新状态
return learned_ema_decay * z + (1 - learned_ema_decay) * states
def forward(self, x):
raise NotImplementedError
# 主类
class Attend(nn.Module):
# 初始化函数
def __init__(
self,
causal = False,
use_flash_attn = False
):
# 调用父类的初始化函数
super().__init__()
# 初始化 causal 和 use_flash_attn 属性
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.use_flash_attn = use_flash_attn
# 检查是否满足使用 flash attention 的条件
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定 CUDA 和 CPU 的高效注意力配置
self.cpu_config = Config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not use_flash_attn:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
# 获取 mask
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
# flash attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# 推荐使用 Tri Dao 的多查询单键值注意力
if k.ndim == 3:
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
if v.ndim == 3:
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
# 检查 mask 是否存在并扩展到兼容的形状
masks = []
if self.causal:
i, j = q_len, k_len
causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
masks.append(~causal_mask)
if exists(mask):
if mask.ndim != 2:
mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0])
masks.append(mask)
attn_mask = and_reduce(masks)
# 检查是否有兼容的设备用于 flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 torch.backends.cuda.sdp_kernel 函数进行 flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask
)
return out
# 实现 Transformer 模型的前向传播函数,接受查询(q)、键(k)、值(v)和掩码(mask),以及是否使用 Flash Attention(use_flash_attn)的参数
def forward(self, q, k, v, mask = None, use_flash_attn = None):
# 如果未提供 use_flash_attn 参数,则使用默认值 self.use_flash_attn
use_flash_attn = default(use_flash_attn, self.use_flash_attn)
# 获取查询张量的形状信息
b, n, device = q.shape[0], q.shape[-2], q.device
# 将查询(q)、键(k)、值(v)打包成特定形状
q, ps = pack_one(q, '* h n d')
k, _ = pack_one(k, '* n d')
v, _ = pack_one(v, '* n d')
# 如果使用 Flash Attention,则调用 flash_attn 函数进行注意力计算
if use_flash_attn:
out = self.flash_attn(q, k, v, mask = mask)
return unpack_one(out, ps, '* h n d')
# 计算缩放因子
scale = q.shape[-1] ** -0.5
# 根据键(k)的维度确定 einsum 中的字符串
k_einsum = 'b j d' if k.ndim == 3 else 'b h j d'
v_einsum = 'b j d' if v.ndim == 3 else 'b h j d'
# 计算相似度矩阵
sim = einsum(f"b h i d, {k_einsum} -> b h i j", q, k) * scale
# 处理键的填充掩码
if exists(mask):
if mask.ndim != 2:
mask = repeat(mask, 'w ... -> (b w) ...', b = b)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 处理因果掩码
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
# 聚合值
out = einsum(f"b h i j, {v_einsum} -> b h i d", attn, v)
return unpack_one(out, ps, '* h n d')
# 定义 GEGLU 类,用于实现 GEGLU 激活函数
class GEGLU(nn.Module):
# GEGLU 类的前向传播函数
def forward(self, x):
# 将输入张量 x 按照最后一个维度分成两部分,分别赋值给 x 和 gate
x, gate = x.chunk(2, dim=-1)
# 返回 GEGLU 激活函数的计算结果
return F.gelu(gate) * x
# 定义 FeedForward 函数,用于创建前馈神经网络
def FeedForward(dim, mult=4):
# 计算内部维度
inner_dim = int(dim * mult * 2 / 3)
# 返回一个包含多个层的神经网络模型
return nn.Sequential(
LayerNorm(dim), # 对输入进行层归一化
nn.Linear(dim, inner_dim * 2, bias=False), # 线性变换层
GEGLU(), # GEGLU 激活函数
nn.Linear(inner_dim, dim, bias=False) # 线性变换层
)
# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
# Attention 类的初始化函数
def __init__(
self,
dim_head,
causal=False,
qk_rmsnorm=False,
qk_rmsnorm_scale=8,
use_flash_attn=False
):
super().__init__()
self.causal = causal # 是否使用因果注意力机制
self.qk_rmsnorm = qk_rmsnorm # 是否进行 RMS 归一化
self.qk_rmsnorm_scale = qk_rmsnorm_scale # RMS 归一化的缩放因子
self.attend = Attend(causal=causal, use_flash_attn=use_flash_attn) # 创建 Attend 对象
if qk_rmsnorm:
self.q_scale = nn.Parameter(torch.ones(dim_head)) # 创建可学习参数 q_scale
self.k_scale = nn.Parameter(torch.ones(dim_head)) # 创建可学习参数 k_scale
# Attention 类的前向传播函数
def forward(
self,
q, k, v,
mask=None,
rotary_pos_emb=None,
xpos_scale=None
):
scale = q.shape[-1] ** -0.5 # 缩放因子
if self.qk_rmsnorm:
q, k = map(l2norm, (q, k)) # 对 q 和 k 进行 L2 归一化
scale = self.qk_rmsnorm_scale # 更新缩放因子
if self.qk_rmsnorm:
q = q * self.q_scale # 对 q 进行缩放
k = k * self.k_scale # 对 k 进行缩放
# 使用旋转位置嵌入进行位置编码
if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale)
k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale ** -1)
# 注意力计算
out = self.attend(q, k, v, mask=mask)
return out
# 定义 AttentionBlock 类,用于实现注意力块
class AttentionBlock(nn.Module):
# AttentionBlock 类的初始化函数
def __init__(
self,
dim,
block_width,
dim_head=64,
heads=8,
qk_rmsnorm=False,
qk_rmsnorm_scale=8,
use_flash_attn=False,
num_state_vectors=0,
num_external_state_reads=0,
state_read_before_write=True
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = LayerNorm(dim) # 对输入进行层归一化
self.to_q = nn.Linear(dim, inner_dim, bias=False) # 线性变换层
self.to_kv = nn.Linear(dim, dim_head * 2, bias=False) # 线性变换层
self.attn = Attention(dim_head, qk_rmsnorm=qk_rmsnorm, qk_rmsnorm_scale=qk_rmsnorm_scale, use_flash_attn=use_flash_attn) # 创建 Attention 对象
self.block_width = block_width
self.is_recurrent_layer = num_state_vectors > 0 # 是否为循环层
num_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_reads # 确定从状态容器中读取的状态数量
self.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias=False) # 线性变换层
if not self.is_recurrent_layer:
return
self.state_read_before_write = state_read_before_write
self.state_container = StateContainer(
dim,
dim_head=dim_head,
heads=heads,
num_state_vectors=num_state_vectors,
qk_rmsnorm=qk_rmsnorm,
qk_rmsnorm_scale=qk_rmsnorm_scale,
use_flash_attn=use_flash_attn
)
@property
def device(self):
return next(self.parameters()).device
# AttentionBlock 类的前向传播函数
def forward(
self,
x,
rotary_pos_emb=None,
xpos_scale=None,
attn_mask=None,
xl_memories: Optional[torch.Tensor] = None,
read_from_state_containers: List[StateContainer] = []
):
# 解构输入张量 x 的形状,获取 batch, seq_len, _, width, device
batch, seq_len, _, width, device = *x.shape, self.block_width, self.device
# 预归一化处理
x = self.norm(x)
# 分别提取 queries, keys, values,并拆分出多头
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
split_head = partial(rearrange, pattern='b n (h d) -> b h n d', h=self.heads)
q = split_head(q)
# 将最后的 key / values 作为记忆以供后续使用
memories = torch.stack((k, v))
mem_len = 0
if exists(xl_memories):
# 如果传入了过去的记忆,将其连接为第一个 bucket
mem_len = xl_memories.shape[-2]
past_k, past_v = xl_memories
k = torch.cat((past_k, k), dim=1)
v = torch.cat((past_v, v), dim=1)
# 处理注意力掩码和位置嵌入的裁剪
if exists(attn_mask):
attn_mask = attn_mask[:seq_len, :seq_len]
attn_mask = F.pad(attn_mask, (mem_len, 0), value=True)
# 进行注意力计算
out = self.attn(
q, k, v,
rotary_pos_emb=rotary_pos_emb,
xpos_scale=xpos_scale,
mask=attn_mask
)
# 合并多头
out = rearrange(out, 'b h n d -> b n (h d)')
# 如果不是循环层且没有从状态容器中读取数据,则直接返回结果
if not self.is_recurrent_layer and len(read_from_state_containers) == 0:
return self.to_out(out), memories, None
# 是否从自身状态容器中读取数据,默认为是,但也可以传入更多
if self.is_recurrent_layer and self.state_read_before_write:
read_from_state_containers = [self.state_container, *read_from_state_containers]
for read_state_container in read_from_state_containers:
# 从状态容器中读取数据
to_state_out = read_state_container.read(x)
# 将读取的数据连接到自注意力输出中
out = torch.cat((out, to_state_out), dim=-1)
new_states = None
if self.is_recurrent_layer:
# 如果是循环层,则将记忆写入状态容器
new_states = self.state_container.write(memories=memories)
return self.to_out(out), memories, new_states
# 定义一个装饰器函数 @beartype,用于类型检查
# 定义一个类 BlockRecurrentTransformer,继承自 nn.Module
class BlockRecurrentTransformer(nn.Module):
# 初始化函数
def __init__(
self,
*,
num_tokens, # 输入参数:标记的数量
dim, # 输入参数:维度
depth, # 输入参数:深度
dim_head = 64, # 输入参数:头的维度,默认为64
heads = 8, # 输入参数:头的数量,默认为8
all_layers_qk_rmsnorm = False, # 输入参数:是否对所有层的查询和键进行均方根归一化,默认为False
ff_mult = 4, # 输入参数:前馈网络的倍数,默认为4
max_seq_len = 1024, # 输入参数:最大序列长度,默认为1024
block_width = 512, # 输入参数:块的宽度,默认为512
recurrent_layers: Optional[Tuple[int, ...]] = None, # 输入参数:循环层的索引元组,默认为None
read_recurrent_layers: Optional[Tuple[int, ...]] = None, # 输入参数:读取循环层的索引元组,默认为None
num_state_vectors = None, # 输入参数:状态向量的数量,默认为None
ignore_index = -100, # 输入参数:忽略的索引,默认为-100
use_flash_attn = False, # 输入参数:是否使用快闪注意力,默认为False
use_compressed_mem = False, # 输入参数:是否使用压缩内存,默认为False
compressed_mem_factor = 4 # 输入参数:压缩内存因子,默认为4
):
# 调用父类的构造函数
super().__init__()
# 设置状态向量的数量,默认为块宽度
num_state_vectors = default(num_state_vectors, block_width)
# 设置循环层
# 默认为网络中间的一个循环层
recurrent_layers = default(recurrent_layers, (depth // 2,))
# 断言循环层的范围在1到深度之间
assert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}'
# 断言循环层是唯一的,没有重复的层
assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers'
self.recurrent_layers = recurrent_layers
# 设置读取循环层
read_recurrent_layers = default(read_recurrent_layers, recurrent_layers)
# 断言读取循环层小于等于写入循环层
assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer'
assert all([0 < layer <= depth for layer in read_recurrent_layers])
assert len(read_recurrent_layers) == len(recurrent_layers)
self.read_recurrent_layers = read_recurrent_layers
# 令牌嵌入
self.token_emb = nn.Embedding(num_tokens, dim)
self.rotary_pos_emb = RotaryEmbedding(dim = dim_head, width = (2 if not use_compressed_mem else 3) * block_width)
self.layers = nn.ModuleList([])
self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)}
self.read_state_router = defaultdict(list)
for layer in range(1, depth + 1):
is_recurrent_layer = layer in self.recurrent_layers
layer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0
num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers])
# 只有具有xl记忆的层或在水平方向上具有循环的层使用qk rmsnorm
qk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layer
attn_block = AttentionBlock(
dim,
block_width = block_width,
dim_head = dim_head,
heads = heads,
qk_rmsnorm = qk_rmsnorm,
num_state_vectors = layer_num_state_vectors,
use_flash_attn = use_flash_attn,
num_external_state_reads = num_external_state_reads,
state_read_before_write = False,
)
ff_block = FeedForward(dim, mult = ff_mult)
if is_recurrent_layer:
read_layer = self.write_to_read_map[layer]
self.read_state_router[read_layer].append(attn_block.state_container)
self.layers.append(nn.ModuleList([
attn_block,
ff_block
]))
# (compressed) memory management
self.mem_manager = MemoryManager(
dim = dim_head,
layers = depth,
mem_lengths = block_width if not use_compressed_mem else (block_width, block_width // 2),
compress_factors = 1 if not use_compressed_mem else (1, compressed_mem_factor)
)
# 转换为logits
self.to_logits = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias = False)
)
self.max_seq_len = max_seq_len
self.block_width = block_width
# 断言最大序列长度能被块宽度整除
assert divisible_by(max_seq_len, block_width)
self.ignore_index = ignore_index
self.register_buffer('cached_causal_attn_mask', None, persistent = False)
@property
def device(self):
# 返回参数的设备
return next(self.parameters()).device
# 获取因果注意力掩码
def get_causal_attn_mask(self, width):
# 如果缓存中存在因果注意力掩码
if exists(self.cached_causal_attn_mask):
# 获取缓存中的掩码
cached_mask = self.cached_causal_attn_mask
# 获取缓存掩码的宽度
cached_width = cached_mask.shape[-2]
# 计算填充量
padding = (width - cached_width) // 2
# 创建切片对象
j_slice = Ellipsis if padding == 0 else slice(padding, -padding)
# 返回缓存中的掩码
return cached_mask[:cached_width, j_slice]
# 获取设备信息
device = self.device
# 创建全为1的因果掩码
causal_mask = torch.ones((width, width), device=device, dtype=torch.bool).triu(1)
# 返回取反的因果掩码
return ~causal_mask
# 生成序列
@torch.no_grad()
@eval_decorator
def generate(
self,
prime,
length=None,
xl_memories: List[torch.Tensor] = [],
states: List[torch.Tensor] = [],
temperature=1.,
filter_thres=0.9,
return_memories_and_states=False
):
# 设置生成序列的长度
length = default(length, self.max_seq_len + 1)
# 获取起始序列的长度
start_len = prime.shape[-1]
# 断言起始序列长度小于最大序列长度
assert start_len < self.max_seq_len
# 断言生成序列长度不超过最大序列长度
assert length <= (self.max_seq_len + 1)
# 断言起始序列长度小于生成序列长度
assert start_len < length
# 初始化输出为起始序列
output = prime
# 初始化记忆
memories = []
# 循环生成序列
for ind in range(length - start_len):
# 前向传播
logits, next_memories, next_states = self.forward(
output,
xl_memories=xl_memories,
states=states
)
# 获取最后一个位置的logits
logits = logits[:, -1]
# 过滤logits
filtered_logits = top_k(logits, thres=filter_thres)
# 采样
sampled = gumbel_sample(filtered_logits, temperature=temperature)
sampled = rearrange(sampled, 'b -> b 1')
# 拼接采样结果到输出序列
output = torch.cat((output, sampled), dim=-1)
# 如果当前窗口的最后一个token采样完成,更新记忆和状态
if divisible_by(output.shape[-1] - 1, self.max_seq_len):
memories = next_memories
states = next_states
# 去除起始序列部分,得到最终生成序列
output = output[:, start_len:]
# 如果需要返回记忆和状态信息
if return_memories_and_states:
return output, memories, states
return output
# 前向传播
def forward(
self,
x,
return_loss=False,
xl_memories: List[torch.Tensor] = [],
states: List[torch.Tensor] = [],
return_memories_and_states=None # 可以强制返回记忆和状态,或者不返回。默认只有在token数量等于最大序列长度时才返回
):
# 获取输入张量的设备信息
device = x.device
if return_loss:
# 如果需要返回损失,则将输入张量切片,去掉最后一个元素作为标签
x, labels = x[:, :-1], x[:, 1:]
# 获取动态位置偏置的序列长度 i 和 j
assert x.shape[-1] <= self.max_seq_len
w = self.block_width
# 令牌嵌入
x = self.token_emb(x)
# 动态位置偏置
attn_mask = self.get_causal_attn_mask(w)
rotary_pos_emb, xpos_scale = self.rotary_pos_emb()
# 只有在完整的块宽度时才返回记忆和状态,但可以被覆盖
return_memories_and_states = default(return_memories_and_states, self.max_seq_len == x.shape[-2])
# 准备输出张量,以便按块连接
batch, _, dim = x.shape
out = torch.empty(batch, 0, dim, dtype = x.dtype, device = self.device)
# 将输入分割成宽度为 w 的块
input_blocks = x.split(w, dim = -2)
# 逐个处理每个块
for input_block in input_blocks:
input_block_length = input_block.shape[-2]
# 准备 xl 记忆和状态
iter_xl_memories = iter(xl_memories)
iter_states = iter(states)
next_xl_memories = []
next_states = []
# 在适当的状态容器上设置状态
for attn, _ in self.layers:
if not attn.is_recurrent_layer:
continue
attn.state_container.set_next_read_state(next(iter_states, None))
# 遍历层
for ind, (attn, ff) in enumerate(self.layers):
# 确定层是否需要 transformer xl 记忆
layer = ind + 1
# 是否传入 xl 记忆
attn_kwargs = dict(
rotary_pos_emb = rotary_pos_emb,
xpos_scale = xpos_scale,
attn_mask = attn_mask,
xl_memories = next(iter_xl_memories, None),
read_from_state_containers = self.read_state_router[layer]
)
# 注意力层
residual = input_block
attn_branch_out, layer_xl_memories, layer_next_states = attn(input_block, **attn_kwargs)
if exists(layer_xl_memories):
next_xl_memories.append(layer_xl_memories)
if exists(layer_next_states):
next_states.append(layer_next_states)
input_block = attn_branch_out + residual
# 前馈层
input_block = ff(input_block) + input_block
# 连接到输出
out = torch.cat((out, input_block), dim = -2)
# 设置新的 xl 记忆和状态
states = next_states
if input_block_length == w:
xl_memories = self.mem_manager(xl_memories, next_xl_memories)
# 投影到对数
logits = self.to_logits(out)
# 分离状态和记忆
returned_next_states = list(map(torch.detach, states)) if return_memories_and_states else None
returned_next_xl_memories = list(map(torch.detach, xl_memories)) if return_memories_and_states else None
# 是否返回对数
if not return_loss:
return logits, returned_next_xl_memories, returned_next_states
# 交叉熵损失
logits = rearrange(logits, 'b n c -> b c n')
loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)
return loss, returned_next_xl_memories, returned_next_states
# recurrent trainer wrapper
# 定义一个装饰器,用于验证输入参数类型
@beartype
# 定义一个类,继承自 nn.Module
class RecurrentTrainerWrapper(nn.Module):
# 初始化方法
def __init__(
self,
transformer: BlockRecurrentTransformer,
xl_memories_dropout = 0.,
state_dropout = 0.
):
super().__init__()
self.transformer = transformer
self.seq_len = transformer.max_seq_len
self.xl_memories_dropout = xl_memories_dropout
self.state_dropout = state_dropout
# 生成方法,用于生成序列
@eval_decorator
@torch.no_grad()
def generate(
self,
prime,
length,
**kwargs
):
seq_len = self.seq_len
start_len = prime.shape[-1]
assert start_len < length
output = prime
current_len = start_len
memories = []
states = []
# 确定长度
has_remainder = not divisible_by(length, seq_len)
remainder_amount = length % seq_len
total_segments = math.ceil(length / seq_len)
if not has_remainder:
lengths = (*((seq_len + 1,) * (total_segments - 1)), seq_len)
elif remainder_amount == 1:
lengths = (seq_len + 1,) * (total_segments - 1)
else:
lengths = (*((seq_len + 1,) * (total_segments - 1)), remainder_amount)
# 循环遍历长度
for next_length in lengths:
segment_output, memories, states = self.transformer.generate(
output[:, -current_len:],
length = next_length,
xl_memories = memories,
states = states,
return_memories_and_states = True,
**kwargs
)
output = torch.cat((output, segment_output), dim = -1)
current_len = 1
return output[:, start_len:]
# 前向传播方法
def forward(
self,
x,
return_memories_and_states = False
):
total_seq_len, seq_len = x.shape[1], self.seq_len
assert divisible_by(total_seq_len - 1, seq_len), f'length of sequence ({total_seq_len}) must be equal to a multiple of {seq_len} + 1 (one extra token) during training'
segments = total_seq_len // seq_len
total_loss = 0.
memories = []
states = []
for ind in range(segments):
start = ind * seq_len
end = start + seq_len + 1
if self.training and random() < self.xl_memories_dropout:
memories.clear()
if self.training and random() < self.state_dropout:
states.clear()
loss, memories, states = self.transformer(
x[:, start:end],
xl_memories = memories,
states = states,
return_loss = True
)
total_loss = total_loss + (loss / segments)
if return_memories_and_states:
return total_loss, memories, states
return total_loss
.\lucidrains\block-recurrent-transformer-pytorch\block_recurrent_transformer_pytorch\__init__.py
# 导入 torch 库
import torch
# 从 packaging 库中导入 version 模块
from packaging import version
# 检查 torch 版本是否大于等于 '2.0.0',如果是则执行以下代码
if version.parse(torch.__version__) >= version.parse('2.0.0'):
# 从 einops._torch_specific 模块中导入 allow_ops_in_compiled_graph 函数
from einops._torch_specific import allow_ops_in_compiled_graph
# 调用 allow_ops_in_compiled_graph 函数
allow_ops_in_compiled_graph()
# 从 block_recurrent_transformer_pytorch 包中导入 BlockRecurrentTransformer 和 RecurrentTrainerWrapper 类
from block_recurrent_transformer_pytorch.block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Block Recurrent Transformer - Pytorch
Implementation of Block Recurrent Transformer - Pytorch. The highlight of the paper is its reported ability to remember something up to 60k tokens ago.
This design is SOTA for recurrent transformers line of research, afaict.
It will also include flash attention as well as routed memories of up to 250k tokens using ideas from this paper
Appreciation
- Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
Install
$ pip install block-recurrent-transformer-pytorch
Usage
import torch
from block_recurrent_transformer_pytorch import BlockRecurrentTransformer
model = BlockRecurrentTransformer(
num_tokens = 20000, # vocab size
dim = 512, # model dimensions
depth = 6, # depth
dim_head = 64, # attention head dimensions
heads = 8, # number of attention heads
max_seq_len = 1024, # the total receptive field of the transformer, in the paper this was 2 * block size
block_width = 512, # block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)
num_state_vectors = 512, # number of state vectors, i believe this was a single block size in the paper, but can be any amount
recurrent_layers = (4,), # where to place the recurrent layer(s) for states with fixed simple gating
use_compressed_mem = False, # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507
compressed_mem_factor = 4, # compression factor of compressed memories
use_flash_attn = True # use flash attention, if on pytorch 2.0
)
seq = torch.randint(0, 2000, (1, 1024))
out, mems1, states1 = model(seq)
out, mems2, states2 = model(seq, xl_memories = mems1, states = states1)
out, mems3, states3 = model(seq, xl_memories = mems2, states = states2)
Test on Enwik8
First pip install -r requirements.txt, then
$ python train.py
Todo
Citations
@article{Hutchins2022BlockRecurrentT,
title = {Block-Recurrent Transformers},
author = {DeLesley S. Hutchins and Imanol Schlag and Yuhuai Wu and Ethan Dyer and Behnam Neyshabur},
journal = {ArXiv},
year = {2022},
volume = {abs/2203.07852}
}
@article{Shazeer2019FastTD,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam M. Shazeer},
journal = {ArXiv},
year = {2019},
volume = {abs/1911.02150}
}
@inproceedings{Sun2022ALT,
title = {A Length-Extrapolatable Transformer},
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
year = {2022}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@inproceedings{Ainslie2023CoLT5FL,
title = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
author = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
year = {2023}
}
Memory is Attention through Time - Alex Graves
.\lucidrains\block-recurrent-transformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'block-recurrent-transformer-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.4.3', # 版本号
license='MIT', # 许可证
description = 'Block Recurrent Transformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/block-recurrent-transformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'recurrence'
],
install_requires=[ # 安装依赖
'beartype',
'einops>=0.6.1',
'memorizing-transformers-pytorch>=0.4.0',
'torch>=1.6',
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\block-recurrent-transformer-pytorch\train.py
# 导入所需的库
import gzip
import random
import tqdm
import numpy as np
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入加速库
from accelerate import Accelerator
# 导入自定义的模型和训练器
from block_recurrent_transformer_pytorch import BlockRecurrentTransformer, RecurrentTrainerWrapper
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 250
GENERATE_LENGTH = 2048
SEQ_LEN = 2048
# 定义辅助函数
# 将 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 将 tokens 解码为字符串
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 初始化加速器
accelerator = Accelerator()
# 获取设备和打印函数
device = accelerator.device
acc_print = accelerator.print
# 实例化模型
model = BlockRecurrentTransformer(
num_tokens = 256,
dim = 512,
depth = 6,
dim_head = 64,
heads = 8,
max_seq_len = 1024,
block_width = 512,
num_state_vectors = 512,
recurrent_layers = (4,),
use_flash_attn = True
)
# 实例化训练器
train_wrapper = RecurrentTrainerWrapper(
model,
xl_memories_dropout = 0.1,
state_dropout = 0.1,
)
# 将模型移动到设备上
model.to(device)
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
np_train, np_valid = np.split(data, [int(90e6)])
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.to(device)
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)
# 准备模型、优化器和数据加载器
model, optim, train_loader, val_loader = accelerator.prepare(
model, optim, train_loader, val_loader
)
# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for _ in range(GRADIENT_ACCUMULATE_EVERY):
loss = train_wrapper(next(train_loader))
accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)
acc_print(f"training loss: {loss.item()}")
accelerator.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = train_wrapper(next(val_loader))
acc_print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:PRIME_LENGTH]
prime = decode_tokens(inp)
acc_print(f"%s \n\n %s", (prime, "*" * 100))
sample = train_wrapper.generate(inp[None, ...], length = GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
acc_print(output_str, "\n")
.\lucidrains\bottleneck-transformer-pytorch\bottleneck_transformer_pytorch\bottleneck_transformer_pytorch.py
# 导入 math 和 torch 模块
import math
import torch
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 einops 模块中导入 rearrange 函数
# 从 tensorflow 代码翻译而来
# https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
# 位置嵌入的辅助函数
# 如果 x 不是元组,则返回 (x, x),否则返回 x
def pair(x):
return (x, x) if not isinstance(x, tuple) else x
# 在指定维度 dim 上扩展张量 t 的维度为 k
def expand_dim(t, dim, k):
t = t.unsqueeze(dim = dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
# 将相对位置编码转换为绝对位置编码
def rel_to_abs(x):
b, h, l, _, device, dtype = *x.shape, x.device, x.dtype
dd = {'device': device, 'dtype': dtype}
col_pad = torch.zeros((b, h, l, 1), **dd)
x = torch.cat((x, col_pad), dim = 3)
flat_x = rearrange(x, 'b h l c -> b h (l c)')
flat_pad = torch.zeros((b, h, l - 1), **dd)
flat_x_padded = torch.cat((flat_x, flat_pad), dim = 2)
final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
final_x = final_x[:, :, :l, (l-1):]
return final_x
# 计算相对位置编码的一维相对注意力权重
def relative_logits_1d(q, rel_k):
b, heads, h, w, dim = q.shape
logits = einsum('b h x y d, r d -> b h x y r', q, rel_k)
logits = rearrange(logits, 'b h x y r -> b (h x) y r')
logits = rel_to_abs(logits)
logits = logits.reshape(b, heads, h, w, w)
logits = expand_dim(logits, dim = 3, k = h)
return logits
# 位置嵌入
# 绝对位置嵌入类
class AbsPosEmb(nn.Module):
def __init__(
self,
fmap_size,
dim_head
):
super().__init__()
height, width = pair(fmap_size)
scale = dim_head ** -0.5
self.height = nn.Parameter(torch.randn(height, dim_head) * scale)
self.width = nn.Parameter(torch.randn(width, dim_head) * scale)
def forward(self, q):
emb = rearrange(self.height, 'h d -> h () d') + rearrange(self.width, 'w d -> () w d')
emb = rearrange(emb, ' h w d -> (h w) d')
logits = einsum('b h i d, j d -> b h i j', q, emb)
return logits
# 相对位置嵌入类
class RelPosEmb(nn.Module):
def __init__(
self,
fmap_size,
dim_head
):
super().__init__()
height, width = pair(fmap_size)
scale = dim_head ** -0.5
self.fmap_size = fmap_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
def forward(self, q):
h, w = self.fmap_size
q = rearrange(q, 'b h (x y) d -> b h x y d', x = h, y = w)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')
q = rearrange(q, 'b h x y d -> b h y x d')
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
return rel_logits_w + rel_logits_h
# 注意力机制类
class Attention(nn.Module):
def __init__(
self,
*,
dim,
fmap_size,
heads = 4,
dim_head = 128,
rel_pos_emb = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
rel_pos_class = AbsPosEmb if not rel_pos_emb else RelPosEmb
self.pos_emb = rel_pos_class(fmap_size, dim_head)
def forward(self, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
sim = sim + self.pos_emb(q)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return out
class BottleBlock(nn.Module):
# 初始化函数,设置网络的参数和结构
def __init__(
self,
*,
dim,
fmap_size,
dim_out,
proj_factor,
downsample,
heads = 4,
dim_head = 128,
rel_pos_emb = False,
activation = nn.ReLU()
):
# 调用父类的初始化函数
super().__init__()
# shortcut
# 如果输入维度不等于输出维度或者需要下采样
if dim != dim_out or downsample:
# 根据是否下采样设置卷积核大小、步长和填充
kernel_size, stride, padding = (3, 2, 1) if downsample else (1, 1, 0)
# 创建 shortcut 层,包括卷积、批归一化和激活函数
self.shortcut = nn.Sequential(
nn.Conv2d(dim, dim_out, kernel_size, stride = stride, padding = padding, bias = False),
nn.BatchNorm2d(dim_out),
activation
)
else:
# 如果不需要下采样,则使用恒等映射
self.shortcut = nn.Identity()
# contraction and expansion
# 计算注意力机制输入维度和输出维度
attn_dim_in = dim_out // proj_factor
attn_dim_out = heads * dim_head
# 创建网络结构,包括卷积、批归一化、激活函数、注意力机制、平均池化、卷积和批归一化
self.net = nn.Sequential(
nn.Conv2d(dim, attn_dim_in, 1, bias = False),
nn.BatchNorm2d(attn_dim_in),
activation,
Attention(
dim = attn_dim_in,
fmap_size = fmap_size,
heads = heads,
dim_head = dim_head,
rel_pos_emb = rel_pos_emb
),
nn.AvgPool2d((2, 2)) if downsample else nn.Identity(),
nn.BatchNorm2d(attn_dim_out),
activation,
nn.Conv2d(attn_dim_out, dim_out, 1, bias = False),
nn.BatchNorm2d(dim_out)
)
# 初始化最后一个批归一化层的权重为零
nn.init.zeros_(self.net[-1].weight)
# final activation
# 设置最终的激活函数
self.activation = activation
# 前向传播函数
def forward(self, x):
# 计算 shortcut
shortcut = self.shortcut(x)
# 经过网络结构
x = self.net(x)
# 将 shortcut 加到输出上
x = x + shortcut
# 返回激活后的输出
return self.activation(x)
# 定义一个名为BottleStack的类,继承自nn.Module
class BottleStack(nn.Module):
# 初始化函数,接受一系列参数
def __init__(
self,
*,
dim, # 特征维度
fmap_size, # 特征图大小
dim_out = 2048, # 输出维度,默认为2048
proj_factor = 4, # 投影因子,默认为4
num_layers = 3, # 层数,默认为3
heads = 4, # 多头注意力机制中的头数,默认为4
dim_head = 128, # 多头注意力机制中每个头的维度,默认为128
downsample = True, # 是否下采样,默认为True
rel_pos_emb = False, # 是否使用相对位置编码,默认为False
activation = nn.ReLU() # 激活函数,默认为ReLU
):
super().__init__() # 调用父类的初始化函数
fmap_size = pair(fmap_size) # 将特征图大小转换为元组形式
self.dim = dim # 初始化特征维度
self.fmap_size = fmap_size # 初始化特征图大小
layers = [] # 初始化一个空列表用于存放层
# 循环创建num_layers个BottleBlock层
for i in range(num_layers):
is_first = i == 0 # 判断是否是第一层
dim = (dim if is_first else dim_out) # 如果是第一层,则维度为dim,否则为dim_out
layer_downsample = is_first and downsample # 如果是第一层且需要下采样,则为True
fmap_divisor = (2 if downsample and not is_first else 1) # 计算特征图大小的除数
layer_fmap_size = tuple(map(lambda t: t // fmap_divisor, fmap_size)) # 计算当前层的特征图大小
# 创建一个BottleBlock层,并添加到layers列表中
layers.append(BottleBlock(
dim = dim,
fmap_size = layer_fmap_size,
dim_out = dim_out,
proj_factor = proj_factor,
heads = heads,
dim_head = dim_head,
downsample = layer_downsample,
rel_pos_emb = rel_pos_emb,
activation = activation
))
# 将所有层组合成一个神经网络
self.net = nn.Sequential(*layers)
# 前向传播函数,接受输入x,返回网络输出
def forward(self, x):
_, c, h, w = x.shape # 获取输入x的形状信息
assert c == self.dim, f'channels of feature map {c} must match channels given at init {self.dim}' # 断言通道数与初始化时给定的特征维度相匹配
assert h == self.fmap_size[0] and w == self.fmap_size[1], f'height and width ({h} {w}) of feature map must match the fmap_size given at init {self.fmap_size}' # 断言特征图的高度和宽度与初始化时给定的特征图大小相匹配
return self.net(x) # 返回网络的输出
.\lucidrains\bottleneck-transformer-pytorch\bottleneck_transformer_pytorch\__init__.py
# 从bottleneck_transformer_pytorch包中导入BottleStack和BottleBlock类
from bottleneck_transformer_pytorch.bottleneck_transformer_pytorch import BottleStack, BottleBlock


Bottleneck Transformer - Pytorch
Implementation of Bottleneck Transformer, SotA visual recognition model with convolution + attention that outperforms EfficientNet and DeiT in terms of performance-computes trade-off, in Pytorch
Install
$ pip install bottleneck-transformer-pytorch
Usage
import torch
from torch import nn
from bottleneck_transformer_pytorch import BottleStack
layer = BottleStack(
dim = 256, # channels in
fmap_size = 64, # feature map size
dim_out = 2048, # channels out
proj_factor = 4, # projection factor
downsample = True, # downsample on first layer or not
heads = 4, # number of heads
dim_head = 128, # dimension per head, defaults to 128
rel_pos_emb = False, # use relative positional embedding - uses absolute if False
activation = nn.ReLU() # activation throughout the network
)
fmap = torch.randn(2, 256, 64, 64) # feature map from previous resnet block(s)
layer(fmap) # (2, 2048, 32, 32)
BotNet
With some simple model surgery off a resnet, you can have the 'BotNet' (what a weird name) for training.
import torch
from torch import nn
from torchvision.models import resnet50
from bottleneck_transformer_pytorch import BottleStack
layer = BottleStack(
dim = 256,
fmap_size = 56, # set specifically for imagenet's 224 x 224
dim_out = 2048,
proj_factor = 4,
downsample = True,
heads = 4,
dim_head = 128,
rel_pos_emb = True,
activation = nn.ReLU()
)
resnet = resnet50()
# model surgery
backbone = list(resnet.children())
model = nn.Sequential(
*backbone[:5],
layer,
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(1),
nn.Linear(2048, 1000)
)
# use the 'BotNet'
img = torch.randn(2, 3, 224, 224)
preds = model(img) # (2, 1000)
Citations
@misc{srinivas2021bottleneck,
title = {Bottleneck Transformers for Visual Recognition},
author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
year = {2021},
eprint = {2101.11605},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\bottleneck-transformer-pytorch\setup.py
# 导入设置工具和查找包的模块
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'bottleneck-transformer-pytorch', # 包的名称
packages = find_packages(), # 查找并包含所有包
version = '0.1.4', # 版本号
license='MIT', # 许可证类型
description = 'Bottleneck Transformer - Pytorch', # 包的描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者的电子邮件
url = 'https://github.com/lucidrains/bottleneck-transformer-pytorch', # 包的URL
keywords = [ # 关键字列表
'artificial intelligence',
'attention mechanism',
'transformers',
'image classification',
'vision'
],
install_requires=[ # 安装所需的依赖项
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\BS-RoFormer\bs_roformer\attend.py
# 导入必要的库
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# 定义一个命名元组 FlashAttentionConfig,包含三个布尔类型的参数
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一些辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 保证函数只被调用一次的装饰器
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 打印信息,只打印一次
print_once = once(print)
# 主要类
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
scale = None
):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定在 cuda 和 cpu 上的高效注意力配置
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
# Flash Attention 方法
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# 检查是否有兼容的设备用于 Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 torch.backends.cuda.sdp_kernel() 方法进行 Flash Attention 计算
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
# 前向传播方法
def forward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)
if self.flash:
return self.flash_attn(q, k, v)
# 相似度计算
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# 注意力计算
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合数值
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\BS-RoFormer\bs_roformer\bs_roformer.py
# 导入所需的库
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
# 导入自定义的模块
from bs_roformer.attend import Attend
# 导入类型提示相关的库
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
# 导入旋转嵌入相关的库
from rotary_embedding_torch import RotaryEmbedding
# 导入 einops 库
from einops import rearrange, pack, unpack
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 将单个张量按照指定模式打包
def pack_one(t, pattern):
return pack([t], pattern)
# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 归一化模块
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
# 注意力模块
# 前馈神经网络模块
class FeedForward(Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 注意力模块
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
rotary_embed = None,
flash = True
):
super().__init__()
self.heads = heads
self.scale = dim_head **-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(flash = flash, dropout = dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# Transformer 模块
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_dropout = 0.,
ff_mult = 4,
norm_output = True,
rotary_embed = None,
flash_attn = True
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# bandsplit 模块
class BandSplit(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
self.to_features.append(net)
# 定义一个前向传播函数,接受输入 x
def forward(self, x):
# 将输入 x 沿着指定维度进行分割
x = x.split(self.dim_inputs, dim = -1)
# 初始化一个空列表用于存储输出结果
outs = []
# 遍历分割后的输入和对应的特征函数
for split_input, to_feature in zip(x, self.to_features):
# 对每个分割后的输入应用对应的特征函数,得到分割后的输出
split_output = to_feature(split_input)
# 将分割后的输出添加到输出列表中
outs.append(split_output)
# 将所有输出结果堆叠在一起,沿着指定维度
return torch.stack(outs, dim = -2)
# 定义一个多层感知机(MLP)模型
def MLP(
dim_in,
dim_out,
dim_hidden = None,
depth = 1,
activation = nn.Tanh
):
# 如果未指定隐藏层维度,则默认为输入维度
dim_hidden = default(dim_hidden, dim_in)
net = []
# 构建每一层的维度信息
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
# 遍历每一层,构建网络结构
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
# 添加线性层
net.append(nn.Linear(layer_dim_in, layer_dim_out))
# 如果是最后一层,则跳过激活函数
if is_last:
continue
# 添加激活函数
net.append(activation())
return nn.Sequential(*net)
# 定义一个MaskEstimator类,继承自Module
class MaskEstimator(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor = 4
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
# 遍历输入维度,构建MLP网络
for dim_in in dim_inputs:
net = []
# 构建MLP网络
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth),
nn.GLU(dim = -1)
)
self.to_freqs.append(mlp)
# 前向传播函数
def forward(self, x):
# 沿着倒数第二维度拆分输入
x = x.unbind(dim = -2)
outs = []
# 遍历每个频段特征和对应的MLP网络
for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)
return torch.cat(outs, dim = -1)
# 主类
# 默认频率带数目
DEFAULT_FREQS_PER_BANDS = (
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
12, 12, 12, 12, 12, 12, 12, 12,
24, 24, 24, 24, 24, 24, 24, 24,
48, 48, 48, 48, 48, 48, 48, 48,
128, 129,
)
# 定义BSRoformer类,继承自Module
class BSRoformer(Module):
@beartype
def __init__(
self,
dim,
*,
depth,
stereo = False,
num_stems = 1,
time_transformer_depth = 2,
freq_transformer_depth = 2,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # 在论文中,它们将其分成约60个频带,测试时先用1
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_dropout = 0.,
flash_attn = True,
dim_freqs_in = 1025,
stft_n_fft = 2048,
stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length = 2048,
stft_normalized = False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth = 2,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size = 147,
multi_stft_normalized = False,
multi_stft_window_fn: Callable = torch.hann_window
):
# 调用父类的构造函数
super().__init__()
# 设置音频是否为立体声
self.stereo = stereo
# 根据音频是否为立体声确定音频通道数
self.audio_channels = 2 if stereo else 1
# 设置音频分离的声音轨道数
self.num_stems = num_stems
# 初始化神经网络层列表
self.layers = ModuleList([])
# 设置变压器的参数
transformer_kwargs = dict(
dim = dim,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
flash_attn = flash_attn,
norm_output = False
)
# 创建时间旋转嵌入和频率旋转嵌入
time_rotary_embed = RotaryEmbedding(dim = dim_head)
freq_rotary_embed = RotaryEmbedding(dim = dim_head)
# 根据深度循环创建变压器层
for _ in range(depth):
self.layers.append(nn.ModuleList([
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, **transformer_kwargs),
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, **transformer_kwargs)
]))
# 初始化最终的归一化层
self.final_norm = RMSNorm(dim)
# 设置短时傅里叶变换的参数
self.stft_kwargs = dict(
n_fft = stft_n_fft,
hop_length = stft_hop_length,
win_length = stft_win_length,
normalized = stft_normalized
)
# 设置短时傅里叶变换的窗口函数
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
# 计算频率数量
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1]
# 断言频率段数大于1
assert len(freqs_per_bands) > 1
# 断言频率段数之和等于总频率数
assert sum(freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
# 计算每个频率段的复数频率数量
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
# 初始化频率段分割层
self.band_split = BandSplit(
dim = dim,
dim_inputs = freqs_per_bands_with_complex
)
# 初始化掩蔽估计器列表
self.mask_estimators = nn.ModuleList([])
# ��据声音轨道数循环创建掩蔽估计器
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim = dim,
dim_inputs = freqs_per_bands_with_complex,
depth = mask_estimator_depth
)
self.mask_estimators.append(mask_estimator)
# 设置多分辨率短时傅里叶变换损失的权重和参数
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length = multi_stft_hop_size,
normalized = multi_stft_normalized
)
# 前向传播函数
def forward(
self,
raw_audio,
target = None,
return_loss_breakdown = False
.\lucidrains\BS-RoFormer\bs_roformer\mel_band_roformer.py
# 导入所需的库
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
# 导入自定义的模块
from bs_roformer.attend import Attend
# 导入类型提示相关的库
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
# 导入旋转嵌入相关的库
from rotary_embedding_torch import RotaryEmbedding
# 导入 einops 库中的函数和层
from einops import rearrange, pack, unpack, reduce, repeat
from einops.layers.torch import Rearrange
# 导入 librosa 库中的滤波器
from librosa import filters
# 定义一些辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 将张量打包成指定模式的形状
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包后的张量解包成原始形状
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 在指定维度上进行填充
def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# 对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
# 定义 RMS 归一化层
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
# 定义前馈神经网络层
class FeedForward(Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 定义注意力机制层
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
rotary_embed = None,
flash = True
):
super().__init__()
self.heads = heads
self.scale = dim_head **-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(flash = flash, dropout = dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义线性注意力机制层
class LinearAttention(Module):
"""
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
"""
@beartype
def __init__(
self,
*,
dim,
dim_head = 32,
heads = 8,
scale = 8,
flash = False,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale = scale,
dropout = dropout,
flash = flash
)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
def forward(
self,
x
):
# 对输入进行归一化处理
x = self.norm(x)
# 将输入转换为查询、键、值
q, k, v = self.to_qkv(x)
# 对查询、键进行 L2 归一化
q, k = map(l2norm, (q, k))
# 对查询进行温度调节
q = q * self.temperature.exp()
# 进行注意力计算
out = self.attend(q, k, v)
# 将输出转换为最终输出
return self.to_out(out)
# 定义一个名为 Transformer 的类,继承自 Module 类
class Transformer(Module):
# 初始化函数,接收多个参数
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_dropout = 0.,
ff_mult = 4,
norm_output = True,
rotary_embed = None,
flash_attn = True,
linear_attn = False
):
# 调用父类的初始化函数
super().__init__()
# 初始化 layers 属性为一个空的 ModuleList
self.layers = ModuleList([])
# 循环 depth 次
for _ in range(depth):
# 根据 linear_attn 参数选择不同的注意力机制
if linear_attn:
attn = LinearAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn)
else:
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn)
# 将注意力机制和前馈网络添加到 layers 中
self.layers.append(ModuleList([
attn,
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
# 根据 norm_output 参数选择是否使用 RMSNorm 或者 nn.Identity
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
# 前向传播函数
def forward(self, x):
# 遍历 layers 中的每个注意力机制和前馈网络
for attn, ff in self.layers:
# 执行注意力机制并将结果与输入相加
x = attn(x) + x
# 执行前馈网络并将结果与输入相加
x = ff(x) + x
# 对结果进行归一化处理
return self.norm(x)
# 定义一个名为 BandSplit 的类,继承自 Module 类
class BandSplit(Module):
# 初始化函数,接收多个参数
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
# 调用父类的初始化函数
super().__init__()
# 初始化 dim_inputs 属性
self.dim_inputs = dim_inputs
# 初始化 to_features 属性为一个空的 ModuleList
self.to_features = ModuleList([])
# 遍历 dim_inputs 中的每个维度
for dim_in in dim_inputs:
# 创建一个包含 RMSNorm 和 Linear 层的网络
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
# 将网络添加到 to_features 中
self.to_features.append(net)
# 前向传播函数
def forward(self, x):
# 将输入 x 按照 dim_inputs 进行分割
x = x.split(self.dim_inputs, dim = -1)
outs = []
# 遍历分割后的输入和对应的网络
for split_input, to_feature in zip(x, self.to_features):
# 对分割后的输入进行处理并添加到 outs 中
split_output = to_feature(split_input)
outs.append(split_output)
# 在指定维度上将结果拼接起来
return torch.stack(outs, dim = -2)
# 定义一个名为 MLP 的函数
def MLP(
dim_in,
dim_out,
dim_hidden = None,
depth = 1,
activation = nn.Tanh
):
# 如果未指定隐藏层维度,则设置为输入维度
dim_hidden = default(dim_hidden, dim_in)
# 初始化网络列表
net = []
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
# 遍历每一层的输入和输出维度
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
# 添加线性层
net.append(nn.Linear(layer_dim_in, layer_dim_out))
# 如果不是最后一层,则添加激活函数
if is_last:
continue
net.append(activation())
# 返回一个包含所有层的序列网络
return nn.Sequential(*net)
# 定义一个名为 MaskEstimator 的类,继承自 Module 类
class MaskEstimator(Module):
# 初始化函数,接收多个参数
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor = 4
):
# 调用父类的初始化函数
super().__init__()
# 初始化 dim_inputs 属性
self.dim_inputs = dim_inputs
# 初始化 to_freqs 属性为一个空的 ModuleList
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
# 遍历 dim_inputs 中的每个维度
for dim_in in dim_inputs:
net = []
# 创建一个包含 MLP 和 GLU 层的网络
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth),
nn.GLU(dim = -1)
)
# 将网络添加到 to_freqs 中
self.to_freqs.append(mlp)
# 前向传播函数
def forward(self, x):
# 将输入 x 按照指定维度解绑
x = x.unbind(dim = -2)
outs = []
# 遍历解绑后的输入和对应的网络
for band_features, mlp in zip(x, self.to_freqs):
# 对输入进行处理并添加到 outs 中
freq_out = mlp(band_features)
outs.append(freq_out)
# 在指定维度上将结果拼接起来
return torch.cat(outs, dim = -1)
# 定义一个名为 MelBandRoformer 的类,继承自 Module 类
class MelBandRoformer(Module):
# 初始化函数
@beartype
# 初始化函数,设置模型参数
def __init__(
self,
dim,
*,
depth,
stereo = False,
num_stems = 1,
time_transformer_depth = 2,
freq_transformer_depth = 2,
linear_transformer_depth = 1,
num_bands = 60,
dim_head = 64,
heads = 8,
attn_dropout = 0.1,
ff_dropout = 0.1,
flash_attn = True,
dim_freqs_in = 1025,
sample_rate = 44100, # needed for mel filter bank from librosa
stft_n_fft = 2048,
stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length = 2048,
stft_normalized = False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth = 1,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size = 147,
multi_stft_normalized = False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length = False, # if True, pad output tensor to match length of input tensor
# 前向传播函数,接收原始音频数据和目标数据,返回损失细分结果
def forward(
self,
raw_audio,
target = None,
return_loss_breakdown = False
.\lucidrains\BS-RoFormer\bs_roformer\__init__.py
# 从 bs_roformer 模块中导入 BSRoformer 类
from bs_roformer.bs_roformer import BSRoformer
# 从 bs_roformer 模块中导入 MelBandRoformer 类
from bs_roformer.mel_band_roformer import MelBandRoformer

BS-RoFormer
Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs. They beat the previous first place by a large margin. The technique uses axial attention across frequency (hence multi-band) and time. They also have experiments to show that rotary positional encoding led to a huge improvement over learned absolute positions.
It also includes support for stereo training and outputting multiple stems.
Please join if you are interested in replicating a SOTA music source separator out in the open
Appreciation
-
StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
-
Roee and Fabian-Robert for sharing their audio expertise and fixing audio hyperparameters
-
@chenht2010 and Roman for working out the default band splitting hyperparameter!
-
Max Prod for reporting a big bug with Mel-Band Roformer with stereo training!
-
Roman for successfully training the model and open sourcing his training code and weights at this repository!
-
Christopher for fixing an issue with multiple stems in Mel-Band Roformer
-
Iver Jordal for identifying that the default stft window function is not correct
Install
$ pip install BS-RoFormer
Usage
import torch
from bs_roformer import BSRoformer
model = BSRoformer(
dim = 512,
depth = 12,
time_transformer_depth = 1,
freq_transformer_depth = 1
)
x = torch.randn(2, 352800)
target = torch.randn(2, 352800)
loss = model(x, target = target)
loss.backward()
# after much training
out = model(x)
To use the Mel-Band Roformer proposed in a recent follow up paper, simply import MelBandRoformer instead
import torch
from bs_roformer import MelBandRoformer
model = MelBandRoformer(
dim = 32,
depth = 1,
time_transformer_depth = 1,
freq_transformer_depth = 1
)
x = torch.randn(2, 352800)
target = torch.randn(2, 352800)
loss = model(x, target = target)
loss.backward()
# after much training
out = model(x)
Todo
Citations
@inproceedings{Lu2023MusicSS,
title = {Music Source Separation with Band-Split RoPE Transformer},
author = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:261556702}
}
@inproceedings{Wang2023MelBandRF,
title = {Mel-Band RoFormer for Music Source Separation},
author = {Ju-Chiang Wang and Wei-Tsung Lu and Minz Won},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:263608675}
}
@misc{ho2019axial,
title = {Axial Attention in Multidimensional Transformers},
author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
year = {2019},
archivePrefix = {arXiv}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
.\lucidrains\BS-RoFormer\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'BS-RoFormer',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.4.0',
# 许可证
license='MIT',
# 描述
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/BS-RoFormer',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'music source separation'
],
# 安装依赖
install_requires=[
'beartype',
'einops>=0.6.1',
'librosa',
'rotary-embedding-torch>=0.3.6',
'torch>=2.0',
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\byol-pytorch\byol_pytorch\byol_pytorch.py
# 导入必要的库
import copy
import random
from functools import wraps
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from torchvision import transforms as T
# 辅助函数
# 如果值为 None,则返回默认值
def default(val, def_val):
return def_val if val is None else val
# 将张量展平为二维张量
def flatten(t):
return t.reshape(t.shape[0], -1)
# 单例装饰器,用于缓存结果
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
# 获取模块所在设备
def get_module_device(module):
return next(module.parameters()).device
# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# 根据是否分布式训练返回不同的批归一化层
def MaybeSyncBatchnorm(is_distributed = None):
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
# 损失函数
# 计算余弦相似度损失
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
# 数据增强工具
# 随机应用函数 fn 进行数据增强
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# 指数移动平均
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
# 更新移动平均值
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# 用于投影器和预测器的 MLP 类
# 创建多层感知机
def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
return nn.Sequential(
nn.Linear(dim, hidden_size),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
# 创建 SimSiam 模型的多层感知机
def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
return nn.Sequential(
nn.Linear(dim, hidden_size, bias=False),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, hidden_size, bias=False),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size, bias=False),
MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
)
# 用于基础神经网络的包装类
# 管理隐藏层输出并将其传递到投影器和预测器网络中
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.use_simsiam_mlp = use_simsiam_mlp
self.sync_batchnorm = sync_batchnorm
self.hidden = {}
self.hook_registered = False
# 查找指定层
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
# 在 forward 方法中的 hook 函数,用于获取隐藏层输出并保存到 self.hidden 字典中
def _hook(self, _, input, output):
# 获取输入数据的设备信息
device = input[0].device
# 将输出数据扁平化后保存到 self.hidden 字典中
self.hidden[device] = flatten(output)
# 注册 hook 函数到指定的隐藏层
def _register_hook(self):
# 查找指定的隐藏层
layer = self._find_layer()
# 断言找到隐藏层
assert layer is not None, f'hidden layer ({self.layer}) not found'
# 注册 forward hook 函数到隐藏层
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
# 获取投影器对象
@singleton('projector')
def _get_projector(self, hidden):
# 获取隐藏层的维度信息
_, dim = hidden.shape
# 根据是否使用 SimSiamMLP 创建 MLP 或 SimSiamMLP 对象
create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
# 创建投影器对象
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
# 将投影器对象移动到隐藏层所在设备
return projector.to(hidden)
# 获取输入数据的表示
def get_representation(self, x):
# 如果指定的隐藏层为最后一层,则直接返回网络输出
if self.layer == -1:
return self.net(x)
# 如果 hook 函数未注册,则注册 hook 函数
if not self.hook_registered:
self._register_hook()
# 清空 self.hidden 字典
self.hidden.clear()
# 前向传播输入数据,并获取隐藏层输出
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
# 断言隐藏层输出不为空
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
# 前向传播方法
def forward(self, x, return_projection = True):
# 获取输入数据的表示
representation = self.get_representation(x)
# 如果不需要返回投影结果,则直接返回表示
if not return_projection:
return representation
# 获取投影器对象
projector = self._get_projector(representation)
# 对表示进行投影
projection = projector(representation)
return projection, representation
# 主类 BYOL,继承自 nn.Module
class BYOL(nn.Module):
# 初始化函数
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True,
sync_batchnorm = None
):
super().__init__()
self.net = net
# 默认的 SimCLR 数据增强
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
# 设置数据增强函数
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
# 在线编码器
self.online_encoder = NetWrapper(
net,
projection_size,
projection_hidden_size,
layer = hidden_layer,
use_simsiam_mlp = not use_momentum,
sync_batchnorm = sync_batchnorm
)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
# 在线预测器
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# 获取网络设备并将 wrapper 设置为相同设备
device = get_module_device(net)
self.to(device)
# 发送一个模拟图像张量以实例化单例参数
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
# 获取目标编码器的单例函数
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
# 重置移动平均
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
# 更新移动平均
def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
# 前向传播函数
def forward(
self,
x,
return_embedding = False,
return_projection = True
):
assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'
if return_embedding:
return self.online_encoder(x, return_projection = return_projection)
# 获取两个增强后的图像
image_one, image_two = self.augment1(x), self.augment2(x)
# 拼接两个图像
images = torch.cat((image_one, image_two), dim = 0)
# 获取在线编码器的投影和预测
online_projections, _ = self.online_encoder(images)
online_predictions = self.online_predictor(online_projections)
online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)
with torch.no_grad():
# 获取目标编码器
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_projections, _ = target_encoder(images)
target_projections = target_projections.detach()
target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0)
# 计算损失
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()


浙公网安备 33010602011771号