Transformer-Embedding
导包
#导入包
import torch
from torch import nn
import torch.nn.functional as f 
import math 

TokenEmbedding
#首先定义token embadding
from torch import Tensor
"""
    将输入词汇表的索引转换成指定维度的Embedding
"""
class TokenEmbedding(nn.Embedding):
    def __init__(self,vocab_size,d_model):
        """  
        初始化TokenEmbedding类。  
  
        参数:  
            vocab_size (int): 词汇表的大小。  
            d_model (int): Embedding的维度。  
  
        注意:  
            此类自动将索引为1的词汇视为填充词,并将其嵌入向量初始化为全零。  
            如果你不希望这样,可以手动设置padding_idx参数,或者将其设置为None。  
        """  
        super(TokenEmbedding,self).__init__(vocab_size,d_model,padding_idx=1)
PositionalEmbedding

class PositionalEmbedding(nn.Module):
    def __init__(self,d_model,max_len,device):
        """
            初始化位置矩阵
        """
        super(PositionalEmbedding,self).__init__()
        #初始化0矩阵
        self.encoding = torch.zeros(max_len,d_model,device=device)
        #位置编码不需要优化,就不需要梯度更新
        self.encoding.requires_grad = False
        #定义pos,生成位置索引
        pos = torch.arange(0,max_len)
        pos = pos.to(device)
        #类型转换为浮点型便于计算,在进行维度拓展为二维张量,利用广播机制自动对其
        pos = pos.float().unsqueeze(dim=1)
        #根据公式计算
        frequencies_indices = torch.arange(0, d_model, step=2, device=device).float()
        frequencies = 1.0/torch.pow(10000.0,frequencies_indices//d_model).unsqueeze(dim=0)
        self.encoding[:,0::2] = torch.sin(pos*frequencies)
        self.encoding[:,1::2] = torch.cos(pos*frequencies)
    def forward(self,x):
        #获取批量大小和序列长度
        batch_size,seq_len = x.size()
        return self.encoding[:seq_len,:]
TransformerEmbedding
class TransformerEmbedding(nn.Module):
    def __init__(self,vocab_size,d_model,max_len,drop_prob,device):
        super(TransformerEmbedding,self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size-vocab_size,d_model=d_model)
        self.pos_emb = PositionalEmbedding(d_model=d_model,max_len=max_len,device=device)
        self.drop_out=nn.Dropout(p=drop_prob)
    
    def forward(self,x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb+pos_emb)

posted on 2024-07-23 16:45  凯申物流——  阅读(82)  评论(0)    收藏  举报