论文:> https://arxiv.org/abs/2110.04627 >

参考代码:> https://github.com/thuanz123/enhancing-transformers >

原理部分参考:> https://zhuanlan.zhihu.com/p/611689477 > > https://zhuanlan.zhihu.com/p/1937234761945949010 > > https://juejin.cn/post/7416899377878630427 >

一、ViT-VQGAN简介

ViT-VQGAN的框架从现在来看已然比较泛用,但实际情况下常常因为开销较大,人们更倾向于使用。框架上来看,ViT最大的改动就在于对encoder和decoder创新性地使用了transformer做feature extraction和相应的loss的修改。

ViT-VQGAN的主要贡献如下:

1.stage1(image quantization ViT-VQGAN):

基于ViT的VQGAN encoder。基于VQGAN做了从架构到码本学习方式的多种改进——>提升了efficiency和reconstruction fidelity.

包括logits-laplace loss,L2 loss,adversarial loss 和 perceptual loss.

2.stage2(vector-quantized image modeling VIM):

学习了一个自回归的transformer,包括无条件生成/类条件生成/无监督表示学习。

为了评估无监督学习质量,对中间transformer特征进行平均,学习一个linear head(linear-probe)来预测类的logits。

ViT-VQGAN的方法框架图如下所示。

image

我们展开来看:

Stage 1 VQGAN with Vision Transformers

将VQGAN中的CNN部分,即encoder和decoder替换为ViT。VQVAE和VQGAN核心网络架构为CNN,VQGAN以non-local attention的形式引入了transformer元素,使其能够以较少的层捕获远距离交互。

首先重新了解一下CNN,也就是卷积神经网络。CNN的设计基于以下几个关键观察:

局部性:图像中相邻像素关系更紧密。

平移不变性:同一个物体无论出现在图像哪个位置,都应该被识别为同一类。

层次化特征提取:从低级特征(边缘、角点)到高级特征(人脸、汽车)逐层抽象。

image

但同时对图像的这些“先验知识”(也叫做归纳偏置,inductive bias)也顺势束缚了一些模型能力:

长距离依赖建模需深层堆叠或复杂设计(如空洞卷积、ASPP),效率较低。

计算模式不规则,不利于现代加速器(GPU/TPU)的极致并行化。CNN的卷积操作涉及复杂的局部相关性计算,需要结合空间局部性和平移等变性原则。这种特性导致计算流程呈现非规则性,例如特征图的梯度计算依赖于特定卷积核的排列方式,难以通过简单的矩阵运算实现完全并行化。

不难看出,CNN的架构扩展性受限,难以像Transformer那样通过简单堆叠实现性能持续提升。

下图给出了机器学习中主要的神经网络结构背后的inductive bias是什么。可以看到,全连接网络的inductive bias是最轻微的,它就是假设所有的单元都可能会有联系;卷积则是假设数据的特征具有局部性和平移不变性,循环神经网络则是假设数据具有序列相关性和时序不变性,而图神经网络则是假设节点的特征的聚合方式是一致的。总之,网络的结构本身就包含了设计者的假设和偏好,这就是归纳偏置。

image

为了解决这个问题,有一些对CNN设计的改进来降低这种偏置的影响:

1.卷积神经网络和自注意力混合,例如CNN提取特征图之后,在特征图上进行计算自注意力。

2.全用自注意力 stand-alone self-attention 对图像中的窗口做自注意力,也是像CNN一样一块一块取。

但这些不仅增加了CNN的结构复杂度,也不利于进一步对整体进行优化。

Transformer在NLP中展现出惊人潜力,最重要的一点就是性能随模型规模和数据量持续提升(无明显饱和),另外还有计算高度并行化,适合 GPU/TPU 大规模训练;架构简洁统一,易于扩展和迁移等特点,因此不难想到Transformer在CV上的转移应用。

为了成功应用,需要考虑Transformer的几个问题:

一、Transformer中的自注意力计算复杂度为 \(O(N^2)\),不能直接进行像素级细节的建模。但这一点我们可以参考原本在NLP中token的处理方式,不是对句子中的每一个字划分为单一token,而是先进行分词做embedding,然后再输入进去。对于图像来说,一个容易想到的方法就是对图像切块成多个patch,这样1个patch就是1个token,就能套用NLP的范式。但这不意味着就能完美解决计算复杂度这一根本问题,笔者在做ViT的复现时注意到,在图像方面之前使用Transformer的工作大都还是停留在256这一尺度上,这一特殊性主要是由patch_size=16决定的,每个patch都要和其他所有patch计算注意力得到全局自注意力,这个量在256上就已经是\(1024^2\)。因此还是需要对Transformer本身动刀。

(1)分层结构:如Swin Transformer使用滑动窗口机制,在局部窗口内计算注意力,降低全局计算量。简单来说就是对总体patch再分patch(即window),那么这就要对自注意力头进行改造,让他在窗口内计算patch之间的自注意力,再通过滑动窗口的方式弥补窗口之间的自注意力互动。这个方法将计算复杂度从\(O(L^2)\)降低到了\(O(L)\)。另外有一个技巧就是借助了CNN的分层提取特征的结构,对四种分辨率尺度下的输入分别应用Swin Block去提取特征;

(2)稀疏注意力(Sparse Attention):只关注局部邻域或固定位置(如 axial attention,其原理在于分别在行和列方向上进行注意力而非对整体做注意力,这样同样将复杂度降低为O(L))具体的一些方法参考原文> https://arxiv.org/pdf/1904.10509 > ,这里不做展开;

(3)线性注意力 / 高效注意力(Efficient Attention):使用数学近似将\(O(L^2)\)降为\(O(L)\),如 Linformer(对注意力矩阵用低秩投影逼近)、Performer(随机傅里叶特征映射)。(FlashAttention是对注意力在GPU上的运算做了分配优化提高了计算效率,只能算作)

二、原始的ViT只能进行相对固定分辨率的特征输出,无法支持目标检测、语义分割等密集预测任务(需要多尺度特征图)。因此像Swin Transformer就去考虑了下采样+多尺度特征的这种方式来做文章,还有一些方法如PVT、T2T-ViT也是类似参考了这种缩短token序列长度的思想。

在解决了第一个问题之后,ViT便横空出世。尽管效果上看相当出色,但论文的实验结果上是有一个需要关注的点的,即在小数据集上,ViT < ResNet:

image

上图是数据集大小(左)与模型内在效果(右)的对比效果。灰色部分是BiT网络(基于ResNet)的效果范围,最底下是ResNet50,顶上是ResNet152。
可以看到使用ImageNet数据集时候,ViT的效果是比不上BiT的,随着数据集增大到ImageNet-21k,ViT的表现已经在BiT表现的范围之内了。而继续增大数据集到JFT-300M,ViT可以实现对BiT的效果碾压。所以如果想训练一个好用的ViT,你至少要保证数据集的大小能达到ImageNet-21k,不然是不如同类型CNN的模型的。

因为训练过程中用了dropout.weight decay、label smoothing等。为了证明不是因为这些策略,而是Transformer本身的效果就很好,作者在不同规模的JFT-300M的子集上进行训练,对模型用ImageNet做5-shot,用这些模型特征抽取之后做Logistics Regression。结果也可以看出训练集很小的时候,ViT效果是不如ResNet的,但是随着模型增大,ViT的表现逐渐超越ResNet,这证明了基于Transformer的模型效果确实就是好。另外作者也认为训练好的ViT更适合小样本任务。

image

上图是算力指标的对比效果。左侧是五个数据集都测试然后取平均值,右侧是只拿ImageNet测试。数据集都是在JFT-300M上训练的。可以看到在同样的计算复杂度下,Transformer的表现都优于ResNet,所以证明了训练ViT是真的经济实惠好用。有趣的是混合模型,也就是橙色的点,用CNN特征图作为Transformer输入的部分。可以看到计算复杂度低的时候,混合模型的效果是最好的,但是随着计算复杂度增加,混合模型逐渐和ViT持平。

这里的原因就回到了之前的疑问。ViT的归纳偏置是什么呢?

其实我们在介绍ViT的引入问题时就介绍过。如果说ViT相比于卷积,在图像任务上没有显著优势,那大概率就是指ViT对这两种先验的维护没有CNN做的好,来看ViT的模型结构:

image

图中箭头所指的两部分都属于同一栋建筑。在卷积中,我们可以用大小适当的卷积核将它们圈在一起。但是在ViT中,它们之间的位置却拉远了,如果我把patch再切分细一些,它们的距离就更远了。虽然attention可以学习到向量间的关系,但是ViT在空间局部性的维护上,确实没有卷积做的好。而在平移等边性上,由于ViT需要对patch的位置进行学习,所以对于一个patch,当它位置变换时,它的输出结果也是不一样的。所以,ViT的架构没有很好维护图像问题中的归纳偏置假设。

但是,Transformer架构的模型与实验已经告诉我们:大力出奇迹。只要它见过的数据够多,它就能更好地学习像素块之间的关联性,当然也能抹去归纳偏置的问题。

深入了结构许多,至少我们弄清楚了Transformer替换CNN是有迹可循的。因此,我们先过一遍ViT的数据流(图片参考ViT的模型结构):

step1:224×224×3的图像输入,按Batch走就是[B,224,224,3],分成了14×14个patch,即一个patch是16×16×3,那么embedding的大小就是768。通过FC层将其展开成一个序列,长度就是14×14=196,即拉伸为[B,196,768]。

step2:分隔开的patches是有顺序的,所以跟NLP中一样为其添加位置编码position embedding。位置编码pos_emb∈\(R^{1×768}\),与特征x相连(concat)后得到[B,197,768]。同时和BERT一样,添加extra learnable embedding,也就是token作为分类头。即引入一个可学习的[CLS]向量c∈\(R^{768}\),并拼接在序列最前面,最终得到一个可学习的位置编码learnable_pos_emb∈\(R^{1x197×768}\),最后相加即可 x = x + learnable_pos_emb , \(x∈R^{B×197×768}\)

step3:进入第一个 Transformer Block。每个 Transformer Block 包含两个子模块:Multi-Head Self-Attention(MSA)和MLP(前馈网络),并使用 Pre-LN 结构(即LayerNorm + MSA + LayerNorm + MLP),最终输出[B, 197, 768]。纵深这样的流程到最后一个Transformer Block,最后分类输出[B,1000]。至此前向传播的流程结束。

step4:反向传播使用的是交叉熵损失 loss = F.cross_entropy(logits, labels)。logits就是MLP-head得到的分类输出。

这里还有一张数据流图片可供参考:

image

那么以VQGAN为范式,ViT-VQGAN的Stage1图像编码与量化过程的数据流如下:

step1.224×224×3的图像输入被分割为16×16大小的patches,形成14×14=196个图像块,每个图像块被展平为向量并通过线性投影层转换为768维嵌入向量。

step2.将196个patch嵌入向量与可学习的位置编码相加,将这些嵌入输入到多层Vision Transformer编码器中处理,Transformer编码器通过自注意力机制捕获全局依赖关系,输出相同数量的特征向量[B,196,768]

step3.将Transformer输出的特征向量通过一个投影层映射到与码本维度匹配的空间,如果码本也是768的维度,则仍然是[B,196,768],否则就经过一个池化层或是插值卷积层。使用码本大小为8192的向量量化操作(码本大小延续了VQ-VAE的设定)。对每个特征向量,计算其与码本中所有向量的欧氏距离,选择最近的码字索引,量化过程可表示为:\(z_{q(x)} = e_q(argmin_j ||z(x)-e_j||_2)\),其中\(e_j\)是码本向量,通过码本生成768维的离散整数网格,总共有196个离散标记,即[B,196]。

step4.将量化的向量\(z_{q(x)}\)通过解码器(通常是转置卷积网络)重建为原始图像尺寸。解码器将离散表示转换回224×224的RGB图像。训练目标包括:像素级重建损失、感知损失、码本损失和对抗损失。通过对抗训练提高重建图像的视觉质量。[B,196]->[B,196,768]->[B,196,768]->[B,224,224,3]

其中,Step2没有使用CLS是因为Stage1的目标在于重构提取空间结构化的latent表示,这种分类头的训练是没有必要的。

特别地,我们来看一下VQGAN中的重要部分VQ的代码实现:

BaseQuantize是一个通用的向量量化模块基类,它所继承的父类nn.Module是 PyTorch 深度学习框架中最核心、最基础的类之一,它是所有神经网络模块(layers)、模型(models)和可学习组件的基类,通常是用来进行这些基础部分定义的地方。主要用来支持以下特性:1.向量量化:使用 nearest neighbor 查找码本中最匹配的向量;2.归一化:支持 L2 归一化(适用于球面码本);3.直通估计:允许梯度流过不可导的量化操作;4.残差量化:多级量化,提高表达能力(如 RVQ)5:可扩展:子类需实现quantize()方法(如普通 VQ、Gumbel VQ、RQ 等)

点击查看VQ-BaseQuantize代码
# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
# ------------------------------------------------------------------------------------

import math
from functools import partial
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class BaseQuantizer(nn.Module):
    def __init__(self, embed_dim: int, n_embed: int, straight_through: bool = True, use_norm: bool = True,
                 use_residual: bool = False, num_quantizers: Optional[int] = None) -> None:
        # embed_dim: 每个嵌入向量的维度(即特征大小)。
        # n_embed: 嵌入码本(codebook)中向量的数量(即“词典”大小)。
        # straight_through: 是否使用直通估计(Straight-Through Estimator),用于在反向传播时保留梯度。
        # use_norm: 是否对输入和嵌入进行 L2 归一化(常见于 RQ 或 VQ-VAE 中)。
        # use_residual: 是否使用残差量化(如 Residual Vector Quantization, RVQ)。
        # num_quantizers: 如果使用残差量化,表示要堆叠多少层量化器。
        super().__init__()
        self.straight_through = straight_through
        self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x
        # norm: 一个可调用函数,根据 use_norm 决定是否对张量做 L2 归一化(沿最后一个维度)。这有助于稳定训练,特别是在使用余弦相似度时。

        self.use_residual = use_residual
        self.num_quantizers = num_quantizers

        self.embed_dim = embed_dim
        self.n_embed = n_embed

        self.embedding = nn.Embedding(self.n_embed, self.embed_dim)
        self.embedding.weight.data.normal_()
        # 初始化码本权重为标准正态分布(N(0,1))。实际应用中有时也会用均匀初始化或 Xavier 初始化。
        
    # 这是一个抽象方法(虽然没用 abc 模块强制),子类必须实现它。它的作用是:输入连续向量 z,找到最接近的码本向量(最近邻)。返回:z_q: 量化后的向量(从码本中查表得到),loss: 量化损失(例如 commitment loss),encoding_indices: 最近邻码本向量的索引(整数)
    def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        pass
    
    def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        if not self.use_residual:
            z_q, loss, encoding_indices = self.quantize(z)
        else:
            # 残差量化(RVQ)。目标:逐步逼近原始输入
            z_q = torch.zeros_like(z)
            residual = z.detach().clone()

            losses = []
            encoding_indices = []

            for _ in range(self.num_quantizers):
                z_qi, loss, indices = self.quantize(residual.clone())
                residual.sub_(z_qi)
                z_q.add_(z_qi)

                encoding_indices.append(indices)
                losses.append(loss)

            losses, encoding_indices = map(partial(torch.stack, dim = -1), (losses, encoding_indices)) # 把每轮的 loss 和 indices 在最后一维堆叠起来,变成 shape 为 [B, ..., num_quantizers] 的张量。
            loss = losses.mean()

        # preserve gradients with straight-through estimator,让网络可以训练,尽管中间有非可导的 argmin/nearest neighbor 操作。
        if self.straight_through:
            z_q = z + (z_q - z).detach()

        return z_q, loss, encoding_indices

VectorQuantize主要定义了VQ的具体quantize方法。
特别地,其中的距离矩阵公式为\(\| \mathbf{z}_i - \mathbf{e}_j \|^2 = \| \mathbf{z}_i \|^2 + \| \mathbf{e}_j \|^2 - 2 \mathbf{z}_i^\top \mathbf{e}_j\),其中torch.einsum('bd, nd -> bn') → \(z_i^⊤e_j\) 这个einsum方法通过对矩阵z_reshaped_norm, embedding_norm分别规定两个方向的定义,即z_reshaped_norm -> [b,d] , embedding_norm -> [n,d], 并进行归一化展平,即B个归一化的输入向量(展平后),记作\(z_i​∈R^D\);令一个类似记作\(e_j​∈R^D\)。'-> b n'表示输出只保留 b 和 n 维度,d 维度被“收缩”,即沿着 d 求和。 这也就是说,该方法等价于实现了两个向量的点积,类似的方法可以用>similarity = torch.matmul(z_reshaped_norm, embedding_norm.t()) # [B, D] @ [D, N] -> [B, N]similarity = z_reshaped_norm @ embedding_norm.t(),einsum 的优势在于:更直观表达“哪些维度参与运算”且易于扩展到高维或复杂操作(如注意力机制中的 b h q d, b h k d -> b h q k)。最终得到一个 (BHW, n_embed) 的距离矩阵 \(d[i,j] = dist(z_i, e_j)\)。这种方式避免了显式循环,完全向量化,效率高。

点击查看VectorQuantize代码

# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
# ------------------------------------------------------------------------------------

class VectorQuantizer(BaseQuantizer):
    def __init__(self, embed_dim: int, n_embed: int, beta: float = 0.25, use_norm: bool = True,
                 use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None:
        # beta: 权衡“码本损失”与“承诺损失”的超参数(常见于 VQ-VAE),**kwargs: 允许接收额外参数(可能用于兼容未来扩展)
        # 调用父类构造函数,并固定 straight_through=True(总是使用直通估计)。
        super().__init__(embed_dim, n_embed, True,
                         use_norm, use_residual, num_quantizers)
        
        self.beta = beta

    def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        # 第一步:展平 + 归一化, 把多维输入压平成二维:(batch_size * height * width, embed_dim),方便批量计算距离。当使用余弦相似度或内积作为匹配准则时,L2 归一化可以让比较更关注方向而非模长,提升稳定性。
        z_reshaped_norm = self.norm(z.view(-1, self.embed_dim))
        embedding_norm = self.norm(self.embedding.weight)
        
        # 第二步:计算距离矩阵:
        d = torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) + \
            torch.sum(embedding_norm ** 2, dim=1) - 2 * \
            torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm)
 
        # 第三步:查找最近邻(编码索引),.unsqueeze(1)用于在指定维度插入一个大小为 1 的新维度,主要是强调这是一个“索引序列”,而非标量集合,有一种说法是以便批处理多个码本(如 RVQ),即使当前只用一个。
        encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        encoding_indices = encoding_indices.view(*z.shape[:-1])
        
        z_q = self.embedding(encoding_indices).view(z.shape)
        z_qnorm, z_norm = self.norm(z_q), self.norm(z)
        
        # compute loss for embedding , 对偶学习机制:编码器学会选择合适的码字,码本学会更好表达这些码字
        loss = self.beta * torch.mean((z_qnorm.detach() - z_norm)**2) +  \
               torch.mean((z_qnorm - z_norm.detach())**2)

        return z_qnorm, loss, encoding_indices

GumbelQuantizer是一个基于 Gumbel-Softmax 技术的可微向量量化器,用于在训练时允许梯度流过离散选择过程。不需要直通估计,而是通过“软分布 + 可微采样”实现可导。

点击查看GumbelQuantizer代码
# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
# ------------------------------------------------------------------------------------
class GumbelQuantizer(BaseQuantizer):
    def __init__(self, embed_dim: int, n_embed: int, temp_init: float = 1.0,
                 use_norm: bool = True, use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None:
        super().__init__(embed_dim, n_embed, False,
                         use_norm, use_residual, num_quantizers)
        # 保存初始温度。温度越高 → 输出越接近均匀分布;温度越低 → 越接近 one-hot(趋近于 argmax)。
        self.temperature = temp_init
        
    def quantize(self, z: torch.FloatTensor, temp: Optional[float] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        # force hard = True when we are in eval mode, as we must quantize
        # hard:训练时:hard=False → 使用 soft 分布(可导),推理/验证时:hard=True → 使用 hard one-hot(真正离散化)
        hard = not self.training
        # 训练初期用高温度(更随机),后期降低温度 → 更确定性地选择码字。
        temp = self.temperature if temp is None else temp
        
        z_reshaped_norm = self.norm(z.view(-1, self.embed_dim))
        embedding_norm = self.norm(self.embedding.weight)

        # logits是负距离,最终得到一个 shape 为 (N, n_embed) 的 logits 矩阵,每一行是每个位置上各码字的匹配分数。
        logits = - torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) - \
                 torch.sum(embedding_norm ** 2, dim=1) + 2 * \
                 torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm)
        logits =  logits.view(*z.shape[:-1], -1)
        
        # 核心步骤:生成一个“近似 one-hot”的分布向量,使得我们可以:在 forward 中做出类似“选择某个码字”的行为、在 backward 中让梯度穿过这个“选择”操作
        # 当 hard=True 时,其实就是选中了一个码本向量。即使 hard=True,PyTorch 的 gumbel_softmax 也会保持梯度连续性(模拟 STE 行为)
        soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=-1, hard=hard)
        z_qnorm = torch.matmul(soft_one_hot, embedding_norm) #加权求和得到量化向量
        
        # kl divergence to the prior loss 均匀先验下的KL散度,损失只影响码本和编码器,不涉及解码器部分。
        logits =  F.log_softmax(logits, dim=-1) # use log_softmax because it is more numerically stable
        loss = torch.sum(logits.exp() * (logits+math.log(self.n_embed)), dim=-1).mean()
               
        # get encoding via argmax,返回每个位置上得分最高的码本索引
        encoding_indices = soft_one_hot.argmax(dim=-1)
        
        return z_qnorm, loss, encoding_indices

该部分具体展示了前向传播的框架搭建,主要需要理解2D和1D位置编码、Transformer的实现路径,以及Encoder-Decoder的具体架构

点击查看layers部分代码
# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
# ------------------------------------------------------------------------------------

# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch)
# Copyright (c) 2020 Phil Wang. All Rights Reserved.
# ------------------------------------------------------------------------------------

import math
import numpy as np
from typing import Union, Tuple, List
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def get_2d_sincos_pos_embed(embed_dim, grid_size):
    """
    grid_size: int or (int, int) of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size
    grid_h = np.arange(grid_size[0], dtype=np.float32)
    grid_w = np.arange(grid_size[1], dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def init_weights(m):
    if isinstance(m, nn.Linear):
        # we use xavier_uniform following official JAX ViT:
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        w = m.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))


class PreNorm(nn.Module):
    def __init__(self, dim: int, fn: nn.Module) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int) -> None:
        super().__init__()
        # 值得注意的是,除了Tanh这一能保持在[-1,1]的连续值输出的激活函数以外,还有GeLU、ReLU、Swish(SiLU)等可作替换,其中ReLU存在死亡神经元的问题,需要谨慎使用
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5  # 缩放因子,防止点积过大导致 softmax 梯度消失

        self.attend = nn.Softmax(dim = -1) # 对最后一个维度(即 key 维度)做 softmax,得到注意力权重(概率分布)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() #nn.Identity()表示恒等映射

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # .chunk(3, dim=-1) 沿最后一维切成三块:大小均为[B, N, inner_dim],返回一个元组 (q, k, v)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) #q.shape = [B, heads, N, dim_head]

        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale #q @ k^T /sqrt(dim_head)
        attn = self.attend(attn)
 
        out = torch.matmul(attn, v) # 每个位置的输出是 value 的加权平均,权重来自 attention 分布
        out = rearrange(out, 'b h n d -> b n (h d)') # 将多头结果重新拼接回单一向量

        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None:
        super().__init__()
        self.layers = nn.ModuleList([])
        for idx in range(depth):
            layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)),
                                   PreNorm(dim, FeedForward(dim, mlp_dim))])
            self.layers.append(layer)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        for attn, ff in self.layers: # 从layer封装中取出Attention与FeedForward
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)


class ViTEncoder(nn.Module):
    def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
                 dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
        super().__init__()
        image_height, image_width = image_size if isinstance(image_size, tuple) \
                                    else (image_size, image_size)
        patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
                                    else (patch_size, patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))

        self.num_patches = (image_height // patch_height) * (image_width // patch_width)
        self.patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.apply(init_weights)

    def forward(self, img: torch.FloatTensor) -> torch.FloatTensor:
        x = self.to_patch_embedding(img)
        x = x + self.en_pos_embedding
        x = self.transformer(x)

        return x


class ViTDecoder(nn.Module):
    def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
                 dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
        super().__init__()
        image_height, image_width = image_size if isinstance(image_size, tuple) \
                                    else (image_size, image_size)
        patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
                                    else (patch_size, patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))

        self.num_patches = (image_height // patch_height) * (image_width // patch_width)
        self.patch_dim = channels * patch_height * patch_width

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
        self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False)
        self.to_pixel = nn.Sequential(
            Rearrange('b (h w) c -> b c h w', h=image_height // patch_height),
            nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size)
        )

        self.apply(init_weights)

    def forward(self, token: torch.FloatTensor) -> torch.FloatTensor:
        x = token + self.de_pos_embedding
        x = self.transformer(x)
        x = self.to_pixel(x)

        return x

    def get_last_layer(self) -> nn.Parameter:
        return self.to_pixel[-1].weight

剩余的训练部署部分代码留到之后统一介绍,我们只需要了解loss部分即可。
我们看如下model部分的代码,其中的xrec与qloss分别是前向传播返回的dec和diff,而diff是encode中的emb_loss返回的,也就是分别对应重构图像向量x_reconstruction和向量量化loss。这两个部分就可以构成loss文件中total_loss的计算了:
nll_loss = self.loglaplace_weight * loglaplace_loss + self.loggaussian_weight * loggaussian_loss + self.perceptual_weight * perceptual_loss
其中,
\begin{align} L = L_{VQ} + L_{\mathrm{logit-laplace}} + L_{\mathrm{adv}} + L_{\mathrm{perceptual}} + L_{\mathrm{2}} \tag{2} \end{align}

点击查看stage1_model代码
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:    
        quant, diff = self.encode(x)
        dec = self.decode(quant)
        
        return dec, diff

def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        h = self.encoder(x)
        h = self.pre_quant(h)
        quant, emb_loss, _ = self.quantizer(h)
        
        return quant, emb_loss

def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
        x = self.get_input(batch, self.image_key)
        xrec, qloss = self(x)

        if optimizer_idx == 0:
            # autoencoder
            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
                                            last_layer=self.decoder.get_last_layer(), split="train")  #self.loss = initialize_from_config(loss)
            # loss:
            # target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator
            # params:
            #    loglaplace_weight: 0.0
            #    loggaussian_weight: 1.0
            #    perceptual_weight: 0.1
            #    adversarial_weight: 0.1

            self.log("train/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            del log_dict_ae["train/total_loss"]
            
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)

            return aeloss

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
                                                last_layer=self.decoder.get_last_layer(), split="train")
            
            self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            del log_dict_disc["train/disc_loss"]
            
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            
            return discloss

## from vqperceptual_loss file
class VQLPIPSWithDiscriminator(nn.Module):
    def __init__(self, disc_start: int = 0,
                 disc_loss: str = 'vanilla',
                 disc_params: Optional[OmegaConf] = dict(),
                 codebook_weight: float = 1.0,
                 loglaplace_weight: float = 1.0,
                 loggaussian_weight: float = 1.0,
                 perceptual_weight: float = 1.0,
                 adversarial_weight: float = 1.0,
                 use_adaptive_adv: bool = False,
                 r1_gamma: float = 10,
                 do_r1_every: int = 16) -> None:
        
        super().__init__()
        assert disc_loss in ["hinge", "vanilla", "least_square"], f"Unknown GAN loss '{disc_loss}'."
        self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False)

        self.codebook_weight = codebook_weight 
        self.loglaplace_weight = loglaplace_weight 
        self.loggaussian_weight = loggaussian_weight
        self.perceptual_weight = perceptual_weight 

        self.discriminator = StyleDiscriminator(**disc_params)
        self.discriminator_iter_start = disc_start
        if disc_loss == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        elif disc_loss == "least_square":
            self.disc_loss = least_square_d_loss

        self.adversarial_weight = adversarial_weight
        self.use_adaptive_adv = use_adaptive_adv
        self.r1_gamma = r1_gamma
        self.do_r1_every = do_r1_every

    def calculate_adaptive_factor(self, nll_loss: torch.FloatTensor,
                                  g_loss: torch.FloatTensor, last_layer: nn.Module) -> torch.FloatTensor:
        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        
        adapt_factor = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        adapt_factor = adapt_factor.clamp(0.0, 1e4).detach()

        return adapt_factor

    def forward(self, codebook_loss: torch.FloatTensor, inputs: torch.FloatTensor, reconstructions: torch.FloatTensor, optimizer_idx: int,
                global_step: int, batch_idx: int, last_layer: Optional[nn.Module] = None, split: Optional[str] = "train") -> Tuple:
        inputs = inputs.contiguous()
        reconstructions = reconstructions.contiguous()       
        
        # now the GAN part
        if optimizer_idx == 0:
            # generator update
            loglaplace_loss = (reconstructions - inputs).abs().mean()
            loggaussian_loss = (reconstructions - inputs).pow(2).mean()
            perceptual_loss = self.perceptual_loss(inputs*2-1, reconstructions*2-1).mean()

            nll_loss = self.loglaplace_weight * loglaplace_loss + self.loggaussian_weight * loggaussian_loss + self.perceptual_weight * perceptual_loss #total_loss
        
            logits_fake = self.discriminator(reconstructions)
            g_loss = self.disc_loss(logits_fake)
            
            try:
                d_weight = self.adversarial_weight
                
                if self.use_adaptive_adv:
                    d_weight *= self.calculate_adaptive_factor(nll_loss, g_loss, last_layer=last_layer)
            except RuntimeError:
                assert not self.training
                d_weight = torch.tensor(0.0)

            disc_factor = 1 if global_step >= self.discriminator_iter_start else 0
            loss = nll_loss + disc_factor * d_weight * g_loss + self.codebook_weight * codebook_loss

            log = {"{}/total_loss".format(split): loss.clone().detach(),
                   "{}/quant_loss".format(split): codebook_loss.detach(),
                   "{}/rec_loss".format(split): nll_loss.detach(),
                   "{}/loglaplace_loss".format(split): loglaplace_loss.detach(),
                   "{}/loggaussian_loss".format(split): loggaussian_loss.detach(),
                   "{}/perceptual_loss".format(split): perceptual_loss.detach(),
                   "{}/g_loss".format(split): g_loss.detach(),
                   }

            if self.use_adaptive_adv:
                log["{}/d_weight".format(split)] = d_weight.detach()
            
            return loss, log

        if optimizer_idx == 1:
            # second pass for discriminator update
            disc_factor = 1 if global_step >= self.discriminator_iter_start else 0
            do_r1 = self.training and bool(disc_factor) and batch_idx % self.do_r1_every == 0

            logits_real = self.discriminator(inputs.requires_grad_(do_r1))
            logits_fake = self.discriminator(reconstructions.detach())
            
            d_loss = disc_factor * self.disc_loss(logits_fake, logits_real)
            if do_r1:
                with conv2d_gradfix.no_weight_gradients():
                    gradients, = torch.autograd.grad(outputs=logits_real.sum(), inputs=inputs, create_graph=True)

                gradients_norm = gradients.square().sum([1,2,3]).mean()
                d_loss += self.r1_gamma * self.do_r1_every * gradients_norm/2

            log = {"{}/disc_loss".format(split): d_loss.detach(),
                   "{}/logits_real".format(split): logits_real.detach().mean(),
                   "{}/logits_fake".format(split): logits_fake.detach().mean(),
                   }

            if do_r1:
                log["{}/r1_reg".format(split)] = gradients_norm.detach()
            
            return d_loss, log

Stage 2 Vector-quantized Image Modeling

训练一个Transformer模型来自回归预测栅格化32×32 = 1024图像标记,其中图像标记由学习的Stage 1 Vit-VQGAN编码。对于无条件图像合成或无监督学习,预先训练一个仅解码器的Transformer模型来预测下一个令牌。为评估无监督学习的质量,平均中间Transformer特征,并学习一个linear head来预测类的logit(也就是linear-probe)。

需要注意的是,在code这版中Transformer参考了minGPT,使用了一个时间平移与混合的trick来提升训练速度,具体可以参考RWKV的一篇解析:> https://zhuanlan.zhihu.com/p/636551276 > ,简单来说就是引入了两种先验,通过前后错位concat与可学习权重来进行融合计算。

点击查看Transformer代码
# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from minDALL-E (https://github.com/kakaobrain/minDALL-E)
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from minGPT (https://github.com/karpathy/minGPT)
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
# ------------------------------------------------------------------------------------

import math
from omegaconf import OmegaConf
from typing import Optional, Tuple, List

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import autocast


class MultiHeadSelfAttention(nn.Module):
    def __init__(self,
                 ctx_len: int,         # 上下文最大长度
                 cond_len: int,        # 条件序列长度(如 prefix/prompt)
                 embed_dim: int,       # 嵌入维度 C
                 n_heads: int,         # 注意力头数
                 attn_bias: bool,      # 是否在 QKV 和输出层使用偏置
                 use_mask: bool = True):
        super().__init__()
        assert embed_dim % n_heads == 0

        # key, query, value projections for all heads
        self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
        self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)

        # output projection
        self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)

        self.n_heads = n_heads
        self.ctx_len = ctx_len
        self.use_mask = use_mask
        # 创建下三角掩码(不允许过去看到未来)特别地:前 cond_len 个位置之间完全可见(即 prompt 内部可以双向通信),后续图像 tokens 只能看到前面的所有内容(包括 prompt)
        if self.use_mask:
            self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
            self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
            self.mask[:, :cond_len, :cond_len] = 1

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        with torch.no_grad():
            ww = torch.zeros(1, 1, embed_dim)
            for i in range(embed_dim):
                ww[0, 0, i] = i / (embed_dim - 1)
        self.time_mix = nn.Parameter(ww)

    def forward(self, x, use_cache=False, layer_past=None):
        B, T, C = x.shape

        x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)
        x = x.transpose(0, 1).contiguous()  # (B, T, C) -> (T, B, C)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1)  # (B*nh, T, hs)
        q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1)  # (B*nh, T, hs)
        v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1)  # (B*nh, T, hs)
        
        # 在生成阶段用于缓存历史 K/V,避免重复计算。
        if use_cache:
            present = torch.stack([k, v])

        # 返回当前层的 KV 缓存。若有 layer_past,则拼接之前的缓存。
        if layer_past is not None:
            past_key, past_value = layer_past
            k = torch.cat([past_key, k], dim=-2)
            v = torch.cat([past_value, v], dim=-2)

        # 当启用 use_cache 且已有 past 时:query 只取最后一个时间步(因为只生成一个 token), 计算 (B*nh, 1, T) 的 attention 分布, 输出也仅对应最新 token
        if use_cache and layer_past is not None:
            # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
            att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) #torch.bmm 批量矩阵乘法
            att = F.softmax(att, dim=-1)
            y = torch.bmm(att, v)  # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
        else:
            # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
            att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
            if self.use_mask:
                mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
                att = att.masked_fill(mask == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            y = torch.bmm(att, v)  # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
        y = y.transpose(0, 1).contiguous().view(T, B, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)
        
        if use_cache:
            return y.transpose(0, 1).contiguous(), present  # (T, B, C) -> (B, T, C)
        else:
            return y.transpose(0, 1).contiguous()  # (T, B, C) -> (B, T, C)

# ReLU^2 way
class FFN(nn.Module):
    def __init__(self, embed_dim, mlp_bias):
        super().__init__()
        self.p0 = nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias)
        self.p1 = nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias)

    def forward(self, x):
        x = self.p0(x)
        # x = F.gelu(x)
        x = torch.square(torch.relu(x))
        x = self.p1(x)
        return x

class Block(nn.Module):
    def __init__(self,
                 ctx_len: int,
                 cond_len: int,
                 embed_dim: int,
                 n_heads: int,
                 mlp_bias: bool,
                 attn_bias: bool):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
                                           cond_len=cond_len,
                                           embed_dim=embed_dim,
                                           n_heads=n_heads,
                                           attn_bias=attn_bias,
                                           use_mask=True)
        self.mlp = FFN(embed_dim=embed_dim, mlp_bias=mlp_bias)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))

        return x

    def sample(self, x, layer_past=None):
        attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
        x = x + attn
        x = x + self.mlp(self.ln2(x))

        return x, present


class GPT(nn.Module):
    def __init__(self,
                 vocab_cond_size: int,
                 vocab_img_size: int,
                 embed_dim: int,
                 cond_num_tokens: int,
                 img_num_tokens: int,
                 n_heads: int,
                 n_layers: int,
                 mlp_bias: bool = True,
                 attn_bias: bool = True) -> None:
        super().__init__()
        self.img_num_tokens = img_num_tokens 
        self.vocab_cond_size = vocab_cond_size
        
        # condition token and position embedding 
        self.tok_emb_cond = nn.Embedding(vocab_cond_size, embed_dim)
        self.pos_emb_cond = nn.Parameter(torch.zeros(1, cond_num_tokens, embed_dim))
        
        # input token and position embedding
        self.tok_emb_code = nn.Embedding(vocab_img_size, embed_dim)
        self.pos_emb_code = nn.Parameter(torch.zeros(1, img_num_tokens, embed_dim))

        # transformer blocks
        self.blocks = [Block(ctx_len=cond_num_tokens + img_num_tokens,
                             cond_len=cond_num_tokens,
                             embed_dim=embed_dim,
                             n_heads=n_heads,
                             mlp_bias=mlp_bias,
                             attn_bias=attn_bias) for i in range(1, n_layers+1)]
        self.blocks = nn.Sequential(*self.blocks)

        # head
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_img_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self,
                codes: torch.LongTensor,
                conds: torch.LongTensor) -> torch.FloatTensor:
        
        codes = codes.view(codes.shape[0], -1)
        codes = self.tok_emb_code(codes)
        conds = self.tok_emb_cond(conds)
        
        codes = codes + self.pos_emb_code
        conds = conds + self.pos_emb_cond

        x = torch.cat([conds, codes], axis=1).contiguous()
        x = self.blocks(x)
        x = self.layer_norm(x)

        x = x[:, conds.shape[1]-1:-1].contiguous()
        logits = self.head(x)
        
        return logits

    def sample(self,
               conds: torch.LongTensor,
               top_k: Optional[float] = None,
               top_p: Optional[float] = None,
               softmax_temperature: float = 1.0,
               use_fp16: bool = True) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        
        past = codes = logits = None
            
        for i in range(self.img_num_tokens):
            if codes is None:
                codes_ = None
                pos_code = None
            else:
                codes_ = codes.clone().detach()
                codes_ = codes_[:, -1:]
                pos_code = self.pos_emb_code[:, i-1:i, :]
                
            logits_, presents = self.sample_step(codes_, conds, pos_code, use_fp16, past)
            
            logits_ = logits_.to(dtype=torch.float32)
            logits_ = logits_ / softmax_temperature

            presents = torch.stack(presents).clone().detach()
            if past is None:
                past = [presents]
            else:
                past.append(presents)

            if top_k is not None:
                v, ix = torch.topk(logits_, top_k)
                logits_[logits_ < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits_, dim=-1)
            
            if top_p is not None:
                sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
                cum_probs = torch.cumsum(sorted_probs, dim=-1)

                sorted_idx_remove_cond = cum_probs >= top_p

                sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
                sorted_idx_remove_cond[..., 0] = 0

                indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
                probs = probs.masked_fill(indices_to_remove, 0.0)
                probs = probs / torch.sum(probs, dim=-1, keepdim=True)

            idx = torch.multinomial(probs, num_samples=1).clone().detach()
            codes = idx if codes is None else torch.cat([codes, idx], axis=1)
            logits = logits_ if logits is None else torch.cat([logits, logits_], axis=1)

        del past

        return logits, codes

    def sample_step(self,
                    codes: torch.LongTensor,
                    conds: torch.LongTensor,
                    pos_code: torch.LongTensor,
                    use_fp16: bool = True,
                    past: Optional[torch.FloatTensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
        
        with autocast(enabled=use_fp16):
            presents = []
            
            if codes is None:
                assert past is None
                conds = self.tok_emb_cond(conds)
                x = conds + self.pos_emb_cond
                
                for i, block in enumerate(self.blocks):
                    x, present = block.sample(x, layer_past=None)
                    presents.append(present)
                x = self.layer_norm(x)
                x = x[:, conds.shape[1]-1].contiguous()
            else:
                assert past is not None
                codes = self.tok_emb_code(codes)
                x = codes + pos_code
                
                past = torch.cat(past, dim=-2)
                for i, block in enumerate(self.blocks):
                    x, present = block.sample(x, layer_past=past[i])
                    presents.append(present)

                x = self.layer_norm(x)
                x = x[:, -1].contiguous()

            logits = self.head(x)
            
            return logits, presents


class RQTransformer(nn.Module):
    def __init__(self,
                 vocab_cond_size: int,
                 vocab_img_size: int,
                 embed_dim: int,
                 cond_num_tokens: int,
                 img_num_tokens: int,
                 depth_num_tokens: int,
                 spatial_n_heads: int,
                 depth_n_heads: int,
                 spatial_n_layers: int,
                 depth_n_layers: int,
                 mlp_bias: bool = True,
                 attn_bias: bool = True) -> None:
        super().__init__()
        self.img_num_tokens = img_num_tokens
        self.depth_num_tokens = depth_num_tokens
        self.vocab_img_size = vocab_img_size
        
        # condition token and position embedding 
        self.tok_emb_cond = nn.Embedding(vocab_cond_size, embed_dim)
        self.pos_emb_cond = nn.Parameter(torch.rand(1, cond_num_tokens, embed_dim))
        
        # spatial token and position embedding
        self.tok_emb_code = nn.Embedding(vocab_img_size, embed_dim)
        self.pos_emb_code = nn.Parameter(torch.rand(1, img_num_tokens, embed_dim))

        # depth position embedding
        self.pos_emb_depth = nn.Parameter(torch.rand(1, depth_num_tokens-1, embed_dim))

        # spatial transformer
        self.spatial_transformer = [Block(ctx_len=cond_num_tokens + img_num_tokens,
                                          cond_len=cond_num_tokens,
                                          embed_dim=embed_dim,
                                          n_heads=spatial_n_heads,
                                          mlp_bias=mlp_bias,
                                          attn_bias=attn_bias) for i in range(1, spatial_n_layers+1)]
        self.spatial_transformer = nn.Sequential(*self.spatial_transformer)

        # depth transformer
        self.depth_transformer = [Block(ctx_len=depth_num_tokens,
                                        cond_len=0,
                                        embed_dim=embed_dim,
                                        n_heads=depth_n_heads,
                                        mlp_bias=mlp_bias,
                                        attn_bias=attn_bias) for i in range(1, depth_n_layers+1)]
        self.depth_transformer = nn.Sequential(*self.depth_transformer)

        # head
        self.ln_spatial = nn.LayerNorm(embed_dim)
        self.ln_depth = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_img_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self,
                codes: torch.LongTensor,
                conds: torch.LongTensor) -> torch.FloatTensor:
        
        codes = codes.view(codes.shape[0], -1, codes.shape[-1])
        codes = self.tok_emb_code(codes)
        conds = self.tok_emb_cond(conds)

        codes_cumsum = codes.cumsum(-1)
        codes_sum = codes_cumsum[..., -1, :]
        
        codes = codes_sum + self.pos_emb_code
        conds = conds + self.pos_emb_cond

        h = torch.cat([conds, codes], axis=1).contiguous()
        h = self.ln_spatial(self.spatial_transformer(h))
        h = h[:, conds.shape[1]-1:-1].contiguous()

        v = codes_cumsum[..., :-1, :] + self.pos_emb_depth
        v = torch.cat([h.unsqueeze(2), v], axis=2).contiguous()

        v = v.view(-1, *v.shape[2:])
        v = self.depth_transformer(v)                  
        logits = self.head(self.ln_depth(v))
        
        return logits

    def sample(self,
               conds: torch.LongTensor,
               top_k: Optional[float] = None,
               top_p: Optional[float] = None,
               softmax_temperature: float = 1.0,
               use_fp16: bool = True) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        
        past = codes = logits = None
        B, T, D, S = conds.shape[0], self.img_num_tokens, self.depth_num_tokens, self.vocab_img_size
            
        for i in range(self.img_num_tokens):
            depth_past = None
            
            if codes is None:
                codes_ = None
                pos_code = None
            else:
                codes_ = codes.clone().detach()
                codes_ = codes_[:, -self.depth_num_tokens:]
                pos_code = self.pos_emb_code[:, i-1:i, :]
                
            hidden, presents = self.sample_spatial_step(codes_, conds, pos_code, use_fp16, past)

            presents = torch.stack(presents).clone().detach()
            if past is None:
                past = [presents]
            else:
                past.append(presents)

            last_len = 0 if codes is None else codes.shape[-1]

            for d in range(self.depth_num_tokens):
                if depth_past is None:
                    codes_ = None
                    pos_depth = None
                else:
                    codes_ = codes.clone().detach()
                    codes_ = codes_[:, last_len:]
                    pos_depth = self.pos_emb_depth[:, d-1:d, :]
                
                logits_, depth_presents = self.sample_depth_step(codes_, hidden, pos_depth, use_fp16, depth_past)

                logits_ = logits_.to(dtype=torch.float32)
                logits_ = logits_ / softmax_temperature

                depth_presents = torch.stack(depth_presents).clone().detach()
                if depth_past is None:
                    depth_past = [depth_presents]
                else:
                    depth_past.append(depth_presents)

                if top_k is not None:
                    v, ix = torch.topk(logits_, top_k)
                    logits_[logits_ < v[:, [-1]]] = -float('Inf')
                probs = F.softmax(logits_, dim=-1)
                
                if top_p is not None:
                    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
                    cum_probs = torch.cumsum(sorted_probs, dim=-1)

                    sorted_idx_remove_cond = cum_probs >= top_p

                    sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
                    sorted_idx_remove_cond[..., 0] = 0

                    indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
                    probs = probs.masked_fill(indices_to_remove, 0.0)
                    probs = probs / torch.sum(probs, dim=-1, keepdim=True)

                idx = torch.multinomial(probs, num_samples=1).clone().detach()
                codes = idx if codes is None else torch.cat([codes, idx], axis=1)
                logits = logits_ if logits is None else torch.cat([logits, logits_], axis=1)

            del depth_past

        del past

        codes = codes.view(B, T, D)
        logits = logits.view(B * T, D, S)
        
        return logits, codes

    def sample_spatial_step(self,
                            codes: torch.LongTensor,
                            conds: torch.LongTensor,
                            pos_code: torch.LongTensor,
                            use_fp16: bool = True,
                            past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
        
        with autocast(enabled=use_fp16):
            presents = []

            if codes is None:
                assert past is None
                conds = self.tok_emb_cond(conds)
                x = conds + self.pos_emb_cond
                
                for i, block in enumerate(self.spatial_transformer):
                    x, present = block.sample(x, layer_past=None)
                    presents.append(present)
                x = self.ln_spatial(x)
                x = x[:, conds.shape[1]-1:conds.shape[1]].contiguous()
            else:
                assert past is not None
                codes = self.tok_emb_code(codes)
                x = codes.sum(1, keepdim=True) + pos_code
                
                past = torch.cat(past, dim=-2)
                for i, block in enumerate(self.spatial_transformer):
                    x, present = block.sample(x, layer_past=past[i])
                    presents.append(present)

                x = self.ln_spatial(x)
                x = x[:, -1:].contiguous()
                
            return x, presents

    def sample_depth_step(self,
                          codes: torch.LongTensor,
                          hidden: torch.FloatTensor,
                          pos_depth: torch.LongTensor,
                          use_fp16: bool = True,
                          past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
        
        with autocast(enabled=use_fp16):
            presents = []

            if codes is None:
                assert past is None
                x = hidden
                
                for i, block in enumerate(self.depth_transformer):
                    x, present = block.sample(x, layer_past=None)
                    presents.append(present)
                x = self.ln_depth(x)
            else:
                assert past is not None
                codes = self.tok_emb_code(codes)
                x = codes.sum(1, keepdim=True) + pos_depth
                
                past = torch.cat(past, dim=-2) 
                for i, block in enumerate(self.depth_transformer):
                    x, present = block.sample(x, layer_past=past[i])
                    presents.append(present)

            x = self.ln_depth(x)    
            x = x[:, -1].contiguous()
            
            logits = self.head(x)   

            return logits, presents

二、训练流程

BasicSR提供了一个非常结构化的SR训练框架,但可惜的是ViT-VQGAN使用的是自搭的训练流程。因此需要进行一下迁移。我们不妨以此为契机来了解一下训练方法。
一个完整的训练流程如下:

点击查看ViT-VQGAN model代码
# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
# ------------------------------------------------------------------------------------

from typing import List, Tuple, Dict, Any, Optional
from omegaconf import OmegaConf

import PIL
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision import transforms as T
import pytorch_lightning as pl

from .layers import ViTEncoder as Encoder, ViTDecoder as Decoder
from .quantizers import VectorQuantizer, GumbelQuantizer
from ...utils.general import initialize_from_config


class ViTVQ(pl.LightningModule):
    def __init__(self, image_key: str, image_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf,
                 loss: OmegaConf, path: Optional[str] = None, ignore_keys: List[str] = list(), scheduler: Optional[OmegaConf] = None) -> None:
        super().__init__()
        self.path = path
        self.ignore_keys = ignore_keys 
        self.image_key = image_key
        self.scheduler = scheduler 
        
        self.loss = initialize_from_config(loss)
        self.encoder = Encoder(image_size=image_size, patch_size=patch_size, **encoder)
        self.decoder = Decoder(image_size=image_size, patch_size=patch_size, **decoder)
        self.quantizer = VectorQuantizer(**quantizer)
        self.pre_quant = nn.Linear(encoder.dim, quantizer.embed_dim)
        self.post_quant = nn.Linear(quantizer.embed_dim, decoder.dim)

        if path is not None:
            self.init_from_ckpt(path, ignore_keys)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:    
        quant, diff = self.encode(x)
        dec = self.decode(quant)
        
        return dec, diff

    def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")
        
    def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        h = self.encoder(x)
        h = self.pre_quant(h)
        quant, emb_loss, _ = self.quantizer(h)
        
        return quant, emb_loss

    def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
        quant = self.post_quant(quant)
        dec = self.decoder(quant)
        
        return dec

    def encode_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
        h = self.encoder(x)
        h = self.pre_quant(h)
        _, _, codes = self.quantizer(h)
        
        return codes

    def decode_codes(self, code: torch.LongTensor) -> torch.FloatTensor:
        quant = self.quantizer.embedding(code)
        quant = self.quantizer.norm(quant)
        
        if self.quantizer.use_residual:
            quant = quant.sum(-2)  
            
        dec = self.decode(quant)
        
        return dec

    def get_input(self, batch: Tuple[Any, Any], key: str = 'image') -> Any:
        x = batch[key]
        if len(x.shape) == 3:
            x = x[..., None]
        if x.dtype == torch.double:
            x = x.float()

        return x.contiguous()

    def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
        x = self.get_input(batch, self.image_key)
        xrec, qloss = self(x)

        if optimizer_idx == 0:
            # autoencoder
            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
                                            last_layer=self.decoder.get_last_layer(), split="train")

            self.log("train/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            del log_dict_ae["train/total_loss"]
            
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)

            return aeloss

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
                                                last_layer=self.decoder.get_last_layer(), split="train")
            
            self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            del log_dict_disc["train/disc_loss"]
            
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            
            return discloss

    def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> Dict:
        x = self.get_input(batch, self.image_key)
        xrec, qloss = self(x)
        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, batch_idx,
                                        last_layer=self.decoder.get_last_layer(), split="val")

        rec_loss = log_dict_ae["val/rec_loss"]

        self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log("val/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        del log_dict_ae["val/rec_loss"]
        del log_dict_ae["val/total_loss"]

        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)

        if hasattr(self.loss, 'discriminator'):
            discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, batch_idx,
                                                last_layer=self.decoder.get_last_layer(), split="val")
            
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
        
        return self.log_dict

    def configure_optimizers(self) -> Tuple[List, List]:
        lr = self.learning_rate
        optim_groups = list(self.encoder.parameters()) + \
                       list(self.decoder.parameters()) + \
                       list(self.pre_quant.parameters()) + \
                       list(self.post_quant.parameters()) + \
                       list(self.quantizer.parameters())
        
        optimizers = [torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
        schedulers = []
        
        if hasattr(self.loss, 'discriminator'):
            optimizers.append(torch.optim.AdamW(self.loss.discriminator.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4))

        if self.scheduler is not None:
            self.scheduler.params.start = lr
            scheduler = initialize_from_config(self.scheduler)
            
            schedulers = [
                {
                    'scheduler': lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                } for optimizer in optimizers
            ]
   
        return optimizers, schedulers
        
    def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict:
        log = dict()
        x = self.get_input(batch, self.image_key).to(self.device)
        quant, _ = self.encode(x)
        
        log["originals"] = x
        log["reconstructions"] = self.decode(quant)
        
        return log


class ViTVQGumbel(ViTVQ):
    def __init__(self, image_key: str, image_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf, loss: OmegaConf,
                 path: Optional[str] = None, ignore_keys: List[str] = list(), temperature_scheduler: OmegaConf = None, scheduler: Optional[OmegaConf] = None) -> None:
        super().__init__(image_key, image_size, patch_size, encoder, decoder, quantizer, loss, None, None, scheduler)

        self.temperature_scheduler = initialize_from_config(temperature_scheduler) \
                                     if temperature_scheduler else None
        self.quantizer = GumbelQuantizer(**quantizer)

        if path is not None:
            self.init_from_ckpt(path, ignore_keys)

    def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
        if self.temperature_scheduler:
            self.quantizer.temperature = self.temperature_scheduler(self.global_step)

        loss = super().training_step(batch, batch_idx, optimizer_idx)
        
        if optimizer_idx == 0:
            self.log("temperature", self.quantizer.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)

        return loss

暂时搁置,之后更新

博客园  ©  2004-2026
浙公网安备 33010602011771号 浙ICP备2021040463号-3