transfomer推理 + 训练过程

1.1 理解transfomer中的反向传播

1. 反向传播的核心目标

  • 计算梯度:通过链式法则,计算损失函数 L 对所有可训练参数(权重、偏置)的梯度。
  • 参数更新:使用优化器(如 Adam)根据梯度更新参数。

2. Transformer 的损失计算

  • 典型损失函数:交叉熵损失(Cross-Entropy Loss),用于衡量预测分布与真实分布的差异。
  • 计算过程:
    1. 解码器输出 logits(未归一化概率)。
    2. 应用 softmax 得到概率分布。
    3. 计算预测概率与真实标签的交叉熵。

3. 反向传播的关键步骤

步骤 1:解码器的反向传播
  • 目标:计算损失对解码器参数的梯度。
  • 过程:
    1. 输出层梯度:
      • 损失对 logits 的梯度:∂logits∂L=softmax_output−ground_truth。
    2. 反向传播通过前馈神经网络(FFN):
      • 梯度通过两层全连接网络反向传播,应用链式法则。
    3. 反向传播通过自注意力机制:
      • 计算损失对查询(Q)、键(K)、值(V)的梯度。
      • 涉及矩阵乘法和 softmax 的梯度计算。
    4. 反向传播通过编码器-解码器注意力:
      • 计算损失对解码器查询和编码器输出的梯度。
步骤 2:编码器的反向传播
  • 目标:计算损失对编码器参数的梯度。
  • 过程:
    1. 通过解码器-编码器注意力传递梯度:
      • 解码器传递的梯度用于计算编码器输出的梯度。
    2. 反向传播通过自注意力机制:
      • 与解码器类似,计算 Q、K、V 的梯度。
    3. 反向传播通过前馈神经网络:
      • 梯度通过两层全连接网络反向传播。
步骤 3:梯度累积与参数更新
  • 梯度累积:
    • 将所有层的梯度累加,得到每个参数的总梯度。
  • 参数更新:
    • 使用优化器(如 Adam)更新参数:θθη⋅∂θL,其中 η 为学习率。

4. 关键组件的梯度计算

4.1 自注意力机制的梯度
  • 公式

Attention(Q,K,V)=softmax(dkQKT)V

  • 梯度计算:
    • Softmax 梯度:使用雅可比矩阵计算。
    • 矩阵乘法梯度:应用链式法则,分别计算 QKV 的梯度。
4.2 前馈神经网络的梯度
  • 公式

FFN(x)=ReLU(x**W1+b1)W2+b2

  • 梯度计算:
    • ReLU 梯度:导数为 1(正区间)或 0(负区间)。
    • 矩阵乘法梯度:应用链式法则,分别计算 W1、b1、W2、b2 的梯度。
4.3 位置编码的梯度
  • 特性:位置编码通常固定,不参与梯度计算。
  • 例外:若位置编码可训练,则计算其梯度并更新。

5. 反向传播的挑战与优化

挑战 1:内存消耗
  • 原因:Transformer 参数众多,激活值需全部存储以计算梯度。
  • 解决方案:梯度检查点(Gradient Checkpointing),在反向传播时重新计算部分激活值。
挑战 2:梯度消失/爆炸
  • 原因:深层网络中的梯度可能过小或过大。
  • 解决方案:
    • 层归一化:稳定梯度,加速训练。
    • 残差连接:缓解梯度消失。
挑战 3:计算效率
  • 原因:自注意力机制的计算复杂度为 O(n2)。
  • 解决方案:使用混合精度训练、分布式计算等。

6. 示例:简单 Transformer 的反向传播

假设

  • 单层编码器,输入为 2 个单词的嵌入向量。
  • 损失函数为均方误差(MSE)。

步骤

  1. 前向传播:
    • 输入嵌入 → 自注意力 → FFN → 输出。
  2. 损失计算:
    • 输出与真实标签计算 MSE。
  3. 反向传播:
    • 计算损失对输出的梯度。
    • 反向传播通过 FFN 和自注意力,计算各参数的梯度。
  4. 参数更新:
    • 使用梯度下降更新参数。

7. 总结

  • 核心流程:前向传播 → 损失计算 → 反向传播 → 参数更新。
  • 关键组件:自注意力、FFN、层归一化、残差连接的梯度计算。
  • 优化策略:梯度检查点、层归一化、混合精度训练。

直接结论:Transformer 的反向传播通过链式法则和自动微分,高效计算梯度并更新参数,尽管面临内存和计算挑战,但通过优化策略可有效解决。

1.2 代码实践

学习地址:https://www.bilibili.com/video/BV1nyyoYLEyL/?spm_id_from=333.337.search-card.all.click&vd_source=824d8f61906b474c0974b8dce18a69fd

TokenEmbedding

class TokenEmbedding(nn.Embedding):

    # 将输入的词汇索引转换为指定维度d_model embedding向量
    
    def __init__(self,vocb_size,d_model):
        super(TokenEmbedding,self).__init__(vocb_size,d_model,padding_index=1)

PositionEmbedding

class PositialEmbedding(nn.Module):
    def __init__(slef,d_model,max_len,device):
        super(PositialEmbedding,self).__init__()
        self.encoding = torch.zeros(max_len,d_model,device=device)
        self.encoding.requires_grad = False
        pos = torch.arange(0,max_len,device=device)
        pos = pos.float().unsqueeze(dim=1)
        # torch.arange(0, 5, device='cuda') 生成一个张量:tensor([0, 1, 2, 3, 4], device='cuda')
        # pos.float() 将其转换为:tensor([0., 1., 2., 3., 4.], device='cuda')
        # pos.unsqueeze(dim=1) 将其重塑为:tensor([[0.], [1.], [2.], [3.], [4.]], device='cuda')
        _2i = torch.arrage(0,d_model,step=2,device=device).float()
        self.encoding[:,0::2] = torch.sin(pos/(10000*(_2i/d_model)))
        self.encoding[:,1::2] = torch.cos(pos/(10000*(_2i/d_model)))
    def forward(self,x):
        batch_size,seq_len = x.size()
        return self.encoding[:seq_len,:]

TransformerEmbedding

class TransformerEmbedding(nn.Module):
    def __init__(self,vocab_size,d_model,max_len,drop_prod,device):
        super(TransformerEmbedding,self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size,d_model)
        self.pos_emb = PositialEmbedding(d_model,max_len,device)
        self.drop_out = nn.Dropout(p=drop_prod)  # 随机丢弃神经元,防止过拟合

    def forward(self,x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)

MultiHeadAttention

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,n_head):
        super(MultiHeadAttention,self).__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model,d_model)  # 注意看2是不是应该是d_model
        self.w_k = nn.Linear(d_model,d_model)
        self.w_v = nn.Linear(d_model,d_model)
        self.w_combine = nn.Linear(d_model,d_model)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self,q,k,v,mask=None):
        batch,time,dimension = q.shape
        n_d = self.d_model // self.n_head
        q,k,v = self.w_q(q),self.w_k(k),self.w_v(v)

        # 将嵌入向量分割成多个头
        q = q.view(batch,time,self.n_head,n_d).permute(0,2,1,3)
        k = k.view(batch,time,self.n_head,n_d).permute(0,2,1,3)
        v = v.view(batch,time,self.n_head,n_d).permute(0,2,1,3)

        # 计算注意力分数
        # # 缩放点积注意力
        score = q@k.transpose(2,3)/math.sqrt(n_d)
        if mask is not None:
            score = score.masked_fill(mask==0,-10000)
        score = self.softmax(score)@v
        socre = score.permute(0,2,1,3).contiguous().view(batch,time,dimension)
        output = self.w_combine(score)

        return output

LayerNorm

# batch norm 和 layer norm的不同:batch norm是对一个批次内的多个样本数据进行归一化, layer norm 是对一个样本数据类,不同词向量进行归一化
# 归一化,帮助模型训练加速,提高性能
class LayerNorm(nn.Module):
    def __init__(self,d_model,eps=1e-12):
        super(LayerNorm,self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    def forward(self,x):
        # 减去均值,除以 标准差+小常数,再通过gamma,beta进行缩放

        mean = x.mean(-1,keepdim=True)
        var = x.var(-1,unbiased=False,keepdim=True)
        output = (x - mean) / torch.sqrt(var+self.eps)
        output = self.gama*output + self.beta
        return output

前馈神经网络

class PostisionwiseFeedForward(nn.Module):
    def __init__(self,d_model,d_hidden,dropout_prod=0.1):
        super(PostisionwiseFeedForward,self).__init__()
        # 定义全连接层
        self.fc1 = nn.Linear(d_model,d_hidden)
        # 将d_hidden映射为d_model
        self.fc2 = nn.Linear(d_hidden,d_model)
        self.dropout = nn.Dropout(dropout_prod)
    def feedfoward(self,x):
        # 线性变化1
        x = self.fc1(x)
        # 激活函数
        x = F.Relu(x)
        # dropout
        x.self.dropout(x)
        # 线性变化2
        x = self.fc2(x)
        return x

Ecoder

class EncoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,n_head,dropout=0.1):
        super(EncoderLayer,self).__init__()
        # 定义注意力层
        self.attention = MultiHeadAttention(d_model,n_head)
        # 定义归一化层
        self.norml1 = LayerNorm(d_model)
        # 定义dropout
        self.dropout1 = nn.Dropout(dropout)
        # 定义前馈神经网络
        self.ffn = PostisionwiseFeedForward(d_model,ffn_hidden,dropout)
        # 定义第二个归一化层
        self.norml2 = LayerNorm(d_model)
        # 定义dropout2
        self.dropout2 = nn.Dropout(dropout)
    def feedfoward(self,x,mask=None):
        _x = x
        x =self.attention(x,x,x,mask)
        x = self.dropout1(x)
        # 将计算注意力,dropout后的x与原来的x做残差链接,再经过归一化
        x = self.norml1(x + _x)

        # 再计算一次
        _x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norml2(x + _x)
        return x        
class Encoder(nn.Module):
    def __init__(self,encode_vocab_size,max_len,d_model,ffn_hidden,n_head,n_layer,device,dropout=0.1):
        super(Encoder,self).__init__()
        self.embedding = TransformerEmbedding(encode_vocab_size,d_model,max_len,dropout,device)
        self.layers = nn.ModlueList(
            [
                EncoderLayer(d_model,ffn_hidden,n_head)
                for _ in range(n_layer)
            ]
        )

    def feedfoward(self,x,s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x,s_mask)
        return x    

Decoder

class DecoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,n_head,dropout=0.1):
        super(Decoder,self).__init__()
        # 创建attention1 用于解码注意力分数
        self.attention1 = MultiHeadAttention(d_model,n_head)
        # 定义归一化层
        self.norm1 = LayerNorm(d_model)
        # 定义dropout1
        self.dropout1 = nn.Dropout(dropout)

        # 创建跨模块的多头注意力
        self.cross_attention = MultiHeadAttention(d_model,n_head)
        # 定义归一化层2
        self.norm2 = LayerNorm(d_model)
        # 定义dropout2
        self.dropout2 = nn.Dropout(dropout)

        # 定义全连接层
        self.ffn = PostisionwiseFeedForward(d_model,ffn_hidden,dropout)
        
        # 定义归一化层3
        self.norm3 = LayerNorm(d_model)
        # 定义dropout3
        self.dropout3 = nn.Dropout(dropout)

    # t_mask 表示目标mask, 用于解码decode自注意力mask
    # s_mask 表示源mask 用于跨模态的mask
    # dec 解码输出
    # enc 编码输出
    def feedfoward(self,dec,enc,t_mask,s_mask):
        # 第一层
        _x = dec 
        x = self.attention1(dec,dec,dec,t_mask)
        x = self.dropout1(x)
        x =  self.norm1(x + _x)

        # 第二层,跨模态
        _x = x
        x = self.cross_attention(x,enc,enc,s_mask)
        x = self.dropout2(x)
        x =  self.norm2(x + _x)

        # 第三层,经过前馈神经网咯
        _x = x
        x = self.ffn(x)
        x = self.dropout3(x)
        x =  self.norm3(x + _x)

        return x
class Decoder(nn.Module):
    def __init__(self,decode_vocab_size,max_len,d_model,ffn_hidden,n_head,n_layer,device,dropout=0.1):
        super(Decode,self).__init__()
        self.embedding = TransformerEmbedding(decode_vocab_size,d_model,max_len,dropout,device)
        self.layers = nn.ModlueList(
            [
                DecoderLayer(d_model,ffn_hidden,n_head)
                for _ in range(n_layer)
            ]
        )

        # 多了一个全连接层
        self.fc = nn.Linear(d_model,decode_vocab_size)
        
    def feedfoward(self,dec,enc, t_mask,s_mask):
        dec  = self.embedding(dec)
        for layer in self.layers:
            dec  = layer(dec,enc,t_mask,s_mask)
        dec  = self.fc(dec)
        return   

TransFomer

class TransFomer(nn.Module):
    def __init__(self,src_pad_index,trg_pad_index,enc_vocab_size,dec_vocab_size,max_len,d_model,n_heads,ffn_hidden,n_layers,device,dropout=0.1):
        super(TransFomer,self).__init__()
        self.encoder = Encoder(enc_vocab_size,max_len,d_model,ffn_hidden,n_heads,n_layers,device,dropout)
        self.deconder = Decoder(dec_vocab_size,max_len,d_model,ffn_hidden,n_heads,n_layers,device,dropout)
        self.src_pad_idx = src_pad_index
        self.trg_pad_idx = trg_pad_index
        self.device = device 

    def make_pad_mask(self,q,k,pad_idx_q,pad_idx_k):
        len_q,len_k =  q.size(1), k.size(1)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1,1,1,len_k)

        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1,1,len_q,1)

        mask = q & k
        return mask


    def make_casual_mask(self,q,k):
        len_q,len_k =  q.size(1), k.size(1)
        mask = torch.trill(torch.ones(len_q,len_k)).type(torch.BoolTensor).to(self.device)
        return mask 

    def feedfoward(self,src,trg):
        src_mask = make_pad_mask(src,src,self.src_pad_idx,self.src_pad_idx)
        trg_mask = make_pad_mask(trg,trg,self.trg_pad_idx,self.trg_pad_idx) * self.make_casual_mask(trg,trg)

        enc = self.encoder(src,src_mask)
        output = self.decoder(trg,enc,trg_mask,src_mask)
        return output   
posted @ 2025-04-25 10:03  付十一。  阅读(154)  评论(0)    收藏  举报