Lucidrains-系列项目源码解析-四十七-

Lucidrains 系列项目源码解析(四十七)

.\lucidrains\vit-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'vit-pytorch',
  # 查找除了 'examples' 文件夹之外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '1.6.5',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Vision Transformer (ViT) - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/vit-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'image recognition'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'torch>=1.10',
    'torchvision'
  ],
  # 设置需要的依赖
  setup_requires=[
    'pytest-runner',
  ],
  # 测试需要的依赖
  tests_require=[
    'pytest',
    'torch==1.12.1',
    'torchvision==0.13.1'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\vit-pytorch\tests\test.py

# 导入 torch 库
import torch
# 从 vit_pytorch 库中导入 ViT 类
from vit_pytorch import ViT

# 定义测试函数
def test():
    # 创建 ViT 模型对象,设置参数:图像大小为 256,patch 大小为 32,类别数为 1000,特征维度为 1024,深度为 6,注意力头数为 16,MLP 隐藏层维度为 2048,dropout 概率为 0.1,嵌入层 dropout 概率为 0.1
    v = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

    # 生成一个形状为 (1, 3, 256, 256) 的随机张量作为输入图像
    img = torch.randn(1, 3, 256, 256)

    # 将输入图像传入 ViT 模型进行预测
    preds = v(img)
    # 断言预测结果的形状为 (1, 1000),如果不符合则抛出异常信息 'correct logits outputted'
    assert preds.shape == (1, 1000), 'correct logits outputted'

.\lucidrains\vit-pytorch\vit_pytorch\ats_vit.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.nn.utils.rnn 中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat 函数和 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 自适应令牌采样函数和类

# 计算输入张量的自然对数,避免输入为 0 时出现错误
def log(t, eps = 1e-6):
    return torch.log(t + eps)

# 生成服从 Gumbel 分布的随机数
def sample_gumbel(shape, device, dtype, eps = 1e-6):
    u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
    return -log(-log(u, eps), eps)

# 在指定维度上对输入张量进行批量索引选择
def batched_index_select(values, indices, dim = 1):
    # 获取值张量和索引张量的维度信息
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    # 将索引张量扩展到与值张量相同的维度
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# 自适应令牌采样类
class AdaptiveTokenSampling(nn.Module):
    def __init__(self, output_num_tokens, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.output_num_tokens = output_num_tokens
    # 定义一个前向传播函数,接收注意力值、数值、掩码作为输入
    def forward(self, attn, value, mask):
        # 获取注意力值的头数、输出的标记数、eps值、设备和数据类型
        heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype

        # 获取CLS标记到所有其他标记的注意力值
        cls_attn = attn[..., 0, 1:]

        # 计算数值的范数,用于加权得分,如论文中所述
        value_norms = value[..., 1:, :].norm(dim=-1)

        # 通过数值的范数加权注意力得分,对所有头求和
        cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)

        # 归一化为1
        normed_cls_attn = cls_attn / (cls_attn.sum(dim=-1, keepdim=True) + eps)

        # 不使用逆变换采样,而是反转softmax并使用gumbel-max采样
        pseudo_logits = log(normed_cls_attn)

        # 为gumbel-max采样屏蔽伪对数
        mask_without_cls = mask[:, 1:]
        mask_value = -torch.finfo(attn.dtype).max / 2
        pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)

        # 扩展k次,k为自适应采样数
        pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k=output_num_tokens)
        pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device=device, dtype=dtype)

        # gumbel-max采样并加一以保留0用于填充/掩码
        sampled_token_ids = pseudo_logits.argmax(dim=-1) + 1

        # 使用torch.unique计算唯一值,然后从右侧填充序列
        unique_sampled_token_ids_list = [torch.unique(t, sorted=True) for t in torch.unbind(sampled_token_ids)]
        unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first=True)

        # 基于填充计算新的掩码
        new_mask = unique_sampled_token_ids != 0

        # CLS标记永远不会被屏蔽(得到True值)
        new_mask = F.pad(new_mask, (1, 0), value=True)

        # 在前面添加一个0标记ID以保留CLS注意力得分
        unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value=0)
        expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h=heads)

        # 收集新的注意力得分
        new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim=2)

        # 返回采样的注意力得分、新掩码(表示填充)以及采样的标记索引(用于残差)
        return new_attn, new_mask, unique_sampled_token_ids
# 定义前馈神经网络类
class FeedForward(nn.Module):
    # 初始化函数,定义网络结构
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 使用 nn.Sequential 定义网络层次结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # Layer normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 正则化
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )
    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义注意力机制类
class Attention(nn.Module):
    # 初始化函数,定义注意力机制结构
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # Layer normalization
        self.attend = nn.Softmax(dim = -1)  # Softmax 注意力权重计算
        self.dropout = nn.Dropout(dropout)  # Dropout 正则化

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换

        self.output_num_tokens = output_num_tokens
        self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )

    # 前向传播函数
    def forward(self, x, *, mask):
        num_tokens = x.shape[1]

        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if exists(mask):
            dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
            mask_value = -torch.finfo(dots.dtype).max
            dots = dots.masked_fill(~dots_mask, mask_value)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        sampled_token_ids = None

        # 如果启用了自适应令牌采样,并且令牌数量大于输出令牌数量
        if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
            attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out), mask, sampled_token_ids

# 定义 Transformer 类
class Transformer(nn.Module):
    # 初始化函数,定义 Transformer 结构
    def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
        assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
        assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'

        self.layers = nn.ModuleList([])
        for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的形状的前两个维度大小和设备信息
        b, n, device = *x.shape[:2], x.device

        # 使用掩码来跟踪填充位置,以便在采样标记时移除重复项
        mask = torch.ones((b, n), device=device, dtype=torch.bool)

        # 创建一个包含从 0 到 n-1 的张量,设备信息与输入张量 x 一致
        token_ids = torch.arange(n, device=device)
        token_ids = repeat(token_ids, 'n -> b n', b=b)

        # 遍历每个注意力层和前馈层
        for attn, ff in self.layers:
            # 调用注意力层的前向传播函数,获取注意力输出、更新后的掩码和采样的标记
            attn_out, mask, sampled_token_ids = attn(x, mask=mask)

            # 当进行标记采样时,需要使用采样的标记 id 从输入张量中选择对应的标记
            if exists(sampled_token_ids):
                x = batched_index_select(x, sampled_token_ids, dim=1)
                token_ids = batched_index_select(token_ids, sampled_token_ids, dim=1)

            # 更新输入张量,加上注意力输出
            x = x + attn_out

            # 经过前馈层处理后再加上原始输入,得到最终输出
            x = ff(x) + x

        # 返回最终输出张量和标记 id
        return x, token_ids
class ViT(nn.Module):
    # 定义 ViT 模型类,继承自 nn.Module
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 初始化函数,接收参数 image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels, dim_head, dropout, emb_dropout
        super().__init__()
        # 调用父类的初始化函数

        image_height, image_width = pair(image_size)
        # 获取图像的高度和宽度
        patch_height, patch_width = pair(patch_size)
        # 获取补丁的高度和宽度

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言,确保图像的尺寸能够被补丁的尺寸整除

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算补丁的数量
        patch_dim = channels * patch_height * patch_width
        # 计算每个补丁的维度

        self.to_patch_embedding = nn.Sequential(
            # 定义将图像转换为补丁嵌入的序列
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            # 重新排列图像的通道和补丁的维度
            nn.LayerNorm(patch_dim),
            # 对每个补丁进行 LayerNorm
            nn.Linear(patch_dim, dim),
            # 线性变换将每个补丁的维度映射到指定的维度 dim
            nn.LayerNorm(dim)
            # 对映射后的维度进行 LayerNorm
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化位置嵌入参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化类别标记参数
        self.dropout = nn.Dropout(emb_dropout)
        # 定义丢弃层,用于嵌入的丢弃

        self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
        # 初始化 Transformer 模型

        self.mlp_head = nn.Sequential(
            # 定义 MLP 头部
            nn.LayerNorm(dim),
            # 对输入进行 LayerNorm
            nn.Linear(dim, num_classes)
            # 线性变换将维度映射到类别数量
        )

    def forward(self, img, return_sampled_token_ids = False):
        # 定义前向传播函数,接收图像和是否返回采样的令牌 ID

        x = self.to_patch_embedding(img)
        # 将图像转换为补丁嵌入
        b, n, _ = x.shape
        # 获取 x 的形状信息

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 重复类别标记,使其与补丁嵌入的形状相同
        x = torch.cat((cls_tokens, x), dim=1)
        # 拼接类别标记和补丁嵌入
        x += self.pos_embedding[:, :(n + 1)]
        # 添加位置嵌入
        x = self.dropout(x)
        # 对输入进行丢弃

        x, token_ids = self.transformer(x)
        # 使用 Transformer 进行转换

        logits = self.mlp_head(x[:, 0])
        # 使用 MLP 头部生成输出

        if return_sampled_token_ids:
            # 如果需要返回采样的令牌 ID
            token_ids = token_ids[:, 1:] - 1
            # 移除类别标记并减去 1 以使 -1 成为填充
            return logits, token_ids
            # 返回输出和令牌 ID

        return logits
        # 返回输出

.\lucidrains\vit-pytorch\vit_pytorch\cait.py

from random import randrange
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 对层应用 dropout
def dropout_layers(layers, dropout):
    if dropout == 0:
        return layers

    num_layers = len(layers)
    to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout

    # 确保至少有一层保留
    if all(to_drop):
        rand_index = randrange(num_layers)
        to_drop[rand_index] = False

    layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
    return layers

# classes

# 缩放层
class LayerScale(nn.Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:  # 根据深度选择初始化值,详见论文第2节
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
        self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)
        context = x if not exists(context) else torch.cat((x, context), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn)    # talking heads, pre-softmax

        attn = self.attend(dots)
        attn = self.dropout(attn)

        attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn)   # talking heads, post-softmax

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 模型
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layer_dropout = layer_dropout

        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
                LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
            ]))
    def forward(self, x, context = None):
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)

        for attn, ff in layers:
            x = attn(x, context = context) + x
            x = ff(x) + x
        return x

class CaiT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        image_size,  # 图像大小
        patch_size,  # 补丁大小
        num_classes,  # 类别数量
        dim,  # 特征维度
        depth,  # 深度
        cls_depth,  # 分类深度
        heads,  # 多头注意力头数
        mlp_dim,  # MLP隐藏层维度
        dim_head = 64,  # 头维度
        dropout = 0.,  # 丢弃率
        emb_dropout = 0.,  # 嵌入层丢弃率
        layer_dropout = 0.  # 层丢弃率
    ):
        # 调用父类初始化函数
        super().__init__()
        # 检查图像尺寸是否能被补丁大小整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算补丁数量
        num_patches = (image_size // patch_size) ** 2
        # 计算补丁维度
        patch_dim = 3 * patch_size ** 2

        # 补丁嵌入层
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 位置嵌入
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        # 分类令牌
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # 丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 补丁Transformer
        self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
        # 分类Transformer
        self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)

        # MLP头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 补丁嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 添加位置嵌入
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        # 补丁Transformer
        x = self.patch_transformer(x)

        # 重复分类令牌
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 分类Transformer
        x = self.cls_transformer(cls_tokens, context = x)

        # 返回MLP头的结果
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\cct.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# CCT 模型

# 定义导出的 CCT 模型名称列表
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']

# 定义创建不同层数 CCT 模型的函数
def cct_2(*args, **kwargs):
    return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)

def cct_4(*args, **kwargs):
    return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)

def cct_6(*args, **kwargs):
    return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)

def cct_7(*args, **kwargs):
    return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)

def cct_8(*args, **kwargs):
    return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)

def cct_14(*args, **kwargs):
    return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)

def cct_16(*args, **kwargs):
    return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)

# 创建 CCT 模型的内部函数
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
         kernel_size=3, stride=None, padding=None,
         *args, **kwargs):
    # 计算默认的步长和填充值
    stride = default(stride, max(1, (kernel_size // 2) - 1))
    padding = default(padding, max(1, (kernel_size // 2)))

    # 返回 CCT 模型
    return CCT(num_layers=num_layers,
               num_heads=num_heads,
               mlp_ratio=mlp_ratio,
               embedding_dim=embedding_dim,
               kernel_size=kernel_size,
               stride=stride,
               padding=padding,
               *args, **kwargs)

# 位置编码

# 创建正弦位置编码的函数
def sinusoidal_embedding(n_channels, dim):
    pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
                            for p in range(n_channels)])
    pe[:, 0::2] = torch.sin(pe[:, 0::2])
    pe[:, 1::2] = torch.cos(pe[:, 1::2])
    return rearrange(pe, '... -> 1 ...')

# 模块

# 定义注意力机制模块
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.heads = num_heads
        head_dim = dim // self.heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attn_drop = nn.Dropout(attention_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(projection_dropout)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q = q * self.scale

        attn = einsum('b h i d, b h j d -> b h i j', q, k)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = einsum('b h i j, b h j d -> b h i d', attn, v)
        x = rearrange(x, 'b h n d -> b n (h d)')

        return self.proj_drop(self.proj(x))

# 定义 Transformer 编码器层模块
class TransformerEncoderLayer(nn.Module):
    """
    Inspired by torch.nn.TransformerEncoderLayer and
    rwightman's timm package.
    """
    # 初始化函数,定义了 Transformer Encoder 层的结构
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 attention_dropout=0.1, drop_path_rate=0.1):
        # 调用父类的初始化函数
        super().__init__()

        # 对输入进行 Layer Normalization
        self.pre_norm = nn.LayerNorm(d_model)
        # 定义自注意力机制
        self.self_attn = Attention(dim=d_model, num_heads=nhead,
                                   attention_dropout=attention_dropout, projection_dropout=dropout)

        # 第一个线性层
        self.linear1  = nn.Linear(d_model, dim_feedforward)
        # 第一个 Dropout 层
        self.dropout1 = nn.Dropout(dropout)
        # 第一个 Layer Normalization 层
        self.norm1    = nn.LayerNorm(d_model)
        # 第二个线性层
        self.linear2  = nn.Linear(dim_feedforward, d_model)
        # 第二个 Dropout 层
        self.dropout2 = nn.Dropout(dropout)

        # DropPath 模块
        self.drop_path = DropPath(drop_path_rate)

        # 激活函数为 GELU
        self.activation = F.gelu

    # 前向传播函数
    def forward(self, src, *args, **kwargs):
        # 使用自注意力机制处理输入,并加上 DropPath 模块
        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
        # 对结果进行 Layer Normalization
        src = self.norm1(src)
        # 第一个线性层、激活函数、Dropout、第二个线性层的组合
        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
        # 将结果与 DropPath 模块处理后的结果相加
        src = src + self.drop_path(self.dropout2(src2))
        # 返回处理后的结果
        return src
class DropPath(nn.Module):
    # 初始化 DropPath 类
    def __init__(self, drop_prob=None):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的 drop_prob 转换为浮点数
        self.drop_prob = float(drop_prob)

    # 前向传播方法
    def forward(self, x):
        # 获取输入 x 的批次大小、drop_prob、设备和数据类型
        batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype

        # 如果 drop_prob 小于等于 0 或者不处于训练模式,则直接返回输入 x
        if drop_prob <= 0. or not self.training:
            return x

        # 计算保留概率
        keep_prob = 1 - self.drop_prob
        # 构建形状元组
        shape = (batch, *((1,) * (x.ndim - 1)))

        # 生成保留掩码
        keep_mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < keep_prob
        # 对输入 x 进行 DropPath 操作
        output = x.div(keep_prob) * keep_mask.float()
        return output

class Tokenizer(nn.Module):
    # 初始化 Tokenizer 类
    def __init__(self,
                 kernel_size, stride, padding,
                 pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
                 n_conv_layers=1,
                 n_input_channels=3,
                 n_output_channels=64,
                 in_planes=64,
                 activation=None,
                 max_pool=True,
                 conv_bias=False):
        # 调用父类的初始化方法
        super().__init()

        # 构建卷积层的通道数列表
        n_filter_list = [n_input_channels] + \
                        [in_planes for _ in range(n_conv_layers - 1)] + \
                        [n_output_channels]

        # 构建通道数列表的配对
        n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

        # 构建卷积层序列
        self.conv_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(chan_in, chan_out,
                          kernel_size=(kernel_size, kernel_size),
                          stride=(stride, stride),
                          padding=(padding, padding), bias=conv_bias),
                nn.Identity() if not exists(activation) else activation(),
                nn.MaxPool2d(kernel_size=pooling_kernel_size,
                             stride=pooling_stride,
                             padding=pooling_padding) if max_pool else nn.Identity()
            )
                for chan_in, chan_out in n_filter_list_pairs
            ])

        # 对模型参数进行初始化
        self.apply(self.init_weight)

    # 计算序列长度
    def sequence_length(self, n_channels=3, height=224, width=224):
        return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]

    # 前向传播方法
    def forward(self, x):
        # 对卷积层的输出进行重排列
        return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')

    # 初始化权重方法
    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)


class TransformerClassifier(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(self,
                 seq_pool=True,  # 是否使用序列池化
                 embedding_dim=768,  # 嵌入维度
                 num_layers=12,  # 编码器层数
                 num_heads=12,  # 注意力头数
                 mlp_ratio=4.0,  # MLP 扩展比例
                 num_classes=1000,  # 类别数
                 dropout_rate=0.1,  # Dropout 比例
                 attention_dropout=0.1,  # 注意力 Dropout 比例
                 stochastic_depth_rate=0.1,  # 随机深度比例
                 positional_embedding='sine',  # 位置编码类型
                 sequence_length=None,  # 序列长度
                 *args, **kwargs):  # 其他参数
        super().__init__()  # 调用父类的初始化函数
        assert positional_embedding in {'sine', 'learnable', 'none'}  # 断言位置编码类型合法

        dim_feedforward = int(embedding_dim * mlp_ratio)  # 计算前馈网络维度
        self.embedding_dim = embedding_dim  # 设置嵌入维度
        self.sequence_length = sequence_length  # 设置序列长度
        self.seq_pool = seq_pool  # 设置是否使用序列池化

        assert exists(sequence_length) or positional_embedding == 'none', \  # 断言序列长度存在或位置编码为'none'
            f"Positional embedding is set to {positional_embedding} and" \  # 打印位置编码设置信息
            f" the sequence length was not specified."

        if not seq_pool:  # 如果不使用序列池化
            sequence_length += 1  # 序列长度加一
            self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)  # 创建类别嵌入参数
        else:
            self.attention_pool = nn.Linear(self.embedding_dim, 1)  # 创建注意力池化层

        if positional_embedding == 'none':  # 如果位置编码为'none'
            self.positional_emb = None  # 不使用位置编码
        elif positional_embedding == 'learnable':  # 如果位置编码为'learnable'
            self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),  # 创建可学习位置编码参数
                                               requires_grad=True)
            nn.init.trunc_normal_(self.positional_emb, std=0.2)  # 对位置编码参数进行初始化
        else:
            self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),  # 创建正弦位置编码参数
                                               requires_grad=False)

        self.dropout = nn.Dropout(p=dropout_rate)  # 创建 Dropout 层

        dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]  # 计算随机深度比例列表

        self.blocks = nn.ModuleList([  # 创建 Transformer 编码器层列表
            TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
                                    dim_feedforward=dim_feedforward, dropout=dropout_rate,
                                    attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
            for layer_dpr in dpr])

        self.norm = nn.LayerNorm(embedding_dim)  # 创建 LayerNorm 层

        self.fc = nn.Linear(embedding_dim, num_classes)  # 创建全连接层
        self.apply(self.init_weight)  # 应用初始化权重函数

    # 前向传播函数
    def forward(self, x):
        b = x.shape[0]  # 获取 batch 大小

        if not exists(self.positional_emb) and x.size(1) < self.sequence_length:  # 如果位置编码不存在且序列长度小于指定长度
            x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)  # 对输入进行填充

        if not self.seq_pool:  # 如果不使用序列池化
            cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)  # 重复类别嵌入
            x = torch.cat((cls_token, x), dim=1)  # 拼接类别嵌入和输入

        if exists(self.positional_emb):  # 如果位置编码存在
            x += self.positional_emb  # 加上位置编码

        x = self.dropout(x)  # Dropout

        for blk in self.blocks:  # 遍历编码器层
            x = blk(x)  # 应用编码器层

        x = self.norm(x)  # LayerNorm

        if self.seq_pool:  # 如果使用序列池化
            attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')  # 注意力权重计算
            x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)  # 加权池化
        else:
            x = x[:, 0]  # 取第一个位置的输出作为结果

        return self.fc(x)  # 全连接层输出结果

    # 初始化权重函数
    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Linear):  # 如果是线性层
            nn.init.trunc_normal_(m.weight, std=.02)  # 初始化权重
            if isinstance(m, nn.Linear) and exists(m.bias):  # 如果是线性层且存在偏置
                nn.init.constant_(m.bias, 0)  # 初始化偏置为0
        elif isinstance(m, nn.LayerNorm):  # 如果是 LayerNorm 层
            nn.init.constant_(m.bias, 0)  # 初始化偏置为0
            nn.init.constant_(m.weight, 1.0)  # 初始化权重为1.0
# 定义 CCT 类,继承自 nn.Module
class CCT(nn.Module):
    # 初始化函数,设置各种参数
    def __init__(
        self,
        img_size=224,
        embedding_dim=768,
        n_input_channels=3,
        n_conv_layers=1,
        kernel_size=7,
        stride=2,
        padding=3,
        pooling_kernel_size=3,
        pooling_stride=2,
        pooling_padding=1,
        *args, **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        img_height, img_width = pair(img_size)

        # 初始化 Tokenizer 对象
        self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
                                   n_output_channels=embedding_dim,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding,
                                   pooling_kernel_size=pooling_kernel_size,
                                   pooling_stride=pooling_stride,
                                   pooling_padding=pooling_padding,
                                   max_pool=True,
                                   activation=nn.ReLU,
                                   n_conv_layers=n_conv_layers,
                                   conv_bias=False)

        # 初始化 TransformerClassifier 对象
        self.classifier = TransformerClassifier(
            sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
                                                           height=img_height,
                                                           width=img_width),
            embedding_dim=embedding_dim,
            seq_pool=True,
            dropout_rate=0.,
            attention_dropout=0.1,
            stochastic_depth=0.1,
            *args, **kwargs)

    # 前向传播函数
    def forward(self, x):
        # 对输入数据进行编码
        x = self.tokenizer(x)
        # 使用 Transformer 进行分类
        return self.classifier(x)

.\lucidrains\vit-pytorch\vit_pytorch\cct_3d.py

import torch  # 导入 PyTorch 库
from torch import nn, einsum  # 从 PyTorch 库中导入 nn 模块和 einsum 函数
import torch.nn.functional as F  # 从 PyTorch 库中导入 F 模块

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数

# helpers

def exists(val):
    return val is not None  # 判断变量是否存在的辅助函数

def default(val, d):
    return val if exists(val) else d  # 如果变量存在则返回变量,否则返回默认值的辅助函数

def pair(t):
    return t if isinstance(t, tuple) else (t, t)  # 如果输入是元组则返回输入,否则返回包含输入两次的元组

# CCT Models

__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']  # 定义导出的模型名称列表

# 定义不同层数的 CCT 模型函数

def cct_2(*args, **kwargs):
    return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)  # 返回 2 层 CCT 模型

def cct_4(*args, **kwargs):
    return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)  # 返回 4 层 CCT 模型

def cct_6(*args, **kwargs):
    return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)  # 返回 6 层 CCT 模型

def cct_7(*args, **kwargs):
    return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)  # 返回 7 层 CCT 模型

def cct_8(*args, **kwargs):
    return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)  # 返回 8 层 CCT 模型

def cct_14(*args, **kwargs):
    return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)  # 返回 14 层 CCT 模型

def cct_16(*args, **kwargs):
    return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)  # 返回 16 层 CCT 模型

# 定义 CCT 模型函数

def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
         kernel_size=3, stride=None, padding=None,
         *args, **kwargs):
    stride = default(stride, max(1, (kernel_size // 2) - 1))  # 设置默认的步长
    padding = default(padding, max(1, (kernel_size // 2)))  # 设置默认的填充大小

    return CCT(num_layers=num_layers,
               num_heads=num_heads,
               mlp_ratio=mlp_ratio,
               embedding_dim=embedding_dim,
               kernel_size=kernel_size,
               stride=stride,
               padding=padding,
               *args, **kwargs)  # 返回 CCT 模型

# positional

def sinusoidal_embedding(n_channels, dim):
    pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
                            for p in range(n_channels)])  # 计算正弦余弦位置编码
    pe[:, 0::2] = torch.sin(pe[:, 0::2])  # 偶数列使用正弦函数
    pe[:, 1::2] = torch.cos(pe[:, 1::2])  # 奇数列使用余弦函数
    return rearrange(pe, '... -> 1 ...')  # 重新排列位置编码的维度

# modules

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.heads = num_heads  # 设置注意力头数
        head_dim = dim // self.heads  # 计算每个头的维度
        self.scale = head_dim ** -0.5  # 缩放因子

        self.qkv = nn.Linear(dim, dim * 3, bias=False)  # 线性变换层
        self.attn_drop = nn.Dropout(attention_dropout)  # 注意力丢弃层
        self.proj = nn.Linear(dim, dim)  # 投影层
        self.proj_drop = nn.Dropout(projection_dropout)  # 投影丢弃层

    def forward(self, x):
        B, N, C = x.shape  # 获取输入张量的形状

        qkv = self.qkv(x).chunk(3, dim = -1)  # 将线性变换后的张量切分为 Q、K、V
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # 重排 Q、K、V 的维度

        q = q * self.scale  # 缩放 Q

        attn = einsum('b h i d, b h j d -> b h i j', q, k)  # 计算注意力分数
        attn = attn.softmax(dim=-1)  # 对注意力分数进行 softmax
        attn = self.attn_drop(attn)  # 使用注意力丢弃层

        x = einsum('b h i j, b h j d -> b h i d', attn, v)  # 计算加权后的 V
        x = rearrange(x, 'b h n d -> b n (h d)')  # 重排输出张量的维度

        return self.proj_drop(self.proj(x))  # 使用投影丢弃层进行投影

class TransformerEncoderLayer(nn.Module):
    """
    Inspired by torch.nn.TransformerEncoderLayer and
    rwightman's timm package.
    """
    # 初始化函数,定义了 Transformer Encoder 层的结构
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 attention_dropout=0.1, drop_path_rate=0.1):
        # 调用父类的初始化函数
        super().__init__()

        # 对输入进行 Layer Normalization
        self.pre_norm = nn.LayerNorm(d_model)
        # 定义自注意力机制
        self.self_attn = Attention(dim=d_model, num_heads=nhead,
                                   attention_dropout=attention_dropout, projection_dropout=dropout)

        # 第一个线性层
        self.linear1  = nn.Linear(d_model, dim_feedforward)
        # 第一个 Dropout 层
        self.dropout1 = nn.Dropout(dropout)
        # 第一个 Layer Normalization 层
        self.norm1    = nn.LayerNorm(d_model)
        # 第二个线性层
        self.linear2  = nn.Linear(dim_feedforward, d_model)
        # 第二个 Dropout 层
        self.dropout2 = nn.Dropout(dropout)

        # DropPath 模块
        self.drop_path = DropPath(drop_path_rate)

        # 激活函数为 GELU
        self.activation = F.gelu

    # 前向传播函数
    def forward(self, src, *args, **kwargs):
        # 使用自注意力机制对输入进行处理,并加上 DropPath 模块
        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
        # 对结果进行 Layer Normalization
        src = self.norm1(src)
        # 第二个线性层的计算过程
        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
        # 将第二个线性层的结果加上 DropPath 模块
        src = src + self.drop_path(self.dropout2(src2))
        # 返回处理后的结果
        return src
class DropPath(nn.Module):
    # 初始化 DropPath 类
    def __init__(self, drop_prob=None):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的 drop_prob 转换为浮点数
        self.drop_prob = float(drop_prob)

    # 前向传播方法
    def forward(self, x):
        # 获取输入 x 的批量大小、drop_prob、设备和数据类型
        batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype

        # 如果 drop_prob 小于等于 0 或者不处于训练模式,则直接返回输入 x
        if drop_prob <= 0. or not self.training:
            return x

        # 计算保留概率
        keep_prob = 1 - self.drop_prob
        # 构建形状元组
        shape = (batch, *((1,) * (x.ndim - 1)))

        # 生成保留掩码
        keep_mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < keep_prob
        # 对输入 x 进行处理并返回输出
        output = x.div(keep_prob) * keep_mask.float()
        return output

class Tokenizer(nn.Module):
    # 初始化 Tokenizer 类
    def __init__(
        self,
        frame_kernel_size,
        kernel_size,
        stride,
        padding,
        frame_stride=1,
        frame_pooling_stride=1,
        frame_pooling_kernel_size=1,
        pooling_kernel_size=3,
        pooling_stride=2,
        pooling_padding=1,
        n_conv_layers=1,
        n_input_channels=3,
        n_output_channels=64,
        in_planes=64,
        activation=None,
        max_pool=True,
        conv_bias=False
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 构建卷积层的通道数列表
        n_filter_list = [n_input_channels] + \
                        [in_planes for _ in range(n_conv_layers - 1)] + \
                        [n_output_channels]

        # 构建通道数列表的配对
        n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

        # 构建卷积层序列
        self.conv_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Conv3d(chan_in, chan_out,
                          kernel_size=(frame_kernel_size, kernel_size, kernel_size),
                          stride=(frame_stride, stride, stride),
                          padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
                nn.Identity() if not exists(activation) else activation(),
                nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
                             stride=(frame_pooling_stride, pooling_stride, pooling_stride),
                             padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
            )
                for chan_in, chan_out in n_filter_list_pairs
            ])

        # 对模型进行权重初始化
        self.apply(self.init_weight)

    # 计算序列长度
    def sequence_length(self, n_channels=3, frames=8, height=224, width=224):
        return self.forward(torch.zeros((1, n_channels, frames, height, width))).shape[1]

    # 前向传播方法
    def forward(self, x):
        # 对输入 x 进行卷积操作并返回重排后的输出
        x = self.conv_layers(x)
        return rearrange(x, 'b c f h w -> b (f h w) c')

    # 初始化权重方法
    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight)


class TransformerClassifier(nn.Module):
    # 初始化 TransformerClassifier 类
    def __init__(
        self,
        seq_pool=True,
        embedding_dim=768,
        num_layers=12,
        num_heads=12,
        mlp_ratio=4.0,
        num_classes=1000,
        dropout_rate=0.1,
        attention_dropout=0.1,
        stochastic_depth_rate=0.1,
        positional_embedding='sine',
        sequence_length=None,
        *args, **kwargs
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言位置编码在{'sine', 'learnable', 'none'}中
        assert positional_embedding in {'sine', 'learnable', 'none'}

        # 计算前馈网络的维度
        dim_feedforward = int(embedding_dim * mlp_ratio)
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        self.seq_pool = seq_pool

        # 断言序列长度存在或者位置编码为'none'
        assert exists(sequence_length) or positional_embedding == 'none', \
            f"Positional embedding is set to {positional_embedding} and" \
            f" the sequence length was not specified."

        # 如果不使用序列池化
        if not seq_pool:
            sequence_length += 1
            self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim))
        else:
            self.attention_pool = nn.Linear(self.embedding_dim, 1)

        # 根据位置编码类型初始化位置编码
        if positional_embedding == 'none':
            self.positional_emb = None
        elif positional_embedding == 'learnable':
            self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim))
            nn.init.trunc_normal_(self.positional_emb, std=0.2)
        else:
            self.register_buffer('positional_emb', sinusoidal_embedding(sequence_length, embedding_dim))

        # 初始化Dropout层
        self.dropout = nn.Dropout(p=dropout_rate)

        # 生成随机Drop Path率
        dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]

        # 创建Transformer编码器层
        self.blocks = nn.ModuleList([
            TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
                                    dim_feedforward=dim_feedforward, dropout=dropout_rate,
                                    attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
            for layer_dpr in dpr])

        # 初始化LayerNorm层
        self.norm = nn.LayerNorm(embedding_dim)

        # 初始化全连接层
        self.fc = nn.Linear(embedding_dim, num_classes)
        # 应用初始化权重函数
        self.apply(self.init_weight)

    @staticmethod
    def init_weight(m):
        # 初始化线性层的权重
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            # 如果是线性层且存在偏置项,则初始化偏置项
            if isinstance(m, nn.Linear) and exists(m.bias):
                nn.init.constant_(m.bias, 0)
        # 初始化LayerNorm层的权重和偏置项
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # 获取批量大小
        b = x.shape[0]

        # 如果位置编码不存在且输入序列长度小于设定的序列长度,则进行填充
        if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
            x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)

        # 如果不使用序列池化,则在输入序列前添加类别标记
        if not self.seq_pool:
            cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b=b)
            x = torch.cat((cls_token, x), dim=1)

        # 如果位置编码存在,则加上位置编码
        if exists(self.positional_emb):
            x += self.positional_emb

        # Dropout层
        x = self.dropout(x)

        # 遍历Transformer编码器层
        for blk in self.blocks:
            x = blk(x)

        # LayerNorm层
        x = self.norm(x)

        # 如果使用序列池化,则计算注意力权重并进行加权求和
        if self.seq_pool:
            attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
            x = einsum('b n, b n d -> b d', attn_weights.softmax(dim=1), x)
        else:
            x = x[:, 0]

        # 全连接层
        return self.fc(x)
# 定义 CCT 类,继承自 nn.Module
class CCT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        img_size=224,  # 图像大小,默认为 224
        num_frames=8,  # 帧数,默认为 8
        embedding_dim=768,  # 嵌入维度,默认为 768
        n_input_channels=3,  # 输入通道数,默认为 3
        n_conv_layers=1,  # 卷积层数,默认为 1
        frame_stride=1,  # 帧步长,默认为 1
        frame_kernel_size=3,  # 帧卷积核大小,默认为 3
        frame_pooling_kernel_size=1,  # 帧池化核大小,默认为 1
        frame_pooling_stride=1,  # 帧池化步长,默认为 1
        kernel_size=7,  # 卷积核大小,默认为 7
        stride=2,  # 步长,默认为 2
        padding=3,  # 填充,默认为 3
        pooling_kernel_size=3,  # 池化核大小,默认为 3
        pooling_stride=2,  # 池化步长,默认为 2
        pooling_padding=1,  # 池化填充,默认为 1
        *args, **kwargs  # 其他参数
    ):
        super().__init__()  # 调用父类的初始化函数

        img_height, img_width = pair(img_size)  # 获取图像的高度和宽度

        # 初始化 Tokenizer 对象
        self.tokenizer = Tokenizer(
            n_input_channels=n_input_channels,
            n_output_channels=embedding_dim,
            frame_stride=frame_stride,
            frame_kernel_size=frame_kernel_size,
            frame_pooling_stride=frame_pooling_stride,
            frame_pooling_kernel_size=frame_pooling_kernel_size,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            pooling_kernel_size=pooling_kernel_size,
            pooling_stride=pooling_stride,
            pooling_padding=pooling_padding,
            max_pool=True,
            activation=nn.ReLU,
            n_conv_layers=n_conv_layers,
            conv_bias=False
        )

        # 初始化 TransformerClassifier 对象
        self.classifier = TransformerClassifier(
            sequence_length=self.tokenizer.sequence_length(
                n_channels=n_input_channels,
                frames=num_frames,
                height=img_height,
                width=img_width
            ),
            embedding_dim=embedding_dim,
            seq_pool=True,
            dropout_rate=0.,
            attention_dropout=0.1,
            stochastic_depth=0.1,
            *args, **kwargs
        )

    # 前向传播函数
    def forward(self, x):
        x = self.tokenizer(x)  # 对输入数据进行编码
        return self.classifier(x)  # 对编码后的���据进行分类

.\lucidrains\vit-pytorch\vit_pytorch\crossformer.py

import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F

# 辅助函数

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 交叉嵌入层

class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_sizes,
        stride = 2
    ):
        super().__init__()
        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        # 对输入进行卷积操作,并将结果拼接在一起
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        return torch.cat(fmaps, dim = 1)

# 动态位置偏置

def DynamicPositionBias(dim):
    return nn.Sequential(
        nn.Linear(2, dim),
        nn.LayerNorm(dim),
        nn.ReLU(),
        nn.Linear(dim, dim),
        nn.LayerNorm(dim),
        nn.ReLU(),
        nn.Linear(dim, dim),
        nn.LayerNorm(dim),
        nn.ReLU(),
        nn.Linear(dim, 1),
        Rearrange('... () -> ...')
    )

# transformer 类

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        LayerNorm(dim),
        nn.Conv2d(dim, dim * mult, 1),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Conv2d(dim * mult, dim, 1)
    )

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        attn_type,
        window_size,
        dim_head = 32,
        dropout = 0.
    ):
        super().__init__()
        assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
        heads = dim // dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.attn_type = attn_type
        self.window_size = window_size

        self.norm = LayerNorm(dim)

        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)

        # 位置

        self.dpb = DynamicPositionBias(dim // 4)

        # 计算和存储用于检索偏置的索引

        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = grid[:, None] - grid[None, :]
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
        # 解构 x 的形状,获取高度、宽度、头数、窗口大小和设备信息
        *_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device

        # 对输入进行预处理
        x = self.norm(x)

        # 根据不同的注意力类型重新排列输入,以便进行短距离或长距离注意力
        if self.attn_type == 'short':
            x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
        elif self.attn_type == 'long':
            x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)

        # 将输入转换为查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        # 将查询、键、值按头数进行分割
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
        q = q * self.scale

        # 计算注意力矩阵
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加动态位置偏置
        pos = torch.arange(-wsz, wsz + 1, device = device)
        rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
        biases = self.dpb(rel_pos.float())
        rel_pos_bias = biases[self.rel_pos_indices]
        sim = sim + rel_pos_bias

        # 注意力权重归一化
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 合并头部
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
        out = self.to_out(out)

        # 根据不同的注意力类型重新排列输出
        if self.attn_type == 'short':
            out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
        elif self.attn_type == 'long':
            out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)

        return out
# 定义一个名为 Transformer 的神经网络模块
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        local_window_size,
        global_window_size,
        depth = 4,
        dim_head = 32,
        attn_dropout = 0.,
        ff_dropout = 0.,
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])

        # 循环创建指定深度的神经网络层
        for _ in range(depth):
            # 每层包含两个注意力机制和两个前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout),
                Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout)
            ]))

    # 前向传播函数
    def forward(self, x):
        # 遍历每一层的注意力机制和前馈神经网络
        for short_attn, short_ff, long_attn, long_ff in self.layers:
            # 执行短程注意力机制和前馈神经网络
            x = short_attn(x) + x
            x = short_ff(x) + x
            # 执行长程注意力机制和前馈神经网络
            x = long_attn(x) + x
            x = long_ff(x) + x

        # 返回处理后的数据
        return x

# 定义一个名为 CrossFormer 的神经网络模块
class CrossFormer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim = (64, 128, 256, 512),
        depth = (2, 2, 8, 2),
        global_window_size = (8, 4, 2, 1),
        local_window_size = 7,
        cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
        cross_embed_strides = (4, 2, 2, 2),
        num_classes = 1000,
        attn_dropout = 0.,
        ff_dropout = 0.,
        channels = 3
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 将参数转换为元组形式
        dim = cast_tuple(dim, 4)
        depth = cast_tuple(depth, 4)
        global_window_size = cast_tuple(global_window_size, 4)
        local_window_size = cast_tuple(local_window_size, 4)
        cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
        cross_embed_strides = cast_tuple(cross_embed_strides, 4)

        # 断言确保参数长度为4
        assert len(dim) == 4
        assert len(depth) == 4
        assert len(global_window_size) == 4
        assert len(local_window_size) == 4
        assert len(cross_embed_kernel_sizes) == 4
        assert len(cross_embed_strides) == 4

        # 定义维度相关变量
        last_dim = dim[-1]
        dims = [channels, *dim]
        dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))

        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])

        # 循环创建交叉嵌入层和 Transformer 层
        for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
            self.layers.append(nn.ModuleList([
                CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
                Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
            ]))

        # 定义最终的逻辑层
        self.to_logits = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(last_dim, num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 遍历每一层的交叉嵌入层和 Transformer 层
        for cel, transformer in self.layers:
            # 执行交叉嵌入层
            x = cel(x)
            # 执行 Transformer 层
            x = transformer(x)

        # 返回最终的逻辑结果
        return self.to_logits(x)

.\lucidrains\vit-pytorch\vit_pytorch\cross_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 前馈神经网络

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None, kv_include_self = False):
        b, n, _, h = *x.shape, self.heads
        x = self.norm(x)
        context = default(context, x)

        if kv_include_self:
            context = torch.cat((x, context), dim = 1) # 交叉注意力需要 CLS 标记包含自身作为键/值

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 编码器,用于小和大补丁

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 投影 CLS 标记,以防小和大补丁标记具有不同的维度

class ProjectInOut(nn.Module):
    def __init__(self, dim_in, dim_out, fn):
        super().__init__()
        self.fn = fn

        need_projection = dim_in != dim_out
        self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
        self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()

    def forward(self, x, *args, **kwargs):
        x = self.project_in(x)
        x = self.fn(x, *args, **kwargs)
        x = self.project_out(x)
        return x

# 交叉���意力 Transformer

class CrossTransformer(nn.Module):
    def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
            ]))
    # 定义一个前向传播函数,接受两个输入:sm_tokens和lg_tokens
    def forward(self, sm_tokens, lg_tokens):
        # 将输入的sm_tokens和lg_tokens分别拆分为(sm_cls, sm_patch_tokens)和(lg_cls, lg_patch_tokens)
        (sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))

        # 遍历self.layers中的每一层,每一层包含sm_attend_lg和lg_attend_sm
        for sm_attend_lg, lg_attend_sm in self.layers:
            # 对sm_cls进行注意力计算,使用lg_patch_tokens作为上下文,kv_include_self设置为True,然后加上原始sm_cls
            sm_cls = sm_attend_lg(sm_cls, context=lg_patch_tokens, kv_include_self=True) + sm_cls
            # 对lg_cls进行注意力计算,使用sm_patch_tokens作为上下文,kv_include_self设置为True,然后加上原始lg_cls
            lg_cls = lg_attend_sm(lg_cls, context=sm_patch_tokens, kv_include_self=True) + lg_cls

        # 将sm_cls和sm_patch_tokens在维度1上拼接起来
        sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim=1)
        # 将lg_cls和lg_patch_tokens在维度1上拼接起来
        lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim=1)
        # 返回拼接后的sm_tokens和lg_tokens
        return sm_tokens, lg_tokens
# 定义多尺度编码器类
class MultiScaleEncoder(nn.Module):
    def __init__(
        self,
        *,
        depth,  # 编码器深度
        sm_dim,  # 小尺度维度
        lg_dim,  # 大尺度维度
        sm_enc_params,  # 小尺度编码器参数
        lg_enc_params,  # 大尺度编码器参数
        cross_attn_heads,  # 跨尺度注意力头数
        cross_attn_depth,  # 跨尺度注意力深度
        cross_attn_dim_head = 64,  # 跨尺度注意力头维度
        dropout = 0.  # 丢弃率
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params),  # 小尺度变换器
                Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params),  # 大尺度变换器
                CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)  # 跨尺度变换器
            ]))

    def forward(self, sm_tokens, lg_tokens):
        for sm_enc, lg_enc, cross_attend in self.layers:
            sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)  # 小尺度编码器和大尺度编码器
            sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)  # 跨尺度注意力

        return sm_tokens, lg_tokens

# 基于补丁的图像到标记嵌入器类
class ImageEmbedder(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 维度
        image_size,  # 图像尺寸
        patch_size,  # 补丁尺寸
        dropout = 0.  # 丢弃率
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),  # 图像转换为补丁
            nn.LayerNorm(patch_dim),  # 层归一化
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim)  # 层归一化
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置嵌入
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 类别标记
        self.dropout = nn.Dropout(dropout)  # 丢弃层

    def forward(self, img):
        x = self.to_patch_embedding(img)  # 图像转换为补丁嵌入
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)  # 重复类别标记
        x = torch.cat((cls_tokens, x), dim=1)  # 拼接类别标记和补丁嵌入
        x += self.pos_embedding[:, :(n + 1)]  # 加上位置嵌入

        return self.dropout(x)  # 返回结果经过丢弃层处理

# 跨ViT类
class CrossViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,  # 图像尺寸
        num_classes,  # 类别数
        sm_dim,  # 小尺度维度
        lg_dim,  # 大尺度维度
        sm_patch_size = 12,  # 小尺度补丁尺寸
        sm_enc_depth = 1,  # 小尺度编码器深度
        sm_enc_heads = 8,  # 小尺度编码器头数
        sm_enc_mlp_dim = 2048,  # 小尺度编码器MLP维度
        sm_enc_dim_head = 64,  # 小尺度编码器头维度
        lg_patch_size = 16,  # 大尺度补丁尺寸
        lg_enc_depth = 4,  # 大尺度编码器深度
        lg_enc_heads = 8,  # 大尺度编码器头数
        lg_enc_mlp_dim = 2048,  # 大尺度编码器MLP维度
        lg_enc_dim_head = 64,  # 大尺度编码器头维度
        cross_attn_depth = 2,  # 跨尺度注意力深度
        cross_attn_heads = 8,  # 跨尺度注意力头数
        cross_attn_dim_head = 64,  # 跨尺度注意力头维度
        depth = 3,  # 深度
        dropout = 0.1,  # 丢弃率
        emb_dropout = 0.1  # 嵌入丢弃率
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()
        # 创建小尺寸图像嵌入器对象
        self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
        # 创建大尺寸图像嵌入器对象
        self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)

        # 创建多尺度编码器对象
        self.multi_scale_encoder = MultiScaleEncoder(
            depth = depth,
            sm_dim = sm_dim,
            lg_dim = lg_dim,
            cross_attn_heads = cross_attn_heads,
            cross_attn_dim_head = cross_attn_dim_head,
            cross_attn_depth = cross_attn_depth,
            sm_enc_params = dict(
                depth = sm_enc_depth,
                heads = sm_enc_heads,
                mlp_dim = sm_enc_mlp_dim,
                dim_head = sm_enc_dim_head
            ),
            lg_enc_params = dict(
                depth = lg_enc_depth,
                heads = lg_enc_heads,
                mlp_dim = lg_enc_mlp_dim,
                dim_head = lg_enc_dim_head
            ),
            dropout = dropout
        )

        # 创建小尺寸MLP头部对象
        self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
        # 创建大尺寸MLP头部对象
        self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))

    # 前向传播函数
    def forward(self, img):
        # 获取小尺寸图像嵌入
        sm_tokens = self.sm_image_embedder(img)
        # 获取大尺寸图像嵌入
        lg_tokens = self.lg_image_embedder(img)

        # 多尺度编码器处理小尺寸和大尺寸图像嵌入
        sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)

        # 提取小尺寸和大尺寸的类别特征
        sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))

        # 小尺寸MLP头部处理小尺寸类别特征
        sm_logits = self.sm_mlp_head(sm_cls)
        # 大尺寸MLP头部处理大尺寸类别特征
        lg_logits = self.lg_mlp_head(lg_cls)

        # 返回小尺寸和大尺寸类别特征的加和
        return sm_logits + lg_logits

.\lucidrains\vit-pytorch\vit_pytorch\cvt.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helper methods

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(), dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 根据前缀分组并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# classes

# 自定义 LayerNorm 类
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 自定义 FeedForward 类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, dim * mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 自定义 DepthWiseConv2d 类
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# 自定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        padding = proj_kernel // 2
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
        self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        b, n, _, y, h = *shape, self.heads

        x = self.norm(x)
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

# 自定义 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_mult, dropout = dropout)
            ]))
    # 定义一个前向传播函数,接受输入 x
    def forward(self, x):
        # 遍历 self.layers 中的每个元素,每个元素包含一个注意力机制和一个前馈神经网络
        for attn, ff in self.layers:
            # 使用注意力机制处理输入 x,并将结果与原始输入相加
            x = attn(x) + x
            # 使用前馈神经网络处理上一步的结果,并将结果与原始输入相加
            x = ff(x) + x
        # 返回处理后的结果 x
        return x
# 定义一个名为 CvT 的神经网络模型,继承自 nn.Module 类
class CvT(nn.Module):
    # 初始化函数,接收一系列参数
    def __init__(
        self,
        *,
        num_classes,  # 类别数量
        s1_emb_dim = 64,  # s1 阶段的嵌入维度
        s1_emb_kernel = 7,  # s1 阶段的卷积核大小
        s1_emb_stride = 4,  # s1 阶段的卷积步长
        s1_proj_kernel = 3,  # s1 阶段的投影卷积核大小
        s1_kv_proj_stride = 2,  # s1 阶段的键值投影步长
        s1_heads = 1,  # s1 阶段的注意力头数
        s1_depth = 1,  # s1 阶段的深度
        s1_mlp_mult = 4,  # s1 阶段的 MLP 扩展倍数
        s2_emb_dim = 192,  # s2 阶段的嵌入维度
        s2_emb_kernel = 3,  # s2 阶段的卷积核大小
        s2_emb_stride = 2,  # s2 阶段的卷积步长
        s2_proj_kernel = 3,  # s2 阶段的投影卷积核大小
        s2_kv_proj_stride = 2,  # s2 阶段的键值投影步长
        s2_heads = 3,  # s2 阶段的注意力头数
        s2_depth = 2,  # s2 阶段的深度
        s2_mlp_mult = 4,  # s2 阶段的 MLP 扩展倍数
        s3_emb_dim = 384,  # s3 阶段的嵌入维度
        s3_emb_kernel = 3,  # s3 阶段的卷积核大小
        s3_emb_stride = 2,  # s3 阶段的卷积步长
        s3_proj_kernel = 3,  # s3 阶段的投影卷积核大小
        s3_kv_proj_stride = 2,  # s3 阶段的键值投影步长
        s3_heads = 6,  # s3 阶段的注意力头数
        s3_depth = 10,  # s3 阶段的深度
        s3_mlp_mult = 4,  # s3 阶段的 MLP 扩展倍数
        dropout = 0.,  # Dropout 概率
        channels = 3  # 输入通道数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数保存到字典中
        kwargs = dict(locals())

        # 初始化维度为输入通道数
        dim = channels
        # 初始化层列表
        layers = []

        # 遍历 s1、s2、s3 三个阶段
        for prefix in ('s1', 's2', 's3'):
            # 根据前缀分组参数,并从参数字典中移除前缀
            config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)

            # 将卷积、LayerNorm 和 Transformer 层添加到层列表中
            layers.append(nn.Sequential(
                nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
                LayerNorm(config['emb_dim']),
                Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
            ))

            # 更新维度为当前阶段的嵌入维度
            dim = config['emb_dim']

        # 将所有层组成一个序列
        self.layers = nn.Sequential(*layers)

        # 定义输出层,包括全局平均池化、重排和全连接层
        self.to_logits = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...'),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 经过所有层得到特征向量
        latents = self.layers(x)
        # 将特征向量传递给输出层得到预测结果
        return self.to_logits(latents)

.\lucidrains\vit-pytorch\vit_pytorch\deepvit.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 正则化
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )
    def forward(self, x):
        return self.net(x)

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换

        self.dropout = nn.Dropout(dropout)  # Dropout 正则化

        self.reattn_weights = nn.Parameter(torch.randn(heads, heads))  # 定义可学习参数

        self.reattn_norm = nn.Sequential(
            Rearrange('b h i j -> b i j h'),  # 重新排列张量维度
            nn.LayerNorm(heads),  # 对输入进行 Layer Normalization
            Rearrange('b i j h -> b h i j')  # 重新排列张量维度
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        x = self.norm(x)  # 对输入进行 Layer Normalization

        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将线性变换后的结果切分成三部分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)  # 重新排列张量维度

        # attention

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale  # 计算点积
        attn = dots.softmax(dim=-1)  # Softmax 操作
        attn = self.dropout(attn)  # Dropout 正则化

        # re-attention

        attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)  # 重新排列张量维度
        attn = self.reattn_norm(attn)  # 对输入进行 Layer Normalization

        # aggregate and out

        out = einsum('b h i j, b h j d -> b h i d', attn, v)  # 点积操作
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重新排列张量维度
        out =  self.to_out(out)  # 线性变换
        return out

# 定义一个 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),  # 注意力机制
                FeedForward(dim, mlp_dim, dropout = dropout)  # 前馈神经网络
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x  # 注意力机制的输出与输入相加
            x = ff(x) + x  # 前馈神经网络的输出与输入相加
        return x

# 定义一个 DeepViT 类
class DeepViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),  # 重新排列张量维度
            nn.LayerNorm(patch_dim),  # 对输入进行 Layer Normalization
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 定义可学习参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 定义可学习参数
        self.dropout = nn.Dropout(emb_dropout)  # Dropout 正则化

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer 模块

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, num_classes)  # 线性变换
        )
    # 前向传播函数,接收输入图像并返回预测结果
    def forward(self, img):
        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 获取批量大小、补丁数量和嵌入维度
        b, n, _ = x.shape

        # 重复类别标记以匹配批量大小,并与补丁嵌入连接
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置嵌入到输入嵌入中
        x += self.pos_embedding[:, :(n + 1)]
        # 对输入进行 dropout 处理
        x = self.dropout(x)

        # 使用 Transformer 处理输入数据
        x = self.transformer(x)

        # 根据池化方式计算输出结果
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 将输出结果转换为潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部处理潜在空间的输出,并返回预测结果
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\dino.py

# 导入所需的库
import copy
import random
from functools import wraps, partial

import torch
from torch import nn
import torch.nn.functional as F

from torchvision import transforms as T

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在,则返回该值,否则返回默认值
def default(val, default):
    return val if exists(val) else default

# 单例装饰器,用于缓存结果
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# 获取模块所在设备
def get_module_device(module):
    return next(module.parameters()).device

# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

# 损失函数(论文中的算法1)

def loss_fn(
    teacher_logits,
    student_logits,
    teacher_temp,
    student_temp,
    centers,
    eps = 1e-20
):
    teacher_logits = teacher_logits.detach()
    student_probs = (student_logits / student_temp).softmax(dim = -1)
    teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
    return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()

# 数据增强工具类

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 指数移动平均

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP类用于投影器和预测器

class L2Norm(nn.Module):
    def forward(self, x, eps = 1e-6):
        norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
        return x / norm

class MLP(nn.Module):
    def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
        super().__init__()

        layers = []
        dims = (dim, *((hidden_size,) * (num_layers - 1)))

        for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
            is_last = ind == (len(dims) - 1)

            layers.extend([
                nn.Linear(layer_dim_in, layer_dim_out),
                nn.GELU() if not is_last else nn.Identity()
            ])

        self.net = nn.Sequential(
            *layers,
            L2Norm(),
            nn.Linear(hidden_size, dim_out)
        )

    def forward(self, x):
        return self.net(x)

# 用于基础神经网络的包装类
# 将管理隐藏层输出的拦截并将其传递到投影器和预测器网络中

class NetWrapper(nn.Module):
    def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_hidden_size = projection_hidden_size
        self.projection_num_layers = projection_num_layers
        self.output_dim = output_dim

        self.hidden = {}
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None
    # 定义一个私有方法,用于在 forward hook 中保存隐藏层的输出
    def _hook(self, _, input, output):
        # 获取输入数据的设备信息
        device = input[0].device
        # 将隐藏层的输出展平并保存到字典中
        self.hidden[device] = output.flatten(1)

    # 注册 forward hook,用于捕获隐藏层的输出
    def _register_hook(self):
        # 查找指定的隐藏层
        layer = self._find_layer()
        # 断言找到了隐藏层
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        # 注册 forward hook
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    # 获取投影器,用于将隐藏层的输出投影到指定维度
    @singleton('projector')
    def _get_projector(self, hidden):
        # 获取隐藏层输出的维度
        _, dim = hidden.shape
        # 创建 MLP 投影器
        projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取输入数据的隐藏层输出
    def get_embedding(self, x):
        # 如果隐藏层为最后一层,则直接返回网络的输出
        if self.layer == -1:
            return self.net(x)

        # 如果 hook 没有注册,则注册 hook
        if not self.hook_registered:
            self._register_hook()

        # 清空隐藏层输出字典
        self.hidden.clear()
        # 前向传播获取隐藏层输出
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        # 断言隐藏层输出不为空
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    # 网络的前向传播,可选择是否返回投影后的输出
    def forward(self, x, return_projection = True):
        # 获取输入数据的隐藏层输出
        embed = self.get_embedding(x)
        # 如果不需要返回投影后的输出,则直接返回隐藏层输出
        if not return_projection:
            return embed

        # 获取投影器并对隐藏层输出进行投影
        projector = self._get_projector(embed)
        return projector(embed), embed
# 主类定义
class Dino(nn.Module):
    # 初始化函数
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_hidden_size = 256,
        num_classes_K = 65336,
        projection_layers = 4,
        student_temp = 0.9,
        teacher_temp = 0.04,
        local_upper_crop_scale = 0.4,
        global_lower_crop_scale = 0.5,
        moving_average_decay = 0.9,
        center_moving_average_decay = 0.9,
        augment_fn = None,
        augment_fn2 = None
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置网络
        self.net = net

        # 默认的 BYOL 数据增强
        DEFAULT_AUG = torch.nn.Sequential(
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            T.RandomGrayscale(p=0.2),
            T.RandomHorizontalFlip(),
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

        # 设置数据增强函数
        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, DEFAULT_AUG)

        # 设置局部和全局裁剪
        self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
        self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))

        # 设置学生编码器
        self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)

        self.teacher_encoder = None
        self.teacher_ema_updater = EMA(moving_average_decay)

        # 注册缓冲区
        self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
        self.register_buffer('last_teacher_centers',  torch.zeros(1, num_classes_K))

        self.teacher_centering_ema_updater = EMA(center_moving_average_decay)

        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        # 获取网络设备并将包装器设置为相同设备
        device = get_module_device(net)
        self.to(device)

        # 发送一个模拟图像张量以实例化单例参数
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))

    # 获取教师编码器的单例函数
    @singleton('teacher_encoder')
    def _get_teacher_encoder(self):
        teacher_encoder = copy.deepcopy(self.student_encoder)
        set_requires_grad(teacher_encoder, False)
        return teacher_encoder

    # 重置移动平均值
    def reset_moving_average(self):
        del self.teacher_encoder
        self.teacher_encoder = None

    # 更新移动平均值
    def update_moving_average(self):
        assert self.teacher_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

        new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
        self.teacher_centers.copy_(new_teacher_centers)

    # 前向传播函数
    def forward(
        self,
        x,
        return_embedding = False,
        return_projection = True,
        student_temp = None,
        teacher_temp = None
        ):
        # 如果需要返回嵌入向量,则调用学生编码器并返回结果
        if return_embedding:
            return self.student_encoder(x, return_projection = return_projection)

        # 对输入数据进行两种不同的数据增强
        image_one, image_two = self.augment1(x), self.augment2(x)

        # 对增强后的图像进行局部裁剪
        local_image_one, local_image_two   = self.local_crop(image_one),  self.local_crop(image_two)
        # 对增强后的图像进行全局裁剪
        global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)

        # 使用学生编码器对局部裁剪后的图像进行编码
        student_proj_one, _ = self.student_encoder(local_image_one)
        student_proj_two, _ = self.student_encoder(local_image_two)

        # 使用torch.no_grad()上下文管理器,获取教师编码器并对全局裁剪后的图像进行编码
        with torch.no_grad():
            teacher_encoder = self._get_teacher_encoder()
            teacher_proj_one, _ = teacher_encoder(global_image_one)
            teacher_proj_two, _ = teacher_encoder(global_image_two)

        # 部分应用损失函数,设置学生温度、教师温度和教师中心
        loss_fn_ = partial(
            loss_fn,
            student_temp = default(student_temp, self.student_temp),
            teacher_temp = default(teacher_temp, self.teacher_temp),
            centers = self.teacher_centers
        )

        # 计算教师投影的平均值,并将其复制到最后的教师中心
        teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
        self.last_teacher_centers.copy_(teacher_logits_avg)

        # 计算损失,取两个损失函数的平均值
        loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
        return loss

.\lucidrains\vit-pytorch\vit_pytorch\distill.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import nn  # 从 PyTorch 中导入 nn 模块
from vit_pytorch.vit import ViT  # 从 vit_pytorch 库中导入 ViT 类
from vit_pytorch.t2t import T2TViT  # 从 vit_pytorch 库中导入 T2TViT 类
from vit_pytorch.efficient import ViT as EfficientViT  # 从 vit_pytorch 库中导入 EfficientViT 类

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数

# helpers

def exists(val):  # 定义 exists 函数,用于判断变量是否存在
    return val is not None  # 返回变量是否不为 None

# classes

class DistillMixin:  # 定义 DistillMixin 类
    def forward(self, img, distill_token = None):  # 定义 forward 方法,接收图像和 distill_token 参数
        distilling = exists(distill_token)  # 判断 distill_token 是否存在
        x = self.to_patch_embedding(img)  # 将图像转换为 patch embedding
        b, n, _ = x.shape  # 获取 x 的形状信息

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)  # 重复添加 cls_token
        x = torch.cat((cls_tokens, x), dim = 1)  # 在维度 1 上拼接 cls_tokens 和 x
        x += self.pos_embedding[:, :(n + 1)]  # 添加位置编码

        if distilling:  # 如果进行蒸馏
            distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)  # 重复添加 distill_token
            x = torch.cat((x, distill_tokens), dim = 1)  # 在维度 1 上拼接 x 和 distill_tokens

        x = self._attend(x)  # 调用 _attend 方法进行注意力计算

        if distilling:  # 如果进行蒸馏
            x, distill_tokens = x[:, :-1], x[:, -1]  # 分割出 distill_tokens

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]  # 计算平均值或取第一个值

        x = self.to_latent(x)  # 转换为 latent 表示
        out = self.mlp_head(x)  # 经过 MLP 头部处理得到输出

        if distilling:  # 如果进行蒸馏
            return out, distill_tokens  # 返回输出和 distill_tokens

        return out  # 返回输出

class DistillableViT(DistillMixin, ViT):  # 定义 DistillableViT 类,继承自 DistillMixin 和 ViT
    def __init__(self, *args, **kwargs):  # 初始化方法
        super(DistillableViT, self).__init__(*args, **kwargs)  # 调用父类的初始化方法
        self.args = args  # 保存参数
        self.kwargs = kwargs  # 保存关键字参数
        self.dim = kwargs['dim']  # 保存维度信息
        self.num_classes = kwargs['num_classes']  # 保存类别数

    def to_vit(self):  # 定义 to_vit 方法
        v = ViT(*self.args, **self.kwargs)  # 创建 ViT 对象
        v.load_state_dict(self.state_dict())  # 加载当前状态字典
        return v  # 返回 ViT 对象

    def _attend(self, x):  # 定义 _attend 方法
        x = self.dropout(x)  # 使用 dropout
        x = self.transformer(x)  # 经过 transformer 处理
        return x  # 返回处理后的结果

class DistillableT2TViT(DistillMixin, T2TViT):  # 定义 DistillableT2TViT 类,继承自 DistillMixin 和 T2TViT
    def __init__(self, *args, **kwargs):  # 初始化方法
        super(DistillableT2TViT, self).__init__(*args, **kwargs)  # 调用父类的初始化方法
        self.args = args  # 保存参数
        self.kwargs = kwargs  # 保存关键字参数
        self.dim = kwargs['dim']  # 保存维度信息
        self.num_classes = kwargs['num_classes']  # 保存类别数

    def to_vit(self):  # 定义 to_vit 方法
        v = T2TViT(*self.args, **self.kwargs)  # 创建 T2TViT 对象
        v.load_state_dict(self.state_dict())  # 加载当前状态字典
        return v  # 返回 T2TViT 对象

    def _attend(self, x):  # 定义 _attend 方法
        x = self.dropout(x)  # 使用 dropout
        x = self.transformer(x)  # 经过 transformer 处理
        return x  # 返回处理后的结果

class DistillableEfficientViT(DistillMixin, EfficientViT):  # 定义 DistillableEfficientViT 类,继承自 DistillMixin 和 EfficientViT
    def __init__(self, *args, **kwargs):  # 初始化方法
        super(DistillableEfficientViT, self).__init__(*args, **kwargs)  # 调用父类的初始化方法
        self.args = args  # 保存参数
        self.kwargs = kwargs  # 保存关键字参数
        self.dim = kwargs['dim']  # 保存维度信息
        self.num_classes = kwargs['num_classes']  # 保存类别数

    def to_vit(self):  # 定义 to_vit 方法
        v = EfficientViT(*self.args, **self.kwargs)  # 创建 EfficientViT 对象
        v.load_state_dict(self.state_dict())  # 加载当前状态字典
        return v  # 返回 EfficientViT 对象

    def _attend(self, x):  # 定义 _attend 方法
        return self.transformer(x)  # 经过 transformer 处理

# knowledge distillation wrapper

class DistillWrapper(nn.Module):  # 定义 DistillWrapper 类,继承自 nn.Module
    def __init__(  # 初始化方法
        self,
        *,
        teacher,  # 教师模型
        student,  # 学生模型
        temperature = 1.,  # 温度参数
        alpha = 0.5,  # alpha 参数
        hard = False  # 是否硬蒸馏
    ):
        super().__init__()  # 调用父类的初始化方法
        assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'  # 断言学生模型必须是视觉 transformer

        self.teacher = teacher  # 保存教师模型
        self.student = student  # 保存学生模型

        dim = student.dim  # 获取学生模型的维度信息
        num_classes = student.num_classes  # 获取学生模型的类别数
        self.temperature = temperature  # 保存温度参数
        self.alpha = alpha  # 保存 alpha 参数
        self.hard = hard  # 保存是否硬蒸馏

        self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))  # 创建蒸馏 token

        self.distill_mlp = nn.Sequential(  # 创建 MLP 处理蒸馏信息
            nn.LayerNorm(dim),  # LayerNorm 处理
            nn.Linear(dim, num_classes)  # 线性层处理
        )
    # 定义一个前向传播函数,接受输入图像、标签、温度和权重参数
    def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
        # 获取输入图像的批量大小
        b, *_ = img.shape
        # 如果 alpha 参数存在,则使用传入的值,否则使用类属性中的值
        alpha = alpha if exists(alpha) else self.alpha
        # 如果 temperature 参数存在,则使用传入的值,否则使用类属性中的值
        T = temperature if exists(temperature) else self.temperature

        # 在不计算梯度的情况下,通过教师模型获取教师网络的输出
        with torch.no_grad():
            teacher_logits = self.teacher(img)

        # 通过学生模型获取学生网络的输出和蒸馏 token
        student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
        # 通过蒸馏 token 获取蒸馏网络的输出
        distill_logits = self.distill_mlp(distill_tokens)

        # 计算学生网络的交叉熵损失
        loss = F.cross_entropy(student_logits, labels)

        # 如果不是硬蒸馏,则计算软蒸馏损失
        if not self.hard:
            distill_loss = F.kl_div(
                F.log_softmax(distill_logits / T, dim = -1),
                F.softmax(teacher_logits / T, dim = -1).detach(),
            reduction = 'batchmean')
            distill_loss *= T ** 2

        # 如果是硬蒸馏,则计算交叉熵损失
        else:
            teacher_labels = teacher_logits.argmax(dim = -1)
            distill_loss = F.cross_entropy(distill_logits, teacher_labels)

        # 返回加权损失值,结合了学生网络的损失和蒸馏损失
        return loss * (1 - alpha) + distill_loss * alpha

.\lucidrains\vit-pytorch\vit_pytorch\efficient.py

import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义一个函数,用于确保输入是一个元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义一个名为 ViT 的类,继承自 nn.Module
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
        super().__init__()
        image_size_h, image_size_w = pair(image_size)
        # 检查图像尺寸是否能被 patch 大小整除
        assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
        # 检查池化类型是否为 'cls' 或 'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
        patch_dim = channels * patch_size ** 2

        # 定义将图像切片转换为嵌入向量的序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置编码和类别标记
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        # 定义 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.transformer(x)

        # 根据池化类型选择不同的池化方式
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\es_vit.py

# 导入所需的库
import copy
import random
from functools import wraps, partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange, reduce, repeat

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, default):
    return val if exists(val) else default

# 单例装饰器,用于缓存结果
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance
            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# 获取模块所在设备
def get_module_device(module):
    return next(module.parameters()).device

# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

# 张量相关的辅助函数

# 对张量取对数
def log(t, eps = 1e-20):
    return torch.log(t + eps)

# 损失函数

# 视图损失函数
def view_loss_fn(
    teacher_logits,
    student_logits,
    teacher_temp,
    student_temp,
    centers,
    eps = 1e-20
):
    teacher_logits = teacher_logits.detach()
    student_probs = (student_logits / student_temp).softmax(dim = -1)
    teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
    return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()

# 区域损失函数
def region_loss_fn(
    teacher_logits,
    student_logits,
    teacher_latent,
    student_latent,
    teacher_temp,
    student_temp,
    centers,
    eps = 1e-20
):
    teacher_logits = teacher_logits.detach()
    student_probs = (student_logits / student_temp).softmax(dim = -1)
    teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)

    sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent)
    sim_indices = sim_matrix.max(dim = -1).indices
    sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1])
    max_sim_teacher_probs = teacher_probs.gather(1, sim_indices)

    return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()

# 数据增强工具

# 随机应用函数
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 指数移动平均

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP 类用于投影器和预测器

# L2范数
class L2Norm(nn.Module):
    def forward(self, x, eps = 1e-6):
        return F.normalize(x, dim = 1, eps = eps)

# 多层感知机
class MLP(nn.Module):
    def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
        super().__init__()

        layers = []
        dims = (dim, *((hidden_size,) * (num_layers - 1)))

        for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
            is_last = ind == (len(dims) - 1)

            layers.extend([
                nn.Linear(layer_dim_in, layer_dim_out),
                nn.GELU() if not is_last else nn.Identity()
            ])

        self.net = nn.Sequential(
            *layers,
            L2Norm(),
            nn.Linear(hidden_size, dim_out)
        )

    def forward(self, x):
        return self.net(x)

# 用于基础神经网络的包装类
# 将管理隐藏层输出的拦截
# 创建一个包装器类,用于将输入传递到投影器和预测器网络中
class NetWrapper(nn.Module):
    def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.view_projector = None
        self.region_projector = None
        self.projection_hidden_size = projection_hidden_size
        self.projection_num_layers = projection_num_layers
        self.output_dim = output_dim

        self.hidden = {}
        self.hook_registered = False

    # 查找指定的层
    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    # 钩子函数,用于获取隐藏层输出
    def _hook(self, _, input, output):
        device = input[0].device
        self.hidden[device] = output

    # 注册钩子函数
    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    # 获取视图投影器
    @singleton('view_projector')
    def _get_view_projector(self, hidden):
        dim = hidden.shape[1]
        projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取区域投影器
    @singleton('region_projector')
    def _get_region_projector(self, hidden):
        dim = hidden.shape[1]
        projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取嵌入向量
    def get_embedding(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        self.hidden.clear()
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    # 前向传播函数
    def forward(self, x, return_projection = True):
        region_latents = self.get_embedding(x)
        global_latent = reduce(region_latents, 'b c h w -> b c', 'mean')

        if not return_projection:
            return global_latent, region_latents

        view_projector = self._get_view_projector(global_latent)
        region_projector = self._get_region_projector(region_latents)

        region_latents = rearrange(region_latents, 'b c h w -> b (h w) c')

        return view_projector(global_latent), region_projector(region_latents), region_latents

# 主类
class EsViTTrainer(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_hidden_size = 256,
        num_classes_K = 65336,
        projection_layers = 4,
        student_temp = 0.9,
        teacher_temp = 0.04,
        local_upper_crop_scale = 0.4,
        global_lower_crop_scale = 0.5,
        moving_average_decay = 0.9,
        center_moving_average_decay = 0.9,
        augment_fn = None,
        augment_fn2 = None
    # 定义一个继承自父类的子类,初始化网络
    ):
        super().__init__()
        self.net = net

        # 默认的 BYOL 数据增强

        DEFAULT_AUG = torch.nn.Sequential(
            # 随机应用颜色抖动
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            # 随机转换为灰度图像
            T.RandomGrayscale(p=0.2),
            # 随机水平翻转
            T.RandomHorizontalFlip(),
            # 随机应用高斯模糊
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            # 归一化
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

        # 初始化两种数据增强方式
        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, DEFAULT_AUG)

        # 定义局部和全局裁剪
        self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
        self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))

        # 初始化学生编码器
        self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)

        # 初始化教师编码器和指数移动平均更新器
        self.teacher_encoder = None
        self.teacher_ema_updater = EMA(moving_average_decay)

        # 注册缓冲区,用于存储教师视图中心和区域中心
        self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K))
        self.register_buffer('last_teacher_view_centers',  torch.zeros(1, num_classes_K))

        self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K))
        self.register_buffer('last_teacher_region_centers',  torch.zeros(1, num_classes_K))

        # 初始化教师中心指数移动平均更新器
        self.teacher_centering_ema_updater = EMA(center_moving_average_decay)

        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        # 获取网络设备并将包装器设备设置为相同
        device = get_module_device(net)
        self.to(device)

        # 发送一个模拟图像张量以实例化单例参数
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))

    # 使用装饰器创建单例模式,获取教师编码器
    @singleton('teacher_encoder')
    def _get_teacher_encoder(self):
        teacher_encoder = copy.deepcopy(self.student_encoder)
        set_requires_grad(teacher_encoder, False)
        return teacher_encoder

    # 重置移动平均值
    def reset_moving_average(self):
        del self.teacher_encoder
        self.teacher_encoder = None

    # 更新移动平均值
    def update_moving_average(self):
        assert self.teacher_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

        new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers)
        self.teacher_view_centers.copy_(new_teacher_view_centers)

        new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers)
        self.teacher_region_centers.copy_(new_teacher_region_centers)

    # 前向传播函数
    def forward(
        self,
        x,
        return_embedding = False,
        return_projection = True,
        student_temp = None,
        teacher_temp = None
        ):
        # 如果需要返回嵌入向量,则调用学生编码器并返回结果
        if return_embedding:
            return self.student_encoder(x, return_projection = return_projection)

        # 对输入数据进行两种不同的数据增强
        image_one, image_two = self.augment1(x), self.augment2(x)

        # 对增强后的数据进行局部裁剪和全局裁剪
        local_image_one, local_image_two   = self.local_crop(image_one),  self.local_crop(image_two)
        global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)

        # 使用学生编码器对局部裁剪后的数据进行编码
        student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one)
        student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two)

        # 使用torch.no_grad()上下文管理器,获取教师编码器的结果
        with torch.no_grad():
            teacher_encoder = self._get_teacher_encoder()
            teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one)
            teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two)

        # 部分函数调用,设置视图级别损失函数和区域级别损失函数的参数
        view_loss_fn_ = partial(
            view_loss_fn,
            student_temp = default(student_temp, self.student_temp),
            teacher_temp = default(teacher_temp, self.teacher_temp),
            centers = self.teacher_view_centers
        )

        region_loss_fn_ = partial(
            region_loss_fn,
            student_temp = default(student_temp, self.student_temp),
            teacher_temp = default(teacher_temp, self.teacher_temp),
            centers = self.teacher_region_centers
        )

        # 计算视图级别损失
        teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0)
        self.last_teacher_view_centers.copy_(teacher_view_logits_avg)

        teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1))
        self.last_teacher_region_centers.copy_(teacher_region_logits_avg)

        view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \
                   + view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2

        # 计算区域级别损失
        region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \
                     + region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2

        # 返回视图级别损失和区域级别损失的平均值
        return (view_loss + region_loss) / 2

.\lucidrains\vit-pytorch\vit_pytorch\extractor.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t):
    return t

# 克隆并分离张量
def clone_and_detach(t):
    return t.clone().detach()

# 应用函数到元组或单个值
def apply_tuple_or_single(fn, val):
    if isinstance(val, tuple):
        return tuple(map(fn, val))
    return fn(val)

# 定义 Extractor 类,继承自 nn.Module
class Extractor(nn.Module):
    def __init__(
        self,
        vit,
        device = None,
        layer = None,
        layer_name = 'transformer',
        layer_save_input = False,
        return_embeddings_only = False,
        detach = True
    ):
        super().__init__()
        # 初始化属性
        self.vit = vit

        self.data = None
        self.latents = None
        self.hooks = []
        self.hook_registered = False
        self.ejected = False
        self.device = device

        self.layer = layer
        self.layer_name = layer_name
        self.layer_save_input = layer_save_input # 是否保存层的输入或输出
        self.return_embeddings_only = return_embeddings_only

        # 根据 detach 参数选择克隆并分离函数或返回输入值函数
        self.detach_fn = clone_and_detach if detach else identity

    # 钩子函数,用于提取特征
    def _hook(self, _, inputs, output):
        layer_output = inputs if self.layer_save_input else output
        self.latents = apply_tuple_or_single(self.detach_fn, layer_output)

    # 注册钩子函数
    def _register_hook(self):
        if not exists(self.layer):
            assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
            layer = getattr(self.vit, self.layer_name)
        else:
            layer = self.layer

        handle = layer.register_forward_hook(self._hook)
        self.hooks.append(handle)
        self.hook_registered = True

    # 弹出钩子函数
    def eject(self):
        self.ejected = True
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        return self.vit

    # 清除特征
    def clear(self):
        del self.latents
        self.latents = None

    # 前向传播函数
    def forward(
        self,
        img,
        return_embeddings_only = False
    ):
        assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
        self.clear()
        if not self.hook_registered:
            self._register_hook()

        pred = self.vit(img)

        target_device = self.device if exists(self.device) else img.device
        latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)

        if return_embeddings_only or self.return_embeddings_only:
            return latents

        return pred, latents

.\lucidrains\vit-pytorch\vit_pytorch\learnable_memory_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 控制层是否冻结的函数

# 设置模块参数是否需要梯度的函数
def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 冻结所有层的函数
def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# 解冻所有层的函数
def unfreeze_all_layers_(module):
    set_module_requires_grad_(module, True)

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, attn_mask = None, memories = None):
        x = self.norm(x)

        x_kv = x # input for key / values projection

        if exists(memories):
            # add memories to key / values if it is passed in
            memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
            x_kv = torch.cat((x_kv, memories), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if exists(attn_mask):
            dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x, attn_mask = None, memories = None):
        for ind, (attn, ff) in enumerate(self.layers):
            layer_memories = memories[ind] if exists(memories) else None

            x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
            x = ff(x) + x
        return x

# ViT ��
class ViT(nn.Module):
    # 初始化函数,设置模型参数和结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取patch的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被patch的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算patch的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个patch的维度
        patch_dim = channels * patch_height * patch_width
        # 断言池化类型只能是'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 定义将图像转换为patch嵌入的层
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化类别标记参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化dropout层
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 定义MLP头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 将图像转换为tokens
    def img_to_tokens(self, img):
        # 将图像转换为patch嵌入
        x = self.to_patch_embedding(img)

        # 重复类别标记,拼接到patch嵌入中
        cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
        x = torch.cat((cls_tokens, x), dim = 1)

        # 添加位置嵌入并进行dropout
        x += self.pos_embedding
        x = self.dropout(x)
        return x

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为tokens
        x = self.img_to_tokens(img)        

        # 使用transformer模型处理tokens
        x = self.transformer(x)

        # 获取类别标记的输出
        cls_tokens = x[:, 0]
        return self.mlp_head(cls_tokens)
# 适配器模块,具有每层可学习的记忆、记忆 CLS 标记和可学习的适配器头部

class Adapter(nn.Module):
    def __init__(
        self,
        *,
        vit,
        num_memories_per_layer = 10,
        num_classes = 2,   
    ):
        super().__init__()
        assert isinstance(vit, ViT)

        # 提取一些需要的模型变量

        dim = vit.cls_token.shape[-1]
        layers = len(vit.transformer.layers)
        num_patches = vit.pos_embedding.shape[-2]

        self.vit = vit

        # 冻结 ViT 主干 - 只有记忆会被微调

        freeze_all_layers_(vit)

        # 可学习的参数

        self.memory_cls_token = nn.Parameter(torch.randn(dim))
        self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

        # 专门的注意力掩码以保留原始 ViT 的输出
        # 它允许记忆 CLS 标记关注所有其他标记(和可学习的记忆层标记),但反之亦然

        attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
        attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False)  # 主要标记不能关注每层的可学习记忆
        attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True)                  # 记忆 CLS 标记可以关注所有内容
        self.register_buffer('attn_mask', attn_mask)

    def forward(self, img):
        b = img.shape[0]

        tokens = self.vit.img_to_tokens(img)

        # 添加任务特定的记忆标记

        memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
        tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)        

        # 通过变压器传递记忆以及图像标记进行关注

        out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)

        # 提取记忆 CLS 标记

        memory_cls_tokens = out[:, 0]

        # 通过任务特定的适配器头部传递

        return self.mlp_head(memory_cls_tokens)

.\lucidrains\vit-pytorch\vit_pytorch\levit.py

# 从 math 模块中导入 ceil 函数
from math import ceil

# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange 和 repeat 函数,以及 torch 子模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入值转换为元组的函数
def cast_tuple(val, l = 3):
    val = val if isinstance(val, tuple) else (val,)
    return (*val, *((val[-1],) * max(l - len(val), 0))

# 返回固定值的函数
def always(val):
    return lambda *args, **kwargs: val

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        inner_dim_key = dim_key *  heads
        inner_dim_value = dim_value *  heads
        dim_out = default(dim_out, dim)

        self.heads = heads
        self.scale = dim_key ** -0.5

        self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        out_batch_norm = nn.BatchNorm2d(dim_out)
        nn.init.zeros_(out_batch_norm.weight)

        self.to_out = nn.Sequential(
            nn.GELU(),
            nn.Conv2d(inner_dim_value, dim_out, 1),
            out_batch_norm,
            nn.Dropout(dropout)
        )

        # 位置偏置

        self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)

        q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
        k_range = torch.arange(fmap_size)

        q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
        k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)

        q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
        rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

        x_rel, y_rel = rel_pos.unbind(dim = -1)
        pos_indices = (x_rel * fmap_size) + y_rel

        self.register_buffer('pos_indices', pos_indices)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, 'i j h -> () h i j')
        return fmap + (bias / self.scale)

    def forward(self, x):
        b, n, *_, h = *x.shape, self.heads

        q = self.to_q(x)
        y = q.shape[2]

        qkv = (q, self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

class Transformer(nn.Module):
    # 初始化函数,设置模型参数和结构
    def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定输出维度,则默认为输入维度
        dim_out = default(dim_out, dim)
        # 初始化一个空的模块列表用于存储每个层
        self.layers = nn.ModuleList([])
        # 判断是否使用注意力机制的残差连接
        self.attn_residual = (not downsample) and dim == dim_out

        # 根据深度循环创建每个层
        for _ in range(depth):
            # 每个层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
                FeedForward(dim_out, mlp_mult, dropout = dropout)
            ]))
    
    # 前向传播函数,处理输入数据
    def forward(self, x):
        # 遍历每个层
        for attn, ff in self.layers:
            # 如果使用注意力机制的残差连接,则保存输入数据
            attn_res = (x if self.attn_residual else 0)
            # 经过注意力机制处理后,加上残差连接
            x = attn(x) + attn_res
            # 经过前馈神经网络处理后,加上残差连接
            x = ff(x) + x
        # 返回处理后的数据
        return x
# 定义 LeViT 类,继承自 nn.Module
class LeViT(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        image_size,  # 图像大小
        num_classes,  # 类别数量
        dim,  # 维度
        depth,  # 深度
        heads,  # 头数
        mlp_mult,  # MLP 倍数
        stages = 3,  # 阶段数,默认为 3
        dim_key = 32,  # 键维度,默认为 32
        dim_value = 64,  # 值维度,默认为 64
        dropout = 0.,  # Dropout,默认为 0
        num_distill_classes = None  # 蒸馏类别数量,默认为 None
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 将 dim、depth、heads 转换为元组
        dims = cast_tuple(dim, stages)
        depths = cast_tuple(depth, stages)
        layer_heads = cast_tuple(heads, stages)

        # 断言确保 dimensions、depths、heads 必须是小于指定阶段数的元组
        assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'

        # 定义卷积嵌入层
        self.conv_embedding = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
            nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
            nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
            nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
        )

        # 计算特征图大小
        fmap_size = image_size // (2 ** 4)
        layers = []

        # 遍历阶段,构建 Transformer 层
        for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
            is_last = ind == (stages - 1)
            layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))

            if not is_last:
                next_dim = dims[ind + 1]
                layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
                fmap_size = ceil(fmap_size / 2)

        # 构建骨干网络
        self.backbone = nn.Sequential(*layers)

        # 定义池化层
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...')
        )

        # 定义蒸馏头部
        self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
        # 定义 MLP 头部
        self.mlp_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, img):
        # 图像经过卷积嵌入层
        x = self.conv_embedding(img)

        # 特征图经过骨干网络
        x = self.backbone(x)        

        # 特征图经过池化层
        x = self.pool(x)

        # 输出结果经过 MLP 头部
        out = self.mlp_head(x)
        # 蒸馏结果经过蒸馏头部
        distill = self.distill_head(x)

        # 如果存在蒸馏结果,则返回输出结果和蒸馏结果
        if exists(distill):
            return out, distill

        # 否则只返回输出结果
        return out

.\lucidrains\vit-pytorch\vit_pytorch\local_vit.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn, einsum
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类

# classes

# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 定义 ExcludeCLS 类,继承自 nn.Module
class ExcludeCLS(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x, **kwargs):
        cls_token, x = x[:, :1], x[:, 1:]
        x = self.fn(x, **kwargs)
        return torch.cat((cls_token, x), dim = 1)

# feed forward related classes

# 定义 DepthWiseConv2d 类,继承自 nn.Module
class DepthWiseConv2d(nn.Module):
    # 初始化函数
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化函数
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Conv2d(dim, hidden_dim, 1),
            nn.Hardswish(),
            DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Conv2d(hidden_dim, dim, 1),
            nn.Dropout(dropout)
        )
    # 前向传播函数
    def forward(self, x):
        h = w = int(sqrt(x.shape[-2]))
        x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
        x = self.net(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x

# attention

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)
        return x

# main class

# 定义 LocalViT 类,继承自 nn.Module
class LocalViT(nn.Module):
    # 初始化函数,设置模型参数和层结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 检查图像尺寸是否能被分块尺寸整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算图像分块数量
        num_patches = (image_size // patch_size) ** 2
        # 计算每个分块的维度
        patch_dim = channels * patch_size ** 2

        # 定义将图像分块转换为嵌入向量的层序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 初始化位置编码参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化类别标记参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 定义 MLP 头部层序列
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为分块嵌入向量
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将类别标记和分块嵌入向量拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码
        x += self.pos_embedding[:, :(n + 1)]
        # 对结果进行丢弃
        x = self.dropout(x)

        # 使用 Transformer 进行特征变换
        x = self.transformer(x)

        # 返回 MLP 头部的输出结果
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\mae.py

# 导入 PyTorch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 einops 库中导入 repeat 函数
from einops import repeat

# 从 vit_pytorch.vit 模块中导入 Transformer 类
from vit_pytorch.vit import Transformer

# 定义一个名为 MAE 的 nn.Module 类
class MAE(nn.Module):
    # 初始化函数,接收一系列参数
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        masking_ratio = 0.75,
        decoder_depth = 1,
        decoder_heads = 8,
        decoder_dim_head = 64
    ):
        super().__init__()
        # 断言确保 masking_ratio 在 0 和 1 之间
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        # 将 masking_ratio 存储在对象中
        self.masking_ratio = masking_ratio

        # 从编码器中提取一些超参数和函数(待训练的视觉变换器)

        # 存储编码器对象
        self.encoder = encoder
        # 获取补丁数量和编码器维度
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        # 获取从图像到补丁的转换函数
        self.to_patch = encoder.to_patch_embedding[0]
        # 获取从补丁到嵌入的序列
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        # 获取每个补丁的像素值
        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # 解码器参数
        # 存储解码器维度
        self.decoder_dim = decoder_dim
        # 如果编码器维度与解码器维度不同,则使用 nn.Linear 进行映射,否则使用 nn.Identity
        self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        # 初始化一个可学习的遮罩令牌
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        # 创建一个 Transformer 解码器
        self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
        # 创建一个嵌入层用于解码器位置编码
        self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
        # 创建一个线性层用于将解码器输出映射回像素值
        self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
    # 定义一个前向传播函数,接收输入图像
    def forward(self, img):
        # 获取输入图像所在设备
        device = img.device

        # 获取图像的补丁

        patches = self.to_patch(img)
        batch, num_patches, *_ = patches.shape

        # 将补丁转换为编码器标记并添加位置信息

        tokens = self.patch_to_emb(patches)
        if self.encoder.pool == "cls":
            tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
        elif self.encoder.pool == "mean":
            tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype) 

        # 计算需要屏蔽的补丁数量,并获取随机索引,将其分为屏蔽和未屏蔽的部分

        num_masked = int(self.masking_ratio * num_patches)
        rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # 获取要编码的未屏蔽标记

        batch_range = torch.arange(batch, device=device)[:, None]
        tokens = tokens[batch_range, unmasked_indices]

        # 获取用于最终重建损失的要屏蔽的补丁

        masked_patches = patches[batch_range, masked_indices]

        # 使用视觉变换器进行注意力

        encoded_tokens = self.encoder.transformer(tokens)

        # 投影编码器到解码器维度,如果它们不相等 - 论文中说可以使用较小的维度进行解码器

        decoder_tokens = self.enc_to_dec(encoded_tokens)

        # 重新应用解码器位置嵌入到未屏蔽标记

        unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)

        # 重复屏蔽标记以匹配屏蔽数量,并使用上面得到的屏蔽索引添加位置

        mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch, n=num_masked)
        mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

        # 将屏蔽标记连接到解码器标记并使用解码器进行注意力

        decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
        decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
        decoder_tokens[batch_range, masked_indices] = mask_tokens
        decoded_tokens = self.decoder(decoder_tokens)

        # 剪切出屏蔽标记并投影到像素值

        mask_tokens = decoded_tokens[batch_range, masked_indices]
        pred_pixel_values = self.to_pixels(mask_tokens)

        # 计算重建损失

        recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
        return recon_loss

.\lucidrains\vit-pytorch\vit_pytorch\max_vit.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组,如果不是元组则重复多次
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 辅助类

# 残差连接
class Residual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# MBConv

# Squeeze-and-Excitation 模块
class SqueezeExcitation(nn.Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        self.gate = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# MBConv 残差块
class MBConvResidual(nn.Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 随机丢弃采样
class Dropsample(nn.Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        device = x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)

# MBConv 构建函数
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    net = nn.Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        nn.BatchNorm2d(dim_out)
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)

    return net

# 注意力相关类

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言维度应该能够被每个头的维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 计算头的数量
        self.heads = dim // dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5

        # LayerNorm 层
        self.norm = nn.LayerNorm(dim)
        # 线性变换,将输入维度转换为查询、键、值的维度
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),  # Softmax 激活函数
            nn.Dropout(dropout)  # Dropout 层
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),  # 线性变换
            nn.Dropout(dropout)  # Dropout 层
        )

        # 相对位置偏置

        # Embedding 层,用于存储相对位置偏置
        self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

        # 计算相对位置偏置
        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        # 注册缓冲区,存储相对位置索引
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x):
        # 获取输入张量的形状信息
        batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

        # LayerNorm 层
        x = self.norm(x)

        # 展开张量
        x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q = q * self.scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加相对位置偏置
        bias = self.rel_pos_bias(self.rel_pos_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力机制
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头
        out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)

        # 合并头的输出
        out = self.to_out(out)
        return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
# 定义一个名为 MaxViT 的神经网络模型类,继承自 nn.Module
class MaxViT(nn.Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        dim_head = 32,
        dim_conv_stem = None,
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1,
        channels = 3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言 depth 是一个元组,用于指定每个阶段的 transformer 块数量
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        # 卷积 stem

        # 如果未指定 dim_conv_stem,则设为 dim
        dim_conv_stem = default(dim_conv_stem, dim)

        # 定义卷积 stem
        self.conv_stem = nn.Sequential(
            nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
            nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
        )

        # 变量

        # 计算阶段数量
        num_stages = len(depth)

        # 计算每个阶段的维度
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (dim_conv_stem, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        # 初始化 layers 为一个空的 nn.ModuleList
        self.layers = nn.ModuleList([])

        # 为了高效的块状 - 网格状注意力,设置窗口大小的简写
        w = window_size

        # 遍历每个阶段
        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim

                # 定义一个块
                block = nn.Sequential(
                    MBConv(
                        stage_dim_in,
                        layer_dim,
                        downsample = is_first,
                        expansion_rate = mbconv_expansion_rate,
                        shrinkage_rate = mbconv_shrinkage_rate
                    ),
                    Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w),  # 块状注意力
                    Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
                    Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),

                    Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w),  # 网格状注���力
                    Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
                    Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
                )

                # 将块添加到 layers 中
                self.layers.append(block)

        # MLP 头部

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 经过卷积 stem
        x = self.conv_stem(x)

        # 遍历每个阶段的块
        for stage in self.layers:
            x = stage(x)

        # 经过 MLP 头部
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\max_vit_with_registers.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将单个元素打包成指定模式的数据
def pack_one(x, pattern):
    return pack([x], pattern)

# 将数据解包成单个元素
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# 将变量转换为元组,如果不是元组则重复多次
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length

# 辅助类

# 定义前馈神经网络结构
def FeedForward(dim, mult = 4, dropout = 0.):
    inner_dim = int(dim * mult)
    return Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim),
        nn.Dropout(dropout)
    )

# MBConv

# 定义Squeeze-and-Excitation模块
class SqueezeExcitation(Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        self.gate = Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# 定义MBConv残差模块
class MBConvResidual(Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 定义DropSample模块
class Dropsample(Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        device = x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)

# 定义MBConv模块
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    net = Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        nn.BatchNorm2d(dim_out)
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)

    return net

# 注意力相关类

# 定义注意力机制模块
class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7,
        num_registers = 1
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言寄存器数量大于0
        assert num_registers > 0
        # 断言维度应该可以被每个头的维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 计算头的数量
        self.heads = dim // dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5

        # LayerNorm层
        self.norm = nn.LayerNorm(dim)
        # 线性变换层,将输入维度转换为3倍的维度,用于计算Q、K、V
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),  # Softmax激活函数
            nn.Dropout(dropout)  # Dropout层
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),  # 线性变换层
            nn.Dropout(dropout)  # Dropout层
        )

        # 相对位置偏差

        # 计算相对位置偏差的数量
        num_rel_pos_bias = (2 * window_size - 1) ** 2

        # Embedding层,用于存储相对位置偏差
        self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)

        # 生成相对位置偏差的索引
        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        # 对相对位置偏差索引进行填充
        rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x):
        # 获取设备信息、头的数量、相对位置偏差索引
        device, h, bias_indices = x.device, self.heads, self.rel_pos_indices

        # LayerNorm层
        x = self.norm(x)

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q = q * self.scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加位置偏差
        bias = self.rel_pos_bias(bias_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力机制
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部输出
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class MaxViT(Module):
    # 定义一个名为 MaxViT 的类,继承自 Module 类
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        dim_head = 32,
        dim_conv_stem = None,
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1,
        channels = 3,
        num_register_tokens = 4
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数

        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
        assert num_register_tokens > 0
        # 断言语句,确保 depth 是元组类型,num_register_tokens 大于 0

        # convolutional stem

        dim_conv_stem = default(dim_conv_stem, dim)
        # 如果 dim_conv_stem 为 None,则设置为 dim

        self.conv_stem = Sequential(
            nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
            nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
        )
        # 创建一个包含两个卷积层的 Sequential 对象,作为卷积部分的网络结构

        # variables

        num_stages = len(depth)
        # 计算 depth 的长度,作为阶段数

        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (dim_conv_stem, *dims)
        # 计算每个阶段的维度

        dim_pairs = tuple(zip(dims[:-1], dims[1:]))
        # 将维度组成成对

        self.layers = nn.ModuleList([])
        # 创建一个空的 ModuleList 对象用于存储网络层

        # window size

        self.window_size = window_size
        # 设置窗口大小

        self.register_tokens = nn.ParameterList([])
        # 创建一个空的 ParameterList 对象用于存储注册令牌

        # iterate through stages

        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            # 遍历每个阶段
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim
                # 判断是否为当前阶段的第一个块

                conv = MBConv(
                    stage_dim_in,
                    layer_dim,
                    downsample = is_first,
                    expansion_rate = mbconv_expansion_rate,
                    shrinkage_rate = mbconv_shrinkage_rate
                )
                # 创建一个 MBConv 对象

                block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
                block_ff = FeedForward(dim = layer_dim, dropout = dropout)
                # 创建注意力和前馈网络对象

                grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
                grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
                # 创建注意力和前馈网络对象

                register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
                # 创建一个随机初始化的注册令牌

                self.layers.append(ModuleList([
                    conv,
                    ModuleList([block_attn, block_ff]),
                    ModuleList([grid_attn, grid_ff])
                ]))
                # 将卷积层、注意力和前馈网络组成的模块列表添加到网络层中

                self.register_tokens.append(register_tokens)
                # 将注册令牌添加到注册令牌列表中

        # mlp head out

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )
        # 创建一个线性层用于分类
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的批量大小 b 和窗口大小 w
        b, w = x.shape[0], self.window_size

        # 对输入张量 x 进行卷积操作
        x = self.conv_stem(x)

        # 遍历每个层的操作,包括卷积、注意力机制和前馈网络
        for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
            # 对输入张量 x 进行卷积操作
            x = conv(x)

            # block-like attention

            # 对输入张量 x 进行重新排列操作,将其转换为多维矩阵
            x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌

            # 将注册令牌进行重复操作,以匹配输入张量 x 的形状
            r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入张量 x 进行块状注意力操作,并与原始输入相加
            x = block_attn(x) + x
            # 对输入张量 x 进行块状前馈网络操作,并与原始输入相加
            x = block_ff(x) + x

            r, x = unpack(x, register_ps, 'b * d')

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')

            r = unpack_one(r, register_batch_ps, '* n d')

            # grid-like attention

            # 对输入张量 x 进行重新排列操作,将其转换为多维矩阵
            x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌

            # 对注册令牌进行降维操作,计算均值
            r = reduce(r, 'b x y n d -> b n d', 'mean')
            r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入张量 x 进行网格状注意力操作,并与原始输入相加
            x = grid_attn(x) + x

            r, x = unpack(x, register_ps, 'b * d')

            # 对输入张量 x 进行网格状前馈网络操作,并与��始输入相加
            x = grid_ff(x) + x

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')

        # 返回经过 MLP 头部处理后的结果
        return self.mlp_head(x)
posted @ 2024-06-28 14:12  绝不原创的飞龙  阅读(99)  评论(0)    收藏  举报