Lucidrains 系列项目源码解析(四十七)
.\lucidrains\vit-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'vit-pytorch',
# 查找除了 'examples' 文件夹之外的所有包
packages = find_packages(exclude=['examples']),
# 版本号
version = '1.6.5',
# 许可证类型
license='MIT',
# 描述
description = 'Vision Transformer (ViT) - Pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/vit-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'attention mechanism',
'image recognition'
],
# 安装依赖
install_requires=[
'einops>=0.7.0',
'torch>=1.10',
'torchvision'
],
# 设置需要的依赖
setup_requires=[
'pytest-runner',
],
# 测试需要的依赖
tests_require=[
'pytest',
'torch==1.12.1',
'torchvision==0.13.1'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\vit-pytorch\tests\test.py
# 导入 torch 库
import torch
# 从 vit_pytorch 库中导入 ViT 类
from vit_pytorch import ViT
# 定义测试函数
def test():
# 创建 ViT 模型对象,设置参数:图像大小为 256,patch 大小为 32,类别数为 1000,特征维度为 1024,深度为 6,注意力头数为 16,MLP 隐藏层维度为 2048,dropout 概率为 0.1,嵌入层 dropout 概率为 0.1
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# 生成一个形状为 (1, 3, 256, 256) 的随机张量作为输入图像
img = torch.randn(1, 3, 256, 256)
# 将输入图像传入 ViT 模型进行预测
preds = v(img)
# 断言预测结果的形状为 (1, 1000),如果不符合则抛出异常信息 'correct logits outputted'
assert preds.shape == (1, 1000), 'correct logits outputted'
.\lucidrains\vit-pytorch\vit_pytorch\ats_vit.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.nn.utils.rnn 中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat 函数和 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 将输入转换为元组
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 自适应令牌采样函数和类
# 计算输入张量的自然对数,避免输入为 0 时出现错误
def log(t, eps = 1e-6):
return torch.log(t + eps)
# 生成服从 Gumbel 分布的随机数
def sample_gumbel(shape, device, dtype, eps = 1e-6):
u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
return -log(-log(u, eps), eps)
# 在指定维度上对输入张量进行批量索引选择
def batched_index_select(values, indices, dim = 1):
# 获取值张量和索引张量的维度信息
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
# 将索引张量扩展到与值张量相同的维度
indices = indices[(..., *((None,) * len(value_dims))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
# 自适应令牌采样类
class AdaptiveTokenSampling(nn.Module):
def __init__(self, output_num_tokens, eps = 1e-6):
super().__init__()
self.eps = eps
self.output_num_tokens = output_num_tokens
# 定义一个前向传播函数,接收注意力值、数值、掩码作为输入
def forward(self, attn, value, mask):
# 获取注意力值的头数、输出的标记数、eps值、设备和数据类型
heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype
# 获取CLS标记到所有其他标记的注意力值
cls_attn = attn[..., 0, 1:]
# 计算数值的范数,用于加权得分,如论文中所述
value_norms = value[..., 1:, :].norm(dim=-1)
# 通过数值的范数加权注意力得分,对所有头求和
cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)
# 归一化为1
normed_cls_attn = cls_attn / (cls_attn.sum(dim=-1, keepdim=True) + eps)
# 不使用逆变换采样,而是反转softmax并使用gumbel-max采样
pseudo_logits = log(normed_cls_attn)
# 为gumbel-max采样屏蔽伪对数
mask_without_cls = mask[:, 1:]
mask_value = -torch.finfo(attn.dtype).max / 2
pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)
# 扩展k次,k为自适应采样数
pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k=output_num_tokens)
pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device=device, dtype=dtype)
# gumbel-max采样并加一以保留0用于填充/掩码
sampled_token_ids = pseudo_logits.argmax(dim=-1) + 1
# 使用torch.unique计算唯一值,然后从右侧填充序列
unique_sampled_token_ids_list = [torch.unique(t, sorted=True) for t in torch.unbind(sampled_token_ids)]
unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first=True)
# 基于填充计算新的掩码
new_mask = unique_sampled_token_ids != 0
# CLS标记永远不会被屏蔽(得到True值)
new_mask = F.pad(new_mask, (1, 0), value=True)
# 在前面添加一个0标记ID以保留CLS注意力得分
unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value=0)
expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h=heads)
# 收集新的注意力得分
new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim=2)
# 返回采样的注意力得分、新掩码(表示填充)以及采样的标记索引(用于残差)
return new_attn, new_mask, unique_sampled_token_ids
# 定义前馈神经网络类
class FeedForward(nn.Module):
# 初始化函数,定义网络结构
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
# 使用 nn.Sequential 定义网络层次结构
self.net = nn.Sequential(
nn.LayerNorm(dim), # Layer normalization
nn.Linear(dim, hidden_dim), # 线性变换
nn.GELU(), # GELU 激活函数
nn.Dropout(dropout), # Dropout 正则化
nn.Linear(hidden_dim, dim), # 线性变换
nn.Dropout(dropout) # Dropout 正则化
)
# 前向传播函数
def forward(self, x):
return self.net(x)
# 定义注意力机制类
class Attention(nn.Module):
# 初始化函数,定义注意力机制结构
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim) # Layer normalization
self.attend = nn.Softmax(dim = -1) # Softmax 注意力权重计算
self.dropout = nn.Dropout(dropout) # Dropout 正则化
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) # 线性变换
self.output_num_tokens = output_num_tokens
self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), # 线性变换
nn.Dropout(dropout) # Dropout 正则化
)
# 前向传播函数
def forward(self, x, *, mask):
num_tokens = x.shape[1]
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(mask):
dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(~dots_mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
sampled_token_ids = None
# 如果启用了自适应令牌采样,并且令牌数量大于输出令牌数量
if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), mask, sampled_token_ids
# 定义 Transformer 类
class Transformer(nn.Module):
# 初始化函数,定义 Transformer 结构
def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
# 定义前向传播函数,接受输入张量 x
def forward(self, x):
# 获取输入张量 x 的形状的前两个维度大小和设备信息
b, n, device = *x.shape[:2], x.device
# 使用掩码来跟踪填充位置,以便在采样标记时移除重复项
mask = torch.ones((b, n), device=device, dtype=torch.bool)
# 创建一个包含从 0 到 n-1 的张量,设备信息与输入张量 x 一致
token_ids = torch.arange(n, device=device)
token_ids = repeat(token_ids, 'n -> b n', b=b)
# 遍历每个注意力层和前馈层
for attn, ff in self.layers:
# 调用注意力层的前向传播函数,获取注意力输出、更新后的掩码和采样的标记
attn_out, mask, sampled_token_ids = attn(x, mask=mask)
# 当进行标记采样时,需要使用采样的标记 id 从输入张量中选择对应的标记
if exists(sampled_token_ids):
x = batched_index_select(x, sampled_token_ids, dim=1)
token_ids = batched_index_select(token_ids, sampled_token_ids, dim=1)
# 更新输入张量,加上注意力输出
x = x + attn_out
# 经过前馈层处理后再加上原始输入,得到最终输出
x = ff(x) + x
# 返回最终输出张量和标记 id
return x, token_ids
class ViT(nn.Module):
# 定义 ViT 模型类,继承自 nn.Module
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
# 初始化函数,接收参数 image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels, dim_head, dropout, emb_dropout
super().__init__()
# 调用父类的初始化函数
image_height, image_width = pair(image_size)
# 获取图像的高度和宽度
patch_height, patch_width = pair(patch_size)
# 获取补丁的高度和宽度
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
# 断言,确保图像的尺寸能够被补丁的尺寸整除
num_patches = (image_height // patch_height) * (image_width // patch_width)
# 计算补丁的数量
patch_dim = channels * patch_height * patch_width
# 计算每个补丁的维度
self.to_patch_embedding = nn.Sequential(
# 定义将图像转换为补丁嵌入的序列
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
# 重新排列图像的通道和补丁的维度
nn.LayerNorm(patch_dim),
# 对每个补丁进行 LayerNorm
nn.Linear(patch_dim, dim),
# 线性变换将每个补丁的维度映射到指定的维度 dim
nn.LayerNorm(dim)
# 对映射后的维度进行 LayerNorm
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 初始化位置嵌入参数
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 初始化类别标记参数
self.dropout = nn.Dropout(emb_dropout)
# 定义丢弃层,用于嵌入的丢弃
self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
# 初始化 Transformer 模型
self.mlp_head = nn.Sequential(
# 定义 MLP 头部
nn.LayerNorm(dim),
# 对输入进行 LayerNorm
nn.Linear(dim, num_classes)
# 线性变换将维度映射到类别数量
)
def forward(self, img, return_sampled_token_ids = False):
# 定义前向传播函数,接收图像和是否返回采样的令牌 ID
x = self.to_patch_embedding(img)
# 将图像转换为补丁嵌入
b, n, _ = x.shape
# 获取 x 的形状信息
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
# 重复类别标记,使其与补丁嵌入的形状相同
x = torch.cat((cls_tokens, x), dim=1)
# 拼接类别标记和补丁嵌入
x += self.pos_embedding[:, :(n + 1)]
# 添加位置嵌入
x = self.dropout(x)
# 对输入进行丢弃
x, token_ids = self.transformer(x)
# 使用 Transformer 进行转换
logits = self.mlp_head(x[:, 0])
# 使用 MLP 头部生成输出
if return_sampled_token_ids:
# 如果需要返回采样的令牌 ID
token_ids = token_ids[:, 1:] - 1
# 移除类别标记并减去 1 以使 -1 成为填充
return logits, token_ids
# 返回输出和令牌 ID
return logits
# 返回输出
.\lucidrains\vit-pytorch\vit_pytorch\cait.py
from random import randrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
# 检查值是否存在
def exists(val):
return val is not None
# 对层应用 dropout
def dropout_layers(layers, dropout):
if dropout == 0:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
# 确保至少有一层保留
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# classes
# 缩放层
class LayerScale(nn.Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18: # 根据深度选择初始化值,详见论文第2节
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 注意力机制
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
attn = self.attend(dots)
attn = self.dropout(attn)
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# Transformer 模型
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for attn, ff in layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class CaiT(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
image_size, # 图像大小
patch_size, # 补丁大小
num_classes, # 类别数量
dim, # 特征维度
depth, # 深度
cls_depth, # 分类深度
heads, # 多头注意力头数
mlp_dim, # MLP隐藏层维度
dim_head = 64, # 头维度
dropout = 0., # 丢弃率
emb_dropout = 0., # 嵌入层丢弃率
layer_dropout = 0. # 层丢弃率
):
# 调用父类初始化函数
super().__init__()
# 检查图像尺寸是否能被补丁大小整除
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
# 计算补丁数量
num_patches = (image_size // patch_size) ** 2
# 计算补丁维度
patch_dim = 3 * patch_size ** 2
# 补丁嵌入层
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
# 位置嵌入
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
# 分类令牌
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 丢弃层
self.dropout = nn.Dropout(emb_dropout)
# 补丁Transformer
self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
# 分类Transformer
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
# MLP头
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, img):
# 补丁嵌入
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 添加位置嵌入
x += self.pos_embedding[:, :n]
x = self.dropout(x)
# 补丁Transformer
x = self.patch_transformer(x)
# 重复分类令牌
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
# 分类Transformer
x = self.cls_transformer(cls_tokens, context = x)
# 返回MLP头的结果
return self.mlp_head(x[:, 0])
.\lucidrains\vit-pytorch\vit_pytorch\cct.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 定义辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 将输入转换为元组的函数
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# CCT 模型
# 定义导出的 CCT 模型名称列表
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
# 定义创建不同层数 CCT 模型的函数
def cct_2(*args, **kwargs):
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_4(*args, **kwargs):
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs)
def cct_6(*args, **kwargs):
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_7(*args, **kwargs):
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_8(*args, **kwargs):
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs)
def cct_14(*args, **kwargs):
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
def cct_16(*args, **kwargs):
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs)
# 创建 CCT 模型的内部函数
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
# 计算默认的步长和填充值
stride = default(stride, max(1, (kernel_size // 2) - 1))
padding = default(padding, max(1, (kernel_size // 2)))
# 返回 CCT 模型
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
embedding_dim=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
*args, **kwargs)
# 位置编码
# 创建正弦位置编码的函数
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return rearrange(pe, '... -> 1 ...')
# 模块
# 定义注意力机制模块
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.heads = num_heads
head_dim = dim // self.heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(attention_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q = q * self.scale
attn = einsum('b h i d, b h j d -> b h i j', q, k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = einsum('b h i j, b h j d -> b h i d', attn, v)
x = rearrange(x, 'b h n d -> b n (h d)')
return self.proj_drop(self.proj(x))
# 定义 Transformer 编码器层模块
class TransformerEncoderLayer(nn.Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
# 初始化函数,定义了 Transformer Encoder 层的结构
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
# 调用父类的初始化函数
super().__init__()
# 对输入进行 Layer Normalization
self.pre_norm = nn.LayerNorm(d_model)
# 定义自注意力机制
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
# 第一个线性层
self.linear1 = nn.Linear(d_model, dim_feedforward)
# 第一个 Dropout 层
self.dropout1 = nn.Dropout(dropout)
# 第一个 Layer Normalization 层
self.norm1 = nn.LayerNorm(d_model)
# 第二个线性层
self.linear2 = nn.Linear(dim_feedforward, d_model)
# 第二个 Dropout 层
self.dropout2 = nn.Dropout(dropout)
# DropPath 模块
self.drop_path = DropPath(drop_path_rate)
# 激活函数为 GELU
self.activation = F.gelu
# 前向传播函数
def forward(self, src, *args, **kwargs):
# 使用自注意力机制处理输入,并加上 DropPath 模块
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
# 对结果进行 Layer Normalization
src = self.norm1(src)
# 第一个线性层、激活函数、Dropout、第二个线性层的组合
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
# 将结果与 DropPath 模块处理后的结果相加
src = src + self.drop_path(self.dropout2(src2))
# 返回处理后的结果
return src
class DropPath(nn.Module):
# 初始化 DropPath 类
def __init__(self, drop_prob=None):
# 调用父类的初始化方法
super().__init__()
# 将传入的 drop_prob 转换为浮点数
self.drop_prob = float(drop_prob)
# 前向传播方法
def forward(self, x):
# 获取输入 x 的批次大小、drop_prob、设备和数据类型
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
# 如果 drop_prob 小于等于 0 或者不处于训练模式,则直接返回输入 x
if drop_prob <= 0. or not self.training:
return x
# 计算保留概率
keep_prob = 1 - self.drop_prob
# 构建形状元组
shape = (batch, *((1,) * (x.ndim - 1)))
# 生成保留掩码
keep_mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < keep_prob
# 对输入 x 进行 DropPath 操作
output = x.div(keep_prob) * keep_mask.float()
return output
class Tokenizer(nn.Module):
# 初始化 Tokenizer 类
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False):
# 调用父类的初始化方法
super().__init()
# 构建卷积层的通道数列表
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
# 构建通道数列表的配对
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
# 构建卷积层序列
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(chan_in, chan_out,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
# 对模型参数进行初始化
self.apply(self.init_weight)
# 计算序列长度
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
# 前向传播方法
def forward(self, x):
# 对卷积层的输出进行重排列
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
# 初始化权重方法
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
class TransformerClassifier(nn.Module):
# 初始化函数,设置模型的各种参数
def __init__(self,
seq_pool=True, # 是否使用序列池化
embedding_dim=768, # 嵌入维度
num_layers=12, # 编码器层数
num_heads=12, # 注意力头数
mlp_ratio=4.0, # MLP 扩展比例
num_classes=1000, # 类别数
dropout_rate=0.1, # Dropout 比例
attention_dropout=0.1, # 注意力 Dropout 比例
stochastic_depth_rate=0.1, # 随机深度比例
positional_embedding='sine', # 位置编码类型
sequence_length=None, # 序列长度
*args, **kwargs): # 其他参数
super().__init__() # 调用父类的初始化函数
assert positional_embedding in {'sine', 'learnable', 'none'} # 断言位置编码类型合法
dim_feedforward = int(embedding_dim * mlp_ratio) # 计算前馈网络维度
self.embedding_dim = embedding_dim # 设置嵌入维度
self.sequence_length = sequence_length # 设置序列长度
self.seq_pool = seq_pool # 设置是否使用序列池化
assert exists(sequence_length) or positional_embedding == 'none', \ # 断言序列长度存在或位置编码为'none'
f"Positional embedding is set to {positional_embedding} and" \ # 打印位置编码设置信息
f" the sequence length was not specified."
if not seq_pool: # 如果不使用序列池化
sequence_length += 1 # 序列长度加一
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True) # 创建类别嵌入参数
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1) # 创建注意力池化层
if positional_embedding == 'none': # 如果位置编码为'none'
self.positional_emb = None # 不使用位置编码
elif positional_embedding == 'learnable': # 如果位置编码为'learnable'
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim), # 创建可学习位置编码参数
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2) # 对位置编码参数进行初始化
else:
self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim), # 创建正弦位置编码参数
requires_grad=False)
self.dropout = nn.Dropout(p=dropout_rate) # 创建 Dropout 层
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] # 计算随机深度比例列表
self.blocks = nn.ModuleList([ # 创建 Transformer 编码器层列表
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
for layer_dpr in dpr])
self.norm = nn.LayerNorm(embedding_dim) # 创建 LayerNorm 层
self.fc = nn.Linear(embedding_dim, num_classes) # 创建全连接层
self.apply(self.init_weight) # 应用初始化权重函数
# 前向传播函数
def forward(self, x):
b = x.shape[0] # 获取 batch 大小
if not exists(self.positional_emb) and x.size(1) < self.sequence_length: # 如果位置编码不存在且序列长度小于指定长度
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) # 对输入进行填充
if not self.seq_pool: # 如果不使用序列池化
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b) # 重复类别嵌入
x = torch.cat((cls_token, x), dim=1) # 拼接类别嵌入和输入
if exists(self.positional_emb): # 如果位置编码存在
x += self.positional_emb # 加上位置编码
x = self.dropout(x) # Dropout
for blk in self.blocks: # 遍历编码器层
x = blk(x) # 应用编码器层
x = self.norm(x) # LayerNorm
if self.seq_pool: # 如果使用序列池化
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n') # 注意力权重计算
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x) # 加权池化
else:
x = x[:, 0] # 取第一个位置的输出作为结果
return self.fc(x) # 全连接层输出结果
# 初始化权重函数
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear): # 如果是线性层
nn.init.trunc_normal_(m.weight, std=.02) # 初始化权重
if isinstance(m, nn.Linear) and exists(m.bias): # 如果是线性层且存在偏置
nn.init.constant_(m.bias, 0) # 初始化偏置为0
elif isinstance(m, nn.LayerNorm): # 如果是 LayerNorm 层
nn.init.constant_(m.bias, 0) # 初始化偏置为0
nn.init.constant_(m.weight, 1.0) # 初始化权重为1.0
# 定义 CCT 类,继承自 nn.Module
class CCT(nn.Module):
# 初始化函数,设置各种参数
def __init__(
self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs
):
# 调用父类的初始化函数
super().__init__()
# 获取图像的高度和宽度
img_height, img_width = pair(img_size)
# 初始化 Tokenizer 对象
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
max_pool=True,
activation=nn.ReLU,
n_conv_layers=n_conv_layers,
conv_bias=False)
# 初始化 TransformerClassifier 对象
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_height,
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
*args, **kwargs)
# 前向传播函数
def forward(self, x):
# 对输入数据进行编码
x = self.tokenizer(x)
# 使用 Transformer 进行分类
return self.classifier(x)
.\lucidrains\vit-pytorch\vit_pytorch\cct_3d.py
import torch # 导入 PyTorch 库
from torch import nn, einsum # 从 PyTorch 库中导入 nn 模块和 einsum 函数
import torch.nn.functional as F # 从 PyTorch 库中导入 F 模块
from einops import rearrange, repeat # 从 einops 库中导入 rearrange 和 repeat 函数
# helpers
def exists(val):
return val is not None # 判断变量是否存在的辅助函数
def default(val, d):
return val if exists(val) else d # 如果变量存在则返回变量,否则返回默认值的辅助函数
def pair(t):
return t if isinstance(t, tuple) else (t, t) # 如果输入是元组则返回输入,否则返回包含输入两次的元组
# CCT Models
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16'] # 定义导出的模型名称列表
# 定义不同层数的 CCT 模型函数
def cct_2(*args, **kwargs):
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs) # 返回 2 层 CCT 模型
def cct_4(*args, **kwargs):
return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
*args, **kwargs) # 返回 4 层 CCT 模型
def cct_6(*args, **kwargs):
return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs) # 返回 6 层 CCT 模型
def cct_7(*args, **kwargs):
return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs) # 返回 7 层 CCT 模型
def cct_8(*args, **kwargs):
return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
*args, **kwargs) # 返回 8 层 CCT 模型
def cct_14(*args, **kwargs):
return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs) # 返回 14 层 CCT 模型
def cct_16(*args, **kwargs):
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
*args, **kwargs) # 返回 16 层 CCT 模型
# 定义 CCT 模型函数
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
stride = default(stride, max(1, (kernel_size // 2) - 1)) # 设置默认的步长
padding = default(padding, max(1, (kernel_size // 2))) # 设置默认的填充大小
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
embedding_dim=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
*args, **kwargs) # 返回 CCT 模型
# positional
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)]) # 计算正弦余弦位置编码
pe[:, 0::2] = torch.sin(pe[:, 0::2]) # 偶数列使用正弦函数
pe[:, 1::2] = torch.cos(pe[:, 1::2]) # 奇数列使用余弦函数
return rearrange(pe, '... -> 1 ...') # 重新排列位置编码的维度
# modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.heads = num_heads # 设置注意力头数
head_dim = dim // self.heads # 计算每个头的维度
self.scale = head_dim ** -0.5 # 缩放因子
self.qkv = nn.Linear(dim, dim * 3, bias=False) # 线性变换层
self.attn_drop = nn.Dropout(attention_dropout) # 注意力丢弃层
self.proj = nn.Linear(dim, dim) # 投影层
self.proj_drop = nn.Dropout(projection_dropout) # 投影丢弃层
def forward(self, x):
B, N, C = x.shape # 获取输入张量的形状
qkv = self.qkv(x).chunk(3, dim = -1) # 将线性变换后的张量切分为 Q、K、V
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # 重排 Q、K、V 的维度
q = q * self.scale # 缩放 Q
attn = einsum('b h i d, b h j d -> b h i j', q, k) # 计算注意力分数
attn = attn.softmax(dim=-1) # 对注意力分数进行 softmax
attn = self.attn_drop(attn) # 使用注意力丢弃层
x = einsum('b h i j, b h j d -> b h i d', attn, v) # 计算加权后的 V
x = rearrange(x, 'b h n d -> b n (h d)') # 重排输出张量的维度
return self.proj_drop(self.proj(x)) # 使用投影丢弃层进行投影
class TransformerEncoderLayer(nn.Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and
rwightman's timm package.
"""
# 初始化函数,定义了 Transformer Encoder 层的结构
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
# 调用父类的初始化函数
super().__init__()
# 对输入进行 Layer Normalization
self.pre_norm = nn.LayerNorm(d_model)
# 定义自注意力机制
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
# 第一个线性层
self.linear1 = nn.Linear(d_model, dim_feedforward)
# 第一个 Dropout 层
self.dropout1 = nn.Dropout(dropout)
# 第一个 Layer Normalization 层
self.norm1 = nn.LayerNorm(d_model)
# 第二个线性层
self.linear2 = nn.Linear(dim_feedforward, d_model)
# 第二个 Dropout 层
self.dropout2 = nn.Dropout(dropout)
# DropPath 模块
self.drop_path = DropPath(drop_path_rate)
# 激活函数为 GELU
self.activation = F.gelu
# 前向传播函数
def forward(self, src, *args, **kwargs):
# 使用自注意力机制对输入进行处理,并加上 DropPath 模块
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
# 对结果进行 Layer Normalization
src = self.norm1(src)
# 第二个线性层的计算过程
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
# 将第二个线性层的结果加上 DropPath 模块
src = src + self.drop_path(self.dropout2(src2))
# 返回处理后的结果
return src
class DropPath(nn.Module):
# 初始化 DropPath 类
def __init__(self, drop_prob=None):
# 调用父类的初始化方法
super().__init__()
# 将传入的 drop_prob 转换为浮点数
self.drop_prob = float(drop_prob)
# 前向传播方法
def forward(self, x):
# 获取输入 x 的批量大小、drop_prob、设备和数据类型
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
# 如果 drop_prob 小于等于 0 或者不处于训练模式,则直接返回输入 x
if drop_prob <= 0. or not self.training:
return x
# 计算保留概率
keep_prob = 1 - self.drop_prob
# 构建形状元组
shape = (batch, *((1,) * (x.ndim - 1)))
# 生成保留掩码
keep_mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < keep_prob
# 对输入 x 进行处理并返回输出
output = x.div(keep_prob) * keep_mask.float()
return output
class Tokenizer(nn.Module):
# 初始化 Tokenizer 类
def __init__(
self,
frame_kernel_size,
kernel_size,
stride,
padding,
frame_stride=1,
frame_pooling_stride=1,
frame_pooling_kernel_size=1,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False
):
# 调用父类的初始化方法
super().__init__()
# 构建卷积层的通道数列表
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
# 构建通道数列表的配对
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
# 构建卷积层序列
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv3d(chan_in, chan_out,
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
stride=(frame_stride, stride, stride),
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
# 对模型进行权重初始化
self.apply(self.init_weight)
# 计算序列长度
def sequence_length(self, n_channels=3, frames=8, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, frames, height, width))).shape[1]
# 前向传播方法
def forward(self, x):
# 对输入 x 进行卷积操作并返回重排后的输出
x = self.conv_layers(x)
return rearrange(x, 'b c f h w -> b (f h w) c')
# 初始化权重方法
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight)
class TransformerClassifier(nn.Module):
# 初始化 TransformerClassifier 类
def __init__(
self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout_rate=0.1,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
positional_embedding='sine',
sequence_length=None,
*args, **kwargs
):
# 调用父类的构造函数
super().__init__()
# 断言位置编码在{'sine', 'learnable', 'none'}中
assert positional_embedding in {'sine', 'learnable', 'none'}
# 计算前馈网络的维度
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
# 断言序列长度存在或者位置编码为'none'
assert exists(sequence_length) or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
# 如果不使用序列池化
if not seq_pool:
sequence_length += 1
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim))
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
# 根据位置编码类型初始化位置编码
if positional_embedding == 'none':
self.positional_emb = None
elif positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim))
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.register_buffer('positional_emb', sinusoidal_embedding(sequence_length, embedding_dim))
# 初始化Dropout层
self.dropout = nn.Dropout(p=dropout_rate)
# 生成随机Drop Path率
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
# 创建Transformer编码器层
self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
for layer_dpr in dpr])
# 初始化LayerNorm层
self.norm = nn.LayerNorm(embedding_dim)
# 初始化全连接层
self.fc = nn.Linear(embedding_dim, num_classes)
# 应用初始化权重函数
self.apply(self.init_weight)
@staticmethod
def init_weight(m):
# 初始化线性层的权重
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
# 如果是线性层且存在偏置项,则初始化偏置项
if isinstance(m, nn.Linear) and exists(m.bias):
nn.init.constant_(m.bias, 0)
# 初始化LayerNorm层的权重和偏置项
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# 获取批量大小
b = x.shape[0]
# 如果位置编码不存在且输入序列长度小于设定的序列长度,则进行填充
if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
# 如果不使用序列池化,则在输入序列前添加类别标记
if not self.seq_pool:
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b=b)
x = torch.cat((cls_token, x), dim=1)
# 如果位置编码存在,则加上位置编码
if exists(self.positional_emb):
x += self.positional_emb
# Dropout层
x = self.dropout(x)
# 遍历Transformer编码器层
for blk in self.blocks:
x = blk(x)
# LayerNorm层
x = self.norm(x)
# 如果使用序列池化,则计算注意力权重并进行加权求和
if self.seq_pool:
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim=1), x)
else:
x = x[:, 0]
# 全连接层
return self.fc(x)
# 定义 CCT 类,继承自 nn.Module
class CCT(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
img_size=224, # 图像大小,默认为 224
num_frames=8, # 帧数,默认为 8
embedding_dim=768, # 嵌入维度,默认为 768
n_input_channels=3, # 输入通道数,默认为 3
n_conv_layers=1, # 卷积层数,默认为 1
frame_stride=1, # 帧步长,默认为 1
frame_kernel_size=3, # 帧卷积核大小,默认为 3
frame_pooling_kernel_size=1, # 帧池化核大小,默认为 1
frame_pooling_stride=1, # 帧池化步长,默认为 1
kernel_size=7, # 卷积核大小,默认为 7
stride=2, # 步长,默认为 2
padding=3, # 填充,默认为 3
pooling_kernel_size=3, # 池化核大小,默认为 3
pooling_stride=2, # 池化步长,默认为 2
pooling_padding=1, # 池化填充,默认为 1
*args, **kwargs # 其他参数
):
super().__init__() # 调用父类的初始化函数
img_height, img_width = pair(img_size) # 获取图像的高度和宽度
# 初始化 Tokenizer 对象
self.tokenizer = Tokenizer(
n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
frame_stride=frame_stride,
frame_kernel_size=frame_kernel_size,
frame_pooling_stride=frame_pooling_stride,
frame_pooling_kernel_size=frame_pooling_kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
max_pool=True,
activation=nn.ReLU,
n_conv_layers=n_conv_layers,
conv_bias=False
)
# 初始化 TransformerClassifier 对象
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(
n_channels=n_input_channels,
frames=num_frames,
height=img_height,
width=img_width
),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
*args, **kwargs
)
# 前向传播函数
def forward(self, x):
x = self.tokenizer(x) # 对输入数据进行编码
return self.classifier(x) # 对编码后的���据进行分类
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
# 辅助函数
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# 交叉嵌入层
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
dim_out,
kernel_sizes,
stride = 2
):
super().__init__()
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# 计算每个尺度的维度
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
# 对输入进行卷积操作,并将结果拼接在一起
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
# 动态位置偏置
def DynamicPositionBias(dim):
return nn.Sequential(
nn.Linear(2, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, 1),
Rearrange('... () -> ...')
)
# transformer 类
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1)
)
class Attention(nn.Module):
def __init__(
self,
dim,
attn_type,
window_size,
dim_head = 32,
dropout = 0.
):
super().__init__()
assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
heads = dim // dim_head
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.attn_type = attn_type
self.window_size = window_size
self.norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
# 位置
self.dpb = DynamicPositionBias(dim // 4)
# 计算和存储用于检索偏置的索引
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = grid[:, None] - grid[None, :]
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
# 定义前向传播函数,接受输入 x
def forward(self, x):
# 解构 x 的形状,获取高度、宽度、头数、窗口大小和设备信息
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
# 对输入进行预处理
x = self.norm(x)
# 根据不同的注意力类型重新排列输入,以便进行短距离或长距离注意力
if self.attn_type == 'short':
x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
elif self.attn_type == 'long':
x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)
# 将输入转换为查询、键、值
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# 将查询、键、值按头数进行分割
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
q = q * self.scale
# 计算注意力矩阵
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 添加动态位置偏置
pos = torch.arange(-wsz, wsz + 1, device = device)
rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
biases = self.dpb(rel_pos.float())
rel_pos_bias = biases[self.rel_pos_indices]
sim = sim + rel_pos_bias
# 注意力权重归一化
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# 合并头部
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
out = self.to_out(out)
# 根据不同的注意力类型重新排列输出
if self.attn_type == 'short':
out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
elif self.attn_type == 'long':
out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)
return out
# 定义一个名为 Transformer 的神经网络模块
class Transformer(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
*,
local_window_size,
global_window_size,
depth = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
# 调用父类的初始化函数
super().__init__()
# 初始化一个空的神经网络模块列表
self.layers = nn.ModuleList([])
# 循环创建指定深度的神经网络层
for _ in range(depth):
# 每层包含两个注意力机制和两个前馈神经网络
self.layers.append(nn.ModuleList([
Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout),
Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))
# 前向传播函数
def forward(self, x):
# 遍历每一层的注意力机制和前馈神经网络
for short_attn, short_ff, long_attn, long_ff in self.layers:
# 执行短程注意力机制和前馈神经网络
x = short_attn(x) + x
x = short_ff(x) + x
# 执行长程注意力机制和前馈神经网络
x = long_attn(x) + x
x = long_ff(x) + x
# 返回处理后的数据
return x
# 定义一个名为 CrossFormer 的神经网络模块
class CrossFormer(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
global_window_size = (8, 4, 2, 1),
local_window_size = 7,
cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
cross_embed_strides = (4, 2, 2, 2),
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3
):
# 调用父类的初始化函数
super().__init__()
# 将参数转换为元组形式
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
global_window_size = cast_tuple(global_window_size, 4)
local_window_size = cast_tuple(local_window_size, 4)
cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
cross_embed_strides = cast_tuple(cross_embed_strides, 4)
# 断言确保参数长度为4
assert len(dim) == 4
assert len(depth) == 4
assert len(global_window_size) == 4
assert len(local_window_size) == 4
assert len(cross_embed_kernel_sizes) == 4
assert len(cross_embed_strides) == 4
# 定义维度相关变量
last_dim = dim[-1]
dims = [channels, *dim]
dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))
# 初始化一个空的神经网络模块列表
self.layers = nn.ModuleList([])
# 循环创建交叉嵌入层和 Transformer 层
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
self.layers.append(nn.ModuleList([
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
# 定义最终的逻辑层
self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(last_dim, num_classes)
)
# 前向传播函数
def forward(self, x):
# 遍历每一层的交叉嵌入层和 Transformer 层
for cel, transformer in self.layers:
# 执行交叉嵌入层
x = cel(x)
# 执行 Transformer 层
x = transformer(x)
# 返回最终的逻辑结果
return self.to_logits(x)
.\lucidrains\vit-pytorch\vit_pytorch\cross_vit.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
# 定义神经网络结构
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 注意力机制
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = default(context, x)
if kv_include_self:
context = torch.cat((x, context), dim = 1) # 交叉注意力需要 CLS 标记包含自身作为键/值
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# Transformer 编码器,用于小和大补丁
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# 投影 CLS 标记,以防小和大补丁标记具有不同的维度
class ProjectInOut(nn.Module):
def __init__(self, dim_in, dim_out, fn):
super().__init__()
self.fn = fn
need_projection = dim_in != dim_out
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
def forward(self, x, *args, **kwargs):
x = self.project_in(x)
x = self.fn(x, *args, **kwargs)
x = self.project_out(x)
return x
# 交叉���意力 Transformer
class CrossTransformer(nn.Module):
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
]))
# 定义一个前向传播函数,接受两个输入:sm_tokens和lg_tokens
def forward(self, sm_tokens, lg_tokens):
# 将输入的sm_tokens和lg_tokens分别拆分为(sm_cls, sm_patch_tokens)和(lg_cls, lg_patch_tokens)
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))
# 遍历self.layers中的每一层,每一层包含sm_attend_lg和lg_attend_sm
for sm_attend_lg, lg_attend_sm in self.layers:
# 对sm_cls进行注意力计算,使用lg_patch_tokens作为上下文,kv_include_self设置为True,然后加上原始sm_cls
sm_cls = sm_attend_lg(sm_cls, context=lg_patch_tokens, kv_include_self=True) + sm_cls
# 对lg_cls进行注意力计算,使用sm_patch_tokens作为上下文,kv_include_self设置为True,然后加上原始lg_cls
lg_cls = lg_attend_sm(lg_cls, context=sm_patch_tokens, kv_include_self=True) + lg_cls
# 将sm_cls和sm_patch_tokens在维度1上拼接起来
sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim=1)
# 将lg_cls和lg_patch_tokens在维度1上拼接起来
lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim=1)
# 返回拼接后的sm_tokens和lg_tokens
return sm_tokens, lg_tokens
# 定义多尺度编码器类
class MultiScaleEncoder(nn.Module):
def __init__(
self,
*,
depth, # 编码器深度
sm_dim, # 小尺度维度
lg_dim, # 大尺度维度
sm_enc_params, # 小尺度编码器参数
lg_enc_params, # 大尺度编码器参数
cross_attn_heads, # 跨尺度注意力头数
cross_attn_depth, # 跨尺度注意力深度
cross_attn_dim_head = 64, # 跨尺度注意力头维度
dropout = 0. # 丢弃率
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params), # 小尺度变换器
Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params), # 大尺度变换器
CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout) # 跨尺度变换器
]))
def forward(self, sm_tokens, lg_tokens):
for sm_enc, lg_enc, cross_attend in self.layers:
sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens) # 小尺度编码器和大尺度编码器
sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens) # 跨尺度注意力
return sm_tokens, lg_tokens
# 基于补丁的图像到标记嵌入器类
class ImageEmbedder(nn.Module):
def __init__(
self,
*,
dim, # 维度
image_size, # 图像尺寸
patch_size, # 补丁尺寸
dropout = 0. # 丢弃率
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), # 图像转换为补丁
nn.LayerNorm(patch_dim), # 层归一化
nn.Linear(patch_dim, dim), # 线性变换
nn.LayerNorm(dim) # 层归一化
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 位置嵌入
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # 类别标记
self.dropout = nn.Dropout(dropout) # 丢弃层
def forward(self, img):
x = self.to_patch_embedding(img) # 图像转换为补丁嵌入
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) # 重复类别标记
x = torch.cat((cls_tokens, x), dim=1) # 拼接类别标记和补丁嵌入
x += self.pos_embedding[:, :(n + 1)] # 加上位置嵌入
return self.dropout(x) # 返回结果经过丢弃层处理
# 跨ViT类
class CrossViT(nn.Module):
def __init__(
self,
*,
image_size, # 图像尺寸
num_classes, # 类别数
sm_dim, # 小尺度维度
lg_dim, # 大尺度维度
sm_patch_size = 12, # 小尺度补丁尺寸
sm_enc_depth = 1, # 小尺度编码器深度
sm_enc_heads = 8, # 小尺度编码器头数
sm_enc_mlp_dim = 2048, # 小尺度编码器MLP维度
sm_enc_dim_head = 64, # 小尺度编码器头维度
lg_patch_size = 16, # 大尺度补丁尺寸
lg_enc_depth = 4, # 大尺度编码器深度
lg_enc_heads = 8, # 大尺度编码器头数
lg_enc_mlp_dim = 2048, # 大尺度编码器MLP维度
lg_enc_dim_head = 64, # 大尺度编码器头维度
cross_attn_depth = 2, # 跨尺度注意力深度
cross_attn_heads = 8, # 跨尺度注意力头数
cross_attn_dim_head = 64, # 跨尺度注意力头维度
depth = 3, # 深度
dropout = 0.1, # 丢弃率
emb_dropout = 0.1 # 嵌入丢弃率
# 初始化函数,继承父类的初始化方法
def __init__(
super().__init__()
# 创建小尺寸图像嵌入器对象
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
# 创建大尺寸图像嵌入器对象
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
# 创建多尺度编码器对象
self.multi_scale_encoder = MultiScaleEncoder(
depth = depth,
sm_dim = sm_dim,
lg_dim = lg_dim,
cross_attn_heads = cross_attn_heads,
cross_attn_dim_head = cross_attn_dim_head,
cross_attn_depth = cross_attn_depth,
sm_enc_params = dict(
depth = sm_enc_depth,
heads = sm_enc_heads,
mlp_dim = sm_enc_mlp_dim,
dim_head = sm_enc_dim_head
),
lg_enc_params = dict(
depth = lg_enc_depth,
heads = lg_enc_heads,
mlp_dim = lg_enc_mlp_dim,
dim_head = lg_enc_dim_head
),
dropout = dropout
)
# 创建小尺寸MLP头部对象
self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
# 创建大尺寸MLP头部对象
self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
# 前向传播函数
def forward(self, img):
# 获取小尺寸图像嵌入
sm_tokens = self.sm_image_embedder(img)
# 获取大尺寸图像嵌入
lg_tokens = self.lg_image_embedder(img)
# 多尺度编码器处理小尺寸和大尺寸图像嵌入
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
# 提取小尺寸和大尺寸的类别特征
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
# 小尺寸MLP头部处理小尺寸类别特征
sm_logits = self.sm_mlp_head(sm_cls)
# 大尺寸MLP头部处理大尺寸类别特征
lg_logits = self.lg_mlp_head(lg_cls)
# 返回小尺寸和大尺寸类别特征的加和
return sm_logits + lg_logits
.\lucidrains\vit-pytorch\vit_pytorch\cvt.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helper methods
# 根据条件将字典分组
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
# 根据前缀分组并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
# classes
# 自定义 LayerNorm 类
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
# 自定义 FeedForward 类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 自定义 DepthWiseConv2d 类
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.BatchNorm2d(dim_in),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
# 自定义 Attention 类
class Attention(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
padding = proj_kernel // 2
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
# 自定义 Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
# 定义一个前向传播函数,接受输入 x
def forward(self, x):
# 遍历 self.layers 中的每个元素,每个元素包含一个注意力机制和一个前馈神经网络
for attn, ff in self.layers:
# 使用注意力机制处理输入 x,并将结果与原始输入相加
x = attn(x) + x
# 使用前馈神经网络处理上一步的结果,并将结果与原始输入相加
x = ff(x) + x
# 返回处理后的结果 x
return x
# 定义一个名为 CvT 的神经网络模型,继承自 nn.Module 类
class CvT(nn.Module):
# 初始化函数,接收一系列参数
def __init__(
self,
*,
num_classes, # 类别数量
s1_emb_dim = 64, # s1 阶段的嵌入维度
s1_emb_kernel = 7, # s1 阶段的卷积核大小
s1_emb_stride = 4, # s1 阶段的卷积步长
s1_proj_kernel = 3, # s1 阶段的投影卷积核大小
s1_kv_proj_stride = 2, # s1 阶段的键值投影步长
s1_heads = 1, # s1 阶段的注意力头数
s1_depth = 1, # s1 阶段的深度
s1_mlp_mult = 4, # s1 阶段的 MLP 扩展倍数
s2_emb_dim = 192, # s2 阶段的嵌入维度
s2_emb_kernel = 3, # s2 阶段的卷积核大小
s2_emb_stride = 2, # s2 阶段的卷积步长
s2_proj_kernel = 3, # s2 阶段的投影卷积核大小
s2_kv_proj_stride = 2, # s2 阶段的键值投影步长
s2_heads = 3, # s2 阶段的注意力头数
s2_depth = 2, # s2 阶段的深度
s2_mlp_mult = 4, # s2 阶段的 MLP 扩展倍数
s3_emb_dim = 384, # s3 阶段的嵌入维度
s3_emb_kernel = 3, # s3 阶段的卷积核大小
s3_emb_stride = 2, # s3 阶段的卷积步长
s3_proj_kernel = 3, # s3 阶段的投影卷积核大小
s3_kv_proj_stride = 2, # s3 阶段的键值投影步长
s3_heads = 6, # s3 阶段的注意力头数
s3_depth = 10, # s3 阶段的深度
s3_mlp_mult = 4, # s3 阶段的 MLP 扩展倍数
dropout = 0., # Dropout 概率
channels = 3 # 输入通道数
):
# 调用父类的初始化函数
super().__init__()
# 将参数保存到字典中
kwargs = dict(locals())
# 初始化维度为输入通道数
dim = channels
# 初始化层列表
layers = []
# 遍历 s1、s2、s3 三个阶段
for prefix in ('s1', 's2', 's3'):
# 根据前缀分组参数,并从参数字典中移除前缀
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
# 将卷积、LayerNorm 和 Transformer 层添加到层列表中
layers.append(nn.Sequential(
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
LayerNorm(config['emb_dim']),
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
))
# 更新维度为当前阶段的嵌入维度
dim = config['emb_dim']
# 将所有层组成一个序列
self.layers = nn.Sequential(*layers)
# 定义输出层,包括全局平均池化、重排和全连接层
self.to_logits = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, x):
# 经过所有层得到特征向量
latents = self.layers(x)
# 将特征向量传递给输出层得到预测结果
return self.to_logits(latents)
.\lucidrains\vit-pytorch\vit_pytorch\deepvit.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# 定义一个前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim), # 对输入进行 Layer Normalization
nn.Linear(dim, hidden_dim), # 线性变换
nn.GELU(), # GELU 激活函数
nn.Dropout(dropout), # Dropout 正则化
nn.Linear(hidden_dim, dim), # 线性变换
nn.Dropout(dropout) # Dropout 正则化
)
def forward(self, x):
return self.net(x)
# 定义一个注意力机制类
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim) # 对输入进行 Layer Normalization
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) # 线性变换
self.dropout = nn.Dropout(dropout) # Dropout 正则化
self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) # 定义可学习参数
self.reattn_norm = nn.Sequential(
Rearrange('b h i j -> b i j h'), # 重新排列张量维度
nn.LayerNorm(heads), # 对输入进行 Layer Normalization
Rearrange('b i j h -> b h i j') # 重新排列张量维度
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), # 线性变换
nn.Dropout(dropout) # Dropout 正则化
)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x) # 对输入进行 Layer Normalization
qkv = self.to_qkv(x).chunk(3, dim = -1) # 将线性变换后的结果切分成三部分
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) # 重新排列张量维度
# attention
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # 计算点积
attn = dots.softmax(dim=-1) # Softmax 操作
attn = self.dropout(attn) # Dropout 正则化
# re-attention
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) # 重新排列张量维度
attn = self.reattn_norm(attn) # 对输入进行 Layer Normalization
# aggregate and out
out = einsum('b h i j, b h j d -> b h i d', attn, v) # 点积操作
out = rearrange(out, 'b h n d -> b n (h d)') # 重新排列张量维度
out = self.to_out(out) # 线性变换
return out
# 定义一个 Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), # 注意力机制
FeedForward(dim, mlp_dim, dropout = dropout) # 前馈神经网络
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x # 注意力机制的输出与输入相加
x = ff(x) + x # 前馈神经网络的输出与输入相加
return x
# 定义一个 DeepViT 类
class DeepViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), # 重新排列张量维度
nn.LayerNorm(patch_dim), # 对输入进行 Layer Normalization
nn.Linear(patch_dim, dim), # 线性变换
nn.LayerNorm(dim) # 对输入进行 Layer Normalization
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 定义可学习参数
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # 定义可学习参数
self.dropout = nn.Dropout(emb_dropout) # Dropout 正则化
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) # Transformer 模块
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim), # 对输入进行 Layer Normalization
nn.Linear(dim, num_classes) # 线性变换
)
# 前向传播函数,接收输入图像并返回预测结果
def forward(self, img):
# 将输入图像转换为补丁嵌入
x = self.to_patch_embedding(img)
# 获取批量大小、补丁数量和嵌入维度
b, n, _ = x.shape
# 重复类别标记以匹配批量大小,并与补丁嵌入连接
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
# 添加位置嵌入到输入嵌入中
x += self.pos_embedding[:, :(n + 1)]
# 对输入进行 dropout 处理
x = self.dropout(x)
# 使用 Transformer 处理输入数据
x = self.transformer(x)
# 根据池化方式计算输出结果
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
# 将输出结果转换为潜在空间
x = self.to_latent(x)
# 使用 MLP 头部处理潜在空间的输出,并返回预测结果
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\dino.py
# 导入所需的库
import copy
import random
from functools import wraps, partial
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms as T
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 如果值存在,则返回该值,否则返回默认值
def default(val, default):
return val if exists(val) else default
# 单例装饰器,用于缓存结果
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
# 获取模块所在设备
def get_module_device(module):
return next(module.parameters()).device
# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# 损失函数(论文中的算法1)
def loss_fn(
teacher_logits,
student_logits,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()
# 数据增强工具类
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# 指数移动平均
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP类用于投影器和预测器
class L2Norm(nn.Module):
def forward(self, x, eps = 1e-6):
norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
return x / norm
class MLP(nn.Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# 用于基础神经网络的包装类
# 将管理隐藏层输出的拦截并将其传递到投影器和预测器网络中
class NetWrapper(nn.Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
# 定义一个私有方法,用于在 forward hook 中保存隐藏层的输出
def _hook(self, _, input, output):
# 获取输入数据的设备信息
device = input[0].device
# 将隐藏层的输出展平并保存到字典中
self.hidden[device] = output.flatten(1)
# 注册 forward hook,用于捕获隐藏层的输出
def _register_hook(self):
# 查找指定的隐藏层
layer = self._find_layer()
# 断言找到了隐藏层
assert layer is not None, f'hidden layer ({self.layer}) not found'
# 注册 forward hook
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
# 获取投影器,用于将隐藏层的输出投影到指定维度
@singleton('projector')
def _get_projector(self, hidden):
# 获取隐藏层输出的维度
_, dim = hidden.shape
# 创建 MLP 投影器
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
# 获取输入数据的隐藏层输出
def get_embedding(self, x):
# 如果隐藏层为最后一层,则直接返回网络的输出
if self.layer == -1:
return self.net(x)
# 如果 hook 没有注册,则注册 hook
if not self.hook_registered:
self._register_hook()
# 清空隐藏层输出字典
self.hidden.clear()
# 前向传播获取隐藏层输出
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
# 断言隐藏层输出不为空
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
# 网络的前向传播,可选择是否返回投影后的输出
def forward(self, x, return_projection = True):
# 获取输入数据的隐藏层输出
embed = self.get_embedding(x)
# 如果不需要返回投影后的输出,则直接返回隐藏层输出
if not return_projection:
return embed
# 获取投影器并对隐藏层输出进行投影
projector = self._get_projector(embed)
return projector(embed), embed
# 主类定义
class Dino(nn.Module):
# 初始化函数
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
student_temp = 0.9,
teacher_temp = 0.04,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
moving_average_decay = 0.9,
center_moving_average_decay = 0.9,
augment_fn = None,
augment_fn2 = None
):
# 调用父类的初始化函数
super().__init__()
# 设置网络
self.net = net
# 默认的 BYOL 数据增强
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
# 设置数据增强函数
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# 设置局部和全局裁剪
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
# 设置学生编码器
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
self.teacher_encoder = None
self.teacher_ema_updater = EMA(moving_average_decay)
# 注册缓冲区
self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_centers', torch.zeros(1, num_classes_K))
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# 获取网络设备并将包装器设置为相同设备
device = get_module_device(net)
self.to(device)
# 发送一个模拟图像张量以实例化单例参数
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
# 获取教师编码器的单例函数
@singleton('teacher_encoder')
def _get_teacher_encoder(self):
teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(teacher_encoder, False)
return teacher_encoder
# 重置移动平均值
def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None
# 更新移动平均值
def update_moving_average(self):
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
self.teacher_centers.copy_(new_teacher_centers)
# 前向传播函数
def forward(
self,
x,
return_embedding = False,
return_projection = True,
student_temp = None,
teacher_temp = None
):
# 如果需要返回嵌入向量,则调用学生编码器并返回结果
if return_embedding:
return self.student_encoder(x, return_projection = return_projection)
# 对输入数据进行两种不同的数据增强
image_one, image_two = self.augment1(x), self.augment2(x)
# 对增强后的图像进行局部裁剪
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
# 对增强后的图像进行全局裁剪
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
# 使用学生编码器对局部裁剪后的图像进行编码
student_proj_one, _ = self.student_encoder(local_image_one)
student_proj_two, _ = self.student_encoder(local_image_two)
# 使用torch.no_grad()上下文管理器,获取教师编码器并对全局裁剪后的图像进行编码
with torch.no_grad():
teacher_encoder = self._get_teacher_encoder()
teacher_proj_one, _ = teacher_encoder(global_image_one)
teacher_proj_two, _ = teacher_encoder(global_image_two)
# 部分应用损失函数,设置学生温度、教师温度和教师中心
loss_fn_ = partial(
loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_centers
)
# 计算教师投影的平均值,并将其复制到最后的教师中心
teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
self.last_teacher_centers.copy_(teacher_logits_avg)
# 计算损失,取两个损失函数的平均值
loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
return loss
.\lucidrains\vit-pytorch\vit_pytorch\distill.py
import torch # 导入 PyTorch 库
import torch.nn.functional as F # 导入 PyTorch 中的函数模块
from torch import nn # 从 PyTorch 中导入 nn 模块
from vit_pytorch.vit import ViT # 从 vit_pytorch 库中导入 ViT 类
from vit_pytorch.t2t import T2TViT # 从 vit_pytorch 库中导入 T2TViT 类
from vit_pytorch.efficient import ViT as EfficientViT # 从 vit_pytorch 库中导入 EfficientViT 类
from einops import rearrange, repeat # 从 einops 库中导入 rearrange 和 repeat 函数
# helpers
def exists(val): # 定义 exists 函数,用于判断变量是否存在
return val is not None # 返回变量是否不为 None
# classes
class DistillMixin: # 定义 DistillMixin 类
def forward(self, img, distill_token = None): # 定义 forward 方法,接收图像和 distill_token 参数
distilling = exists(distill_token) # 判断 distill_token 是否存在
x = self.to_patch_embedding(img) # 将图像转换为 patch embedding
b, n, _ = x.shape # 获取 x 的形状信息
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) # 重复添加 cls_token
x = torch.cat((cls_tokens, x), dim = 1) # 在维度 1 上拼接 cls_tokens 和 x
x += self.pos_embedding[:, :(n + 1)] # 添加位置编码
if distilling: # 如果进行蒸馏
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b) # 重复添加 distill_token
x = torch.cat((x, distill_tokens), dim = 1) # 在维度 1 上拼接 x 和 distill_tokens
x = self._attend(x) # 调用 _attend 方法进行注意力计算
if distilling: # 如果进行蒸馏
x, distill_tokens = x[:, :-1], x[:, -1] # 分割出 distill_tokens
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] # 计算平均值或取第一个值
x = self.to_latent(x) # 转换为 latent 表示
out = self.mlp_head(x) # 经过 MLP 头部处理得到输出
if distilling: # 如果进行蒸馏
return out, distill_tokens # 返回输出和 distill_tokens
return out # 返回输出
class DistillableViT(DistillMixin, ViT): # 定义 DistillableViT 类,继承自 DistillMixin 和 ViT
def __init__(self, *args, **kwargs): # 初始化方法
super(DistillableViT, self).__init__(*args, **kwargs) # 调用父类的初始化方法
self.args = args # 保存参数
self.kwargs = kwargs # 保存关键字参数
self.dim = kwargs['dim'] # 保存维度信息
self.num_classes = kwargs['num_classes'] # 保存类别数
def to_vit(self): # 定义 to_vit 方法
v = ViT(*self.args, **self.kwargs) # 创建 ViT 对象
v.load_state_dict(self.state_dict()) # 加载当前状态字典
return v # 返回 ViT 对象
def _attend(self, x): # 定义 _attend 方法
x = self.dropout(x) # 使用 dropout
x = self.transformer(x) # 经过 transformer 处理
return x # 返回处理后的结果
class DistillableT2TViT(DistillMixin, T2TViT): # 定义 DistillableT2TViT 类,继承自 DistillMixin 和 T2TViT
def __init__(self, *args, **kwargs): # 初始化方法
super(DistillableT2TViT, self).__init__(*args, **kwargs) # 调用父类的初始化方法
self.args = args # 保存参数
self.kwargs = kwargs # 保存关键字参数
self.dim = kwargs['dim'] # 保存维度信息
self.num_classes = kwargs['num_classes'] # 保存类别数
def to_vit(self): # 定义 to_vit 方法
v = T2TViT(*self.args, **self.kwargs) # 创建 T2TViT 对象
v.load_state_dict(self.state_dict()) # 加载当前状态字典
return v # 返回 T2TViT 对象
def _attend(self, x): # 定义 _attend 方法
x = self.dropout(x) # 使用 dropout
x = self.transformer(x) # 经过 transformer 处理
return x # 返回处理后的结果
class DistillableEfficientViT(DistillMixin, EfficientViT): # 定义 DistillableEfficientViT 类,继承自 DistillMixin 和 EfficientViT
def __init__(self, *args, **kwargs): # 初始化方法
super(DistillableEfficientViT, self).__init__(*args, **kwargs) # 调用父类的初始化方法
self.args = args # 保存参数
self.kwargs = kwargs # 保存关键字参数
self.dim = kwargs['dim'] # 保存维度信息
self.num_classes = kwargs['num_classes'] # 保存类别数
def to_vit(self): # 定义 to_vit 方法
v = EfficientViT(*self.args, **self.kwargs) # 创建 EfficientViT 对象
v.load_state_dict(self.state_dict()) # 加载当前状态字典
return v # 返回 EfficientViT 对象
def _attend(self, x): # 定义 _attend 方法
return self.transformer(x) # 经过 transformer 处理
# knowledge distillation wrapper
class DistillWrapper(nn.Module): # 定义 DistillWrapper 类,继承自 nn.Module
def __init__( # 初始化方法
self,
*,
teacher, # 教师模型
student, # 学生模型
temperature = 1., # 温度参数
alpha = 0.5, # alpha 参数
hard = False # 是否硬蒸馏
):
super().__init__() # 调用父类的初始化方法
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' # 断言学生模型必须是视觉 transformer
self.teacher = teacher # 保存教师模型
self.student = student # 保存学生模型
dim = student.dim # 获取学生模型的维度信息
num_classes = student.num_classes # 获取学生模型的类别数
self.temperature = temperature # 保存温度参数
self.alpha = alpha # 保存 alpha 参数
self.hard = hard # 保存是否硬蒸馏
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) # 创建蒸馏 token
self.distill_mlp = nn.Sequential( # 创建 MLP 处理蒸馏信息
nn.LayerNorm(dim), # LayerNorm 处理
nn.Linear(dim, num_classes) # 线性层处理
)
# 定义一个前向传播函数,接受输入图像、标签、温度和权重参数
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
# 获取输入图像的批量大小
b, *_ = img.shape
# 如果 alpha 参数存在,则使用传入的值,否则使用类属性中的值
alpha = alpha if exists(alpha) else self.alpha
# 如果 temperature 参数存在,则使用传入的值,否则使用类属性中的值
T = temperature if exists(temperature) else self.temperature
# 在不计算梯度的情况下,通过教师模型获取教师网络的输出
with torch.no_grad():
teacher_logits = self.teacher(img)
# 通过学生模型获取学生网络的输出和蒸馏 token
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
# 通过蒸馏 token 获取蒸馏网络的输出
distill_logits = self.distill_mlp(distill_tokens)
# 计算学生网络的交叉熵损失
loss = F.cross_entropy(student_logits, labels)
# 如果不是硬蒸馏,则计算软蒸馏损失
if not self.hard:
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
# 如果是硬蒸馏,则计算交叉熵损失
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
# 返回加权损失值,结合了学生网络的损失和蒸馏损失
return loss * (1 - alpha) + distill_loss * alpha
.\lucidrains\vit-pytorch\vit_pytorch\efficient.py
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# 定义一个函数,用于确保输入是一个元组
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 定义一个名为 ViT 的类,继承自 nn.Module
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
image_size_h, image_size_w = pair(image_size)
# 检查图像尺寸是否能被 patch 大小整除
assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
# 检查池化类型是否为 'cls' 或 'mean'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
patch_dim = channels * patch_size ** 2
# 定义将图像切片转换为嵌入向量的序列
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
# 初始化位置编码和类别标记
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.pool = pool
self.to_latent = nn.Identity()
# 定义 MLP 头部
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 重复类别标记以匹配批次大小
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
# 根据池化类型选择不同的池化方式
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\es_vit.py
# 导入所需的库
import copy
import random
from functools import wraps, partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange, reduce, repeat
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, default):
return val if exists(val) else default
# 单例装饰器,用于缓存结果
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
# 获取模块所在设备
def get_module_device(module):
return next(module.parameters()).device
# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
# 张量相关的辅助函数
# 对张量取对数
def log(t, eps = 1e-20):
return torch.log(t + eps)
# 损失函数
# 视图损失函数
def view_loss_fn(
teacher_logits,
student_logits,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
# 区域损失函数
def region_loss_fn(
teacher_logits,
student_logits,
teacher_latent,
student_latent,
teacher_temp,
student_temp,
centers,
eps = 1e-20
):
teacher_logits = teacher_logits.detach()
student_probs = (student_logits / student_temp).softmax(dim = -1)
teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent)
sim_indices = sim_matrix.max(dim = -1).indices
sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1])
max_sim_teacher_probs = teacher_probs.gather(1, sim_indices)
return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
# 数据增强工具
# 随机应用函数
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# 指数移动平均
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP 类用于投影器和预测器
# L2范数
class L2Norm(nn.Module):
def forward(self, x, eps = 1e-6):
return F.normalize(x, dim = 1, eps = eps)
# 多层感知机
class MLP(nn.Module):
def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
super().__init__()
layers = []
dims = (dim, *((hidden_size,) * (num_layers - 1)))
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 1)
layers.extend([
nn.Linear(layer_dim_in, layer_dim_out),
nn.GELU() if not is_last else nn.Identity()
])
self.net = nn.Sequential(
*layers,
L2Norm(),
nn.Linear(hidden_size, dim_out)
)
def forward(self, x):
return self.net(x)
# 用于基础神经网络的包装类
# 将管理隐藏层输出的拦截
# 创建一个包装器类,用于将输入传递到投影器和预测器网络中
class NetWrapper(nn.Module):
def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.view_projector = None
self.region_projector = None
self.projection_hidden_size = projection_hidden_size
self.projection_num_layers = projection_num_layers
self.output_dim = output_dim
self.hidden = {}
self.hook_registered = False
# 查找指定的层
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
# 钩子函数,用于获取隐藏层输出
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = output
# 注册钩子函数
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
# 获取视图投影器
@singleton('view_projector')
def _get_view_projector(self, hidden):
dim = hidden.shape[1]
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
# 获取区域投影器
@singleton('region_projector')
def _get_region_projector(self, hidden):
dim = hidden.shape[1]
projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
return projector.to(hidden)
# 获取嵌入向量
def get_embedding(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
# 前向传播函数
def forward(self, x, return_projection = True):
region_latents = self.get_embedding(x)
global_latent = reduce(region_latents, 'b c h w -> b c', 'mean')
if not return_projection:
return global_latent, region_latents
view_projector = self._get_view_projector(global_latent)
region_projector = self._get_region_projector(region_latents)
region_latents = rearrange(region_latents, 'b c h w -> b (h w) c')
return view_projector(global_latent), region_projector(region_latents), region_latents
# 主类
class EsViTTrainer(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_hidden_size = 256,
num_classes_K = 65336,
projection_layers = 4,
student_temp = 0.9,
teacher_temp = 0.04,
local_upper_crop_scale = 0.4,
global_lower_crop_scale = 0.5,
moving_average_decay = 0.9,
center_moving_average_decay = 0.9,
augment_fn = None,
augment_fn2 = None
# 定义一个继承自父类的子类,初始化网络
):
super().__init__()
self.net = net
# 默认的 BYOL 数据增强
DEFAULT_AUG = torch.nn.Sequential(
# 随机应用颜色抖动
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
# 随机转换为灰度图像
T.RandomGrayscale(p=0.2),
# 随机水平翻转
T.RandomHorizontalFlip(),
# 随机应用高斯模糊
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
# 归一化
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
# 初始化两种数据增强方式
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, DEFAULT_AUG)
# 定义局部和全局裁剪
self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
# 初始化学生编码器
self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
# 初始化教师编码器和指数移动平均更新器
self.teacher_encoder = None
self.teacher_ema_updater = EMA(moving_average_decay)
# 注册缓冲区,用于存储教师视图中心和区域中心
self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_view_centers', torch.zeros(1, num_classes_K))
self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K))
self.register_buffer('last_teacher_region_centers', torch.zeros(1, num_classes_K))
# 初始化教师中心指数移动平均更新器
self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# 获取网络设备并将包装器设备设置为相同
device = get_module_device(net)
self.to(device)
# 发送一个模拟图像张量以实例化单例参数
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
# 使用装饰器创建单例模式,获取教师编码器
@singleton('teacher_encoder')
def _get_teacher_encoder(self):
teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(teacher_encoder, False)
return teacher_encoder
# 重置移动平均值
def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None
# 更新移动平均值
def update_moving_average(self):
assert self.teacher_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers)
self.teacher_view_centers.copy_(new_teacher_view_centers)
new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers)
self.teacher_region_centers.copy_(new_teacher_region_centers)
# 前向传播函数
def forward(
self,
x,
return_embedding = False,
return_projection = True,
student_temp = None,
teacher_temp = None
):
# 如果需要返回嵌入向量,则调用学生编码器并返回结果
if return_embedding:
return self.student_encoder(x, return_projection = return_projection)
# 对输入数据进行两种不同的数据增强
image_one, image_two = self.augment1(x), self.augment2(x)
# 对增强后的数据进行局部裁剪和全局裁剪
local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
# 使用学生编码器对局部裁剪后的数据进行编码
student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one)
student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two)
# 使用torch.no_grad()上下文管理器,获取教师编码器的结果
with torch.no_grad():
teacher_encoder = self._get_teacher_encoder()
teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one)
teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two)
# 部分函数调用,设置视图级别损失函数和区域级别损失函数的参数
view_loss_fn_ = partial(
view_loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_view_centers
)
region_loss_fn_ = partial(
region_loss_fn,
student_temp = default(student_temp, self.student_temp),
teacher_temp = default(teacher_temp, self.teacher_temp),
centers = self.teacher_region_centers
)
# 计算视图级别损失
teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0)
self.last_teacher_view_centers.copy_(teacher_view_logits_avg)
teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1))
self.last_teacher_region_centers.copy_(teacher_region_logits_avg)
view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \
+ view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2
# 计算区域级别损失
region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \
+ region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2
# 返回视图级别损失和区域级别损失的平均值
return (view_loss + region_loss) / 2
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 检查值是否存在
def exists(val):
return val is not None
# 返回输入值
def identity(t):
return t
# 克隆并分离张量
def clone_and_detach(t):
return t.clone().detach()
# 应用函数到元组或单个值
def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
return fn(val)
# 定义 Extractor 类,继承自 nn.Module
class Extractor(nn.Module):
def __init__(
self,
vit,
device = None,
layer = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False,
detach = True
):
super().__init__()
# 初始化属性
self.vit = vit
self.data = None
self.latents = None
self.hooks = []
self.hook_registered = False
self.ejected = False
self.device = device
self.layer = layer
self.layer_name = layer_name
self.layer_save_input = layer_save_input # 是否保存层的输入或输出
self.return_embeddings_only = return_embeddings_only
# 根据 detach 参数选择克隆并分离函数或返回输入值函数
self.detach_fn = clone_and_detach if detach else identity
# 钩子函数,用于提取特征
def _hook(self, _, inputs, output):
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
# 注册钩子函数
def _register_hook(self):
if not exists(self.layer):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
layer = getattr(self.vit, self.layer_name)
else:
layer = self.layer
handle = layer.register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
# 弹出钩子函数
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.vit
# 清除特征
def clear(self):
del self.latents
self.latents = None
# 前向传播函数
def forward(
self,
img,
return_embeddings_only = False
):
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.vit(img)
target_device = self.device if exists(self.device) else img.device
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
if return_embeddings_only or self.return_embeddings_only:
return latents
return pred, latents
.\lucidrains\vit-pytorch\vit_pytorch\learnable_memory_vit.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 将输入转换为元组的函数
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 控制层是否冻结的函数
# 设置模块参数是否需要梯度的函数
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
# 冻结所有层的函数
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
# 解冻所有层的函数
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
# 类
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 注意力机制类
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, attn_mask = None, memories = None):
x = self.norm(x)
x_kv = x # input for key / values projection
if exists(memories):
# add memories to key / values if it is passed in
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
x_kv = torch.cat((x_kv, memories), dim = 1)
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, attn_mask = None, memories = None):
for ind, (attn, ff) in enumerate(self.layers):
layer_memories = memories[ind] if exists(memories) else None
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
x = ff(x) + x
return x
# ViT ��
class ViT(nn.Module):
# 初始化函数,设置模型参数和结构
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
# 调用父类的初始化函数
super().__init__()
# 获取图像的高度和宽度
image_height, image_width = pair(image_size)
# 获取patch的高度和宽度
patch_height, patch_width = pair(patch_size)
# 断言图像的高度和宽度能够被patch的高度和宽度整除
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
# 计算patch的数量
num_patches = (image_height // patch_height) * (image_width // patch_width)
# 计算每个patch的维度
patch_dim = channels * patch_height * patch_width
# 断言池化类型只能是'cls'或'mean'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# 定义将图像转换为patch嵌入的层
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
# 初始化位置嵌入参数
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 初始化类别标记参数
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 初始化dropout层
self.dropout = nn.Dropout(emb_dropout)
# 初始化transformer模型
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# 定义MLP头部
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# 将图像转换为tokens
def img_to_tokens(self, img):
# 将图像转换为patch嵌入
x = self.to_patch_embedding(img)
# 重复类别标记,拼接到patch嵌入中
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
x = torch.cat((cls_tokens, x), dim = 1)
# 添加位置嵌入并进行dropout
x += self.pos_embedding
x = self.dropout(x)
return x
# 前向传播函数
def forward(self, img):
# 将图像转换为tokens
x = self.img_to_tokens(img)
# 使用transformer模型处理tokens
x = self.transformer(x)
# 获取类别标记的输出
cls_tokens = x[:, 0]
return self.mlp_head(cls_tokens)
# 适配器模块,具有每层可学习的记忆、记忆 CLS 标记和可学习的适配器头部
class Adapter(nn.Module):
def __init__(
self,
*,
vit,
num_memories_per_layer = 10,
num_classes = 2,
):
super().__init__()
assert isinstance(vit, ViT)
# 提取一些需要的模型变量
dim = vit.cls_token.shape[-1]
layers = len(vit.transformer.layers)
num_patches = vit.pos_embedding.shape[-2]
self.vit = vit
# 冻结 ViT 主干 - 只有记忆会被微调
freeze_all_layers_(vit)
# 可学习的参数
self.memory_cls_token = nn.Parameter(torch.randn(dim))
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# 专门的注意力掩码以保留原始 ViT 的输出
# 它允许记忆 CLS 标记关注所有其他标记(和可学习的记忆层标记),但反之亦然
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # 主要标记不能关注每层的可学习记忆
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # 记忆 CLS 标记可以关注所有内容
self.register_buffer('attn_mask', attn_mask)
def forward(self, img):
b = img.shape[0]
tokens = self.vit.img_to_tokens(img)
# 添加任务特定的记忆标记
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
# 通过变压器传递记忆以及图像标记进行关注
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
# 提取记忆 CLS 标记
memory_cls_tokens = out[:, 0]
# 通过任务特定的适配器头部传递
return self.mlp_head(memory_cls_tokens)
.\lucidrains\vit-pytorch\vit_pytorch\levit.py
# 从 math 模块中导入 ceil 函数
from math import ceil
# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
import torch.nn.functional as F
# 导入 einops 模块中的 rearrange 和 repeat 函数,以及 torch 子模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 将输入值转换为元组的函数
def cast_tuple(val, l = 3):
val = val if isinstance(val, tuple) else (val,)
return (*val, *((val[-1],) * max(l - len(val), 0))
# 返回固定值的函数
def always(val):
return lambda *args, **kwargs: val
# 类
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, mult, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 注意力机制类
class Attention(nn.Module):
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
super().__init__()
inner_dim_key = dim_key * heads
inner_dim_value = dim_value * heads
dim_out = default(dim_out, dim)
self.heads = heads
self.scale = dim_key ** -0.5
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
self.to_out = nn.Sequential(
nn.GELU(),
nn.Conv2d(inner_dim_value, dim_out, 1),
out_batch_norm,
nn.Dropout(dropout)
)
# 位置偏置
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
k_range = torch.arange(fmap_size)
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
x_rel, y_rel = rel_pos.unbind(dim = -1)
pos_indices = (x_rel * fmap_size) + y_rel
self.register_buffer('pos_indices', pos_indices)
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads
q = self.to_q(x)
y = q.shape[2]
qkv = (q, self.to_k(x), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
# 初始化函数,设置模型参数和结构
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
# 调用父类的初始化函数
super().__init__()
# 如果未指定输出维度,则默认为输入维度
dim_out = default(dim_out, dim)
# 初始化一个空的模块列表用于存储每个层
self.layers = nn.ModuleList([])
# 判断是否使用注意力机制的残差连接
self.attn_residual = (not downsample) and dim == dim_out
# 根据深度循环创建每个层
for _ in range(depth):
# 每个层包含一个注意力机制和一个前馈神经网络
self.layers.append(nn.ModuleList([
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
FeedForward(dim_out, mlp_mult, dropout = dropout)
]))
# 前向传播函数,处理输入数据
def forward(self, x):
# 遍历每个层
for attn, ff in self.layers:
# 如果使用注意力机制的残差连接,则保存输入数据
attn_res = (x if self.attn_residual else 0)
# 经过注意力机制处理后,加上残差连接
x = attn(x) + attn_res
# 经过前馈神经网络处理后,加上残差连接
x = ff(x) + x
# 返回处理后的数据
return x
# 定义 LeViT 类,继承自 nn.Module
class LeViT(nn.Module):
# 初始化函数,接收多个参数
def __init__(
self,
*,
image_size, # 图像大小
num_classes, # 类别数量
dim, # 维度
depth, # 深度
heads, # 头数
mlp_mult, # MLP 倍数
stages = 3, # 阶段数,默认为 3
dim_key = 32, # 键维度,默认为 32
dim_value = 64, # 值维度,默认为 64
dropout = 0., # Dropout,默认为 0
num_distill_classes = None # 蒸馏类别数量,默认为 None
):
# 调用父类的初始化函数
super().__init__()
# 将 dim、depth、heads 转换为元组
dims = cast_tuple(dim, stages)
depths = cast_tuple(depth, stages)
layer_heads = cast_tuple(heads, stages)
# 断言确保 dimensions、depths、heads 必须是小于指定阶段数的元组
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
# 定义卷积嵌入层
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
)
# 计算特征图大小
fmap_size = image_size // (2 ** 4)
layers = []
# 遍历阶段,构建 Transformer 层
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
is_last = ind == (stages - 1)
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
if not is_last:
next_dim = dims[ind + 1]
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
fmap_size = ceil(fmap_size / 2)
# 构建骨干网络
self.backbone = nn.Sequential(*layers)
# 定义池化层
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...')
)
# 定义蒸馏头部
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
# 定义 MLP 头部
self.mlp_head = nn.Linear(dim, num_classes)
# 前向传播函数
def forward(self, img):
# 图像经过卷积嵌入层
x = self.conv_embedding(img)
# 特征图经过骨干网络
x = self.backbone(x)
# 特征图经过池化层
x = self.pool(x)
# 输出结果经过 MLP 头部
out = self.mlp_head(x)
# 蒸馏结果经过蒸馏头部
distill = self.distill_head(x)
# 如果存在蒸馏结果,则返回输出结果和蒸馏结果
if exists(distill):
return out, distill
# 否则只返回输出结果
return out
.\lucidrains\vit-pytorch\vit_pytorch\local_vit.py
# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn, einsum
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 einops 模块中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
# classes
# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
# 初始化函数
def __init__(self, fn):
super().__init__()
self.fn = fn
# 前向传播函数
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# 定义 ExcludeCLS 类,继承自 nn.Module
class ExcludeCLS(nn.Module):
# 初始化函数
def __init__(self, fn):
super().__init__()
self.fn = fn
# 前向传播函数
def forward(self, x, **kwargs):
cls_token, x = x[:, :1], x[:, 1:]
x = self.fn(x, **kwargs)
return torch.cat((cls_token, x), dim = 1)
# feed forward related classes
# 定义 DepthWiseConv2d 类,继承自 nn.Module
class DepthWiseConv2d(nn.Module):
# 初始化函数
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
# 前向传播函数
def forward(self, x):
return self.net(x)
# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
# 初始化函数
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv2d(dim, hidden_dim, 1),
nn.Hardswish(),
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(hidden_dim, dim, 1),
nn.Dropout(dropout)
)
# 前向传播函数
def forward(self, x):
h = w = int(sqrt(x.shape[-2]))
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
x = self.net(x)
x = rearrange(x, 'b c h w -> b (h w) c')
return x
# attention
# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
# 初始化函数
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
# 前向传播函数
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
# 初始化函数
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
]))
# 前向传播函数
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
# main class
# 定义 LocalViT 类,继承自 nn.Module
class LocalViT(nn.Module):
# 初始化函数,设置模型参数和层结构
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
# 调用父类的初始化函数
super().__init__()
# 检查图像尺寸是否能被分块尺寸整除
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
# 计算图像分块数量
num_patches = (image_size // patch_size) ** 2
# 计算每个分块的维度
patch_dim = channels * patch_size ** 2
# 定义将图像分块转换为嵌入向量的层序列
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
# 初始化位置编码参数
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 初始化类别标记参数
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 初始化丢弃层
self.dropout = nn.Dropout(emb_dropout)
# 初始化 Transformer 模型
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# 定义 MLP 头部层序列
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, img):
# 将图像转换为分块嵌入向量
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 重复类别标记以匹配批次大小
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
# 将类别标记和分块嵌入向量拼接在一起
x = torch.cat((cls_tokens, x), dim=1)
# 添加位置编码
x += self.pos_embedding[:, :(n + 1)]
# 对结果进行丢弃
x = self.dropout(x)
# 使用 Transformer 进行特征变换
x = self.transformer(x)
# 返回 MLP 头部的输出结果
return self.mlp_head(x[:, 0])
.\lucidrains\vit-pytorch\vit_pytorch\mae.py
# 导入 PyTorch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 einops 库中导入 repeat 函数
from einops import repeat
# 从 vit_pytorch.vit 模块中导入 Transformer 类
from vit_pytorch.vit import Transformer
# 定义一个名为 MAE 的 nn.Module 类
class MAE(nn.Module):
# 初始化函数,接收一系列参数
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
# 断言确保 masking_ratio 在 0 和 1 之间
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
# 将 masking_ratio 存储在对象中
self.masking_ratio = masking_ratio
# 从编码器中提取一些超参数和函数(待训练的视觉变换器)
# 存储编码器对象
self.encoder = encoder
# 获取补丁数量和编码器维度
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
# 获取从图像到补丁的转换函数
self.to_patch = encoder.to_patch_embedding[0]
# 获取从补丁到嵌入的序列
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
# 获取每个补丁的像素值
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
# 解码器参数
# 存储解码器维度
self.decoder_dim = decoder_dim
# 如果编码器维度与解码器维度不同,则使用 nn.Linear 进行映射,否则使用 nn.Identity
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
# 初始化一个可学习的遮罩令牌
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
# 创建一个 Transformer 解码器
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
# 创建一个嵌入层用于解码器位置编码
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
# 创建一个线性层用于将解码器输出映射回像素值
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
# 定义一个前向传播函数,接收输入图像
def forward(self, img):
# 获取输入图像所在设备
device = img.device
# 获取图像的补丁
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# 将补丁转换为编码器标记并添加位置信息
tokens = self.patch_to_emb(patches)
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
# 计算需要屏蔽的补丁数量,并获取随机索引,将其分为屏蔽和未屏蔽的部分
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
# 获取要编码的未屏蔽标记
batch_range = torch.arange(batch, device=device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# 获取用于最终重建损失的要屏蔽的补丁
masked_patches = patches[batch_range, masked_indices]
# 使用视觉变换器进行注意力
encoded_tokens = self.encoder.transformer(tokens)
# 投影编码器到解码器维度,如果它们不相等 - 论文中说可以使用较小的维度进行解码器
decoder_tokens = self.enc_to_dec(encoded_tokens)
# 重新应用解码器位置嵌入到未屏蔽标记
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# 重复屏蔽标记以匹配屏蔽数量,并使用上面得到的屏蔽索引添加位置
mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch, n=num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# 将屏蔽标记连接到解码器标记并使用解码器进行注意力
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
# 剪切出屏蔽标记并投影到像素值
mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = self.to_pixels(mask_tokens)
# 计算重建损失
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
.\lucidrains\vit-pytorch\vit_pytorch\max_vit.py
# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# 辅助函数
# 检查变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将变量转换为元组,如果不是元组则重复多次
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# 辅助类
# 残差连接
class Residual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# MBConv
# Squeeze-and-Excitation 模块
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
# MBConv 残差块
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
# 随机丢弃采样
class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
# MBConv 构建函数
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
# 注意力相关类
# 注意力机制
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7
):
# 调用父类的构造函数
super().__init__()
# 断言维度应该能够被每个头的维度整除
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
# 计算头的数量
self.heads = dim // dim_head
# 缩放因子
self.scale = dim_head ** -0.5
# LayerNorm 层
self.norm = nn.LayerNorm(dim)
# 线性变换,将输入维度转换为查询、键、值的维度
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
# 注意力机制
self.attend = nn.Sequential(
nn.Softmax(dim = -1), # Softmax 激活函数
nn.Dropout(dropout) # Dropout 层
)
# 输出层
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False), # 线性变换
nn.Dropout(dropout) # Dropout 层
)
# 相对位置偏置
# Embedding 层,用于存储相对位置偏置
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
# 计算相对位置偏置
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
# 注册缓冲区,存储相对位置索引
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
# 获取输入张量的形状信息
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
# LayerNorm 层
x = self.norm(x)
# 展开张量
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
# 为查询、键、值进行投影
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 分割头
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 缩放
q = q * self.scale
# 计算相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 添加相对位置偏置
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# 注意力机制
attn = self.attend(sim)
# 聚合
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 合并头
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
# 合并头的输出
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
# 定义一个名为 MaxViT 的神经网络模型类,继承自 nn.Module
class MaxViT(nn.Module):
# 初始化函数,接受一系列参数
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3
):
# 调用父类的初始化函数
super().__init__()
# 断言 depth 是一个元组,用于指定每个阶段的 transformer 块数量
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
# 卷积 stem
# 如果未指定 dim_conv_stem,则设为 dim
dim_conv_stem = default(dim_conv_stem, dim)
# 定义卷积 stem
self.conv_stem = nn.Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
# 变量
# 计算阶段数量
num_stages = len(depth)
# 计算每个阶段的维度
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
# 初始化 layers 为一个空的 nn.ModuleList
self.layers = nn.ModuleList([])
# 为了高效的块状 - 网格状注意力,设置窗口大小的简写
w = window_size
# 遍历每个阶段
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
# 定义一个块
block = nn.Sequential(
MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # 块状注意力
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # 网格状注���力
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)
# 将块添加到 layers 中
self.layers.append(block)
# MLP 头部
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
# 前向传播函数
def forward(self, x):
# 经过卷积 stem
x = self.conv_stem(x)
# 遍历每个阶段的块
for stage in self.layers:
x = stage(x)
# 经过 MLP 头部
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\max_vit_with_registers.py
# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
# 辅助函数
# 检查变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将单个元素打包成指定模式的数据
def pack_one(x, pattern):
return pack([x], pattern)
# 将数据解包成单个元素
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]
# 将变量转换为元组,如果不是元组则重复多次
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length
# 辅助类
# 定义前馈神经网络结构
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(dim * mult)
return Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
# MBConv
# 定义Squeeze-and-Excitation模块
class SqueezeExcitation(Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
# 定义MBConv残差模块
class MBConvResidual(Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
# 定义DropSample模块
class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
# 定义MBConv模块
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
# 注意力相关类
# 定义注意力机制模块
class Attention(Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7,
num_registers = 1
):
# 调用父类的构造函数
super().__init__()
# 断言寄存器数量大于0
assert num_registers > 0
# 断言维度应该可以被每个头的维度整除
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
# 计算头的数量
self.heads = dim // dim_head
# 缩放因子
self.scale = dim_head ** -0.5
# LayerNorm层
self.norm = nn.LayerNorm(dim)
# 线性变换层,将输入维度转换为3倍的维度,用于计算Q、K、V
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
# 注意力机制
self.attend = nn.Sequential(
nn.Softmax(dim = -1), # Softmax激活函数
nn.Dropout(dropout) # Dropout层
)
# 输出层
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False), # 线性变换层
nn.Dropout(dropout) # Dropout层
)
# 相对位置偏差
# 计算相对位置偏差的数量
num_rel_pos_bias = (2 * window_size - 1) ** 2
# Embedding层,用于存储相对位置偏差
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
# 生成相对位置偏差的索引
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
# 对相对位置偏差索引进行填充
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
# 获取设备信息、头的数量、相对位置偏差索引
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
# LayerNorm层
x = self.norm(x)
# 为查询、键、值进行投影
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 分割头
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 缩放
q = q * self.scale
# 计算相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 添加位置偏差
bias = self.rel_pos_bias(bias_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# 注意力机制
attn = self.attend(sim)
# 聚合
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 合并头部输出
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class MaxViT(Module):
# 定义一个名为 MaxViT 的类,继承自 Module 类
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3,
num_register_tokens = 4
):
# 初始化函数,接受一系列参数
super().__init__()
# 调用父类的初始化函数
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
assert num_register_tokens > 0
# 断言语句,确保 depth 是元组类型,num_register_tokens 大于 0
# convolutional stem
dim_conv_stem = default(dim_conv_stem, dim)
# 如果 dim_conv_stem 为 None,则设置为 dim
self.conv_stem = Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
# 创建一个包含两个卷积层的 Sequential 对象,作为卷积部分的网络结构
# variables
num_stages = len(depth)
# 计算 depth 的长度,作为阶段数
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
# 计算每个阶段的维度
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
# 将维度组成成对
self.layers = nn.ModuleList([])
# 创建一个空的 ModuleList 对象用于存储网络层
# window size
self.window_size = window_size
# 设置窗口大小
self.register_tokens = nn.ParameterList([])
# 创建一个空的 ParameterList 对象用于存储注册令牌
# iterate through stages
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
# 遍历每个阶段
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
# 判断是否为当前阶段的第一个块
conv = MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
)
# 创建一个 MBConv 对象
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
# 创建注意力和前馈网络对象
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
# 创建注意力和前馈网络对象
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
# 创建一个随机初始化的注册令牌
self.layers.append(ModuleList([
conv,
ModuleList([block_attn, block_ff]),
ModuleList([grid_attn, grid_ff])
]))
# 将卷积层、注意力和前馈网络组成的模块列表添加到网络层中
self.register_tokens.append(register_tokens)
# 将注册令牌添加到注册令牌列表中
# mlp head out
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
# 创建一个线性层用于分类
# 定义前向传播函数,接受输入张量 x
def forward(self, x):
# 获取输入张量 x 的批量大小 b 和窗口大小 w
b, w = x.shape[0], self.window_size
# 对输入张量 x 进行卷积操作
x = self.conv_stem(x)
# 遍历每个层的操作,包括卷积、注意力机制和前馈网络
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
# 对输入张量 x 进行卷积操作
x = conv(x)
# block-like attention
# 对输入张量 x 进行重新排列操作,将其转换为多维矩阵
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
# 准备注册令牌
# 将注册令牌进行重复操作,以匹配输入张量 x 的形状
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
# 对输入张量 x 进行块状注意力操作,并与原始输入相加
x = block_attn(x) + x
# 对输入张量 x 进行块状前馈网络操作,并与原始输入相加
x = block_ff(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
r = unpack_one(r, register_batch_ps, '* n d')
# grid-like attention
# 对输入张量 x 进行重新排列操作,将其转换为多维矩阵
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
# 准备注册令牌
# 对注册令牌进行降维操作,计算均值
r = reduce(r, 'b x y n d -> b n d', 'mean')
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
# 对输入张量 x 进行网格状注意力操作,并与原始输入相加
x = grid_attn(x) + x
r, x = unpack(x, register_ps, 'b * d')
# 对输入张量 x 进行网格状前馈网络操作,并与��始输入相加
x = grid_ff(x) + x
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
# 返回经过 MLP 头部处理后的结果
return self.mlp_head(x)