transformer代码笔记----transformer.py

import torch.nn as nn

from .decoder import Decoder
from .encoder import Encoder


class Transformer(nn.Module):  #定义类,继承父类nn.Module
    """An encoder-decoder framework only includes attention.
    """

    def __init__(self, encoder=None, decoder=None):  #参数encoder和decoder设置默认值None
        super(Transformer, self).__init__()          #继承父类__init__()

        if encoder is not None and decoder is not None:   #判断decoder和encoder是否被重新赋值
            self.encoder = encoder
            self.decoder = decoder

            for p in self.parameters():  #获取网络参数
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)  #参数初始化,torch.nn.init.xavier_uniform_是一个服从均匀分布的Glorot初始化器
        # else:
        #     self.encoder = Encoder()    #对全局变量赋值
        #     self.decoder = Decoder()

    def forward(self, padded_input, input_lengths, padded_target):  #编码器中的前向传播
        """
        Args:
            padded_input: B x Ti x D   表示编码器输入时数据结构
            其中B(一维向量):批量中每个音频的具体长度;Ti:该批量中音频的最大长度;
            input_lengths: B   每个音频的具体长度,假设批量大小为32,则B可表示为[2,3,45,6....],维度32
            padded_targets: B x To   表示解码器的输入数据结构,这里的B和上面的B不同,因为编码器中是音频的输入,解码器中的输入是字符
        """
        encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths)
        # pred is score before softmax
        pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs,
                                      input_lengths)
        return pred, gold

    def recognize(self, input, input_length, char_list, args):   #解码器中的识别过程
        """Sequence-to-Sequence beam search, decode one utterence now.
        Args:
            input: T x D
            char_list: list of characters
            args: args.beam
        Returns:
            nbest_hyps:
        """
        encoder_outputs, *_ = self.encoder(input.unsqueeze(0), input_length)
        nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0],
                                                 char_list,
                                                 args)
        return nbest_hyps

 

posted @ 2021-10-19 19:46  Uriel-w  阅读(122)  评论(0编辑  收藏  举报