import torch.nn as nn
from .decoder import Decoder
from .encoder import Encoder
class Transformer(nn.Module): #定义类,继承父类nn.Module
"""An encoder-decoder framework only includes attention.
"""
def __init__(self, encoder=None, decoder=None): #参数encoder和decoder设置默认值None
super(Transformer, self).__init__() #继承父类__init__()
if encoder is not None and decoder is not None: #判断decoder和encoder是否被重新赋值
self.encoder = encoder
self.decoder = decoder
for p in self.parameters(): #获取网络参数
if p.dim() > 1:
nn.init.xavier_uniform_(p) #参数初始化,torch.nn.init.xavier_uniform_是一个服从均匀分布的Glorot初始化器
# else:
# self.encoder = Encoder() #对全局变量赋值
# self.decoder = Decoder()
def forward(self, padded_input, input_lengths, padded_target): #编码器中的前向传播
"""
Args:
padded_input: B x Ti x D 表示编码器输入时数据结构
其中B(一维向量):批量中每个音频的具体长度;Ti:该批量中音频的最大长度;
input_lengths: B 每个音频的具体长度,假设批量大小为32,则B可表示为[2,3,45,6....],维度32
padded_targets: B x To 表示解码器的输入数据结构,这里的B和上面的B不同,因为编码器中是音频的输入,解码器中的输入是字符
"""
encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths)
# pred is score before softmax
pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs,
input_lengths)
return pred, gold
def recognize(self, input, input_length, char_list, args): #解码器中的识别过程
"""Sequence-to-Sequence beam search, decode one utterence now.
Args:
input: T x D
char_list: list of characters
args: args.beam
Returns:
nbest_hyps:
"""
encoder_outputs, *_ = self.encoder(input.unsqueeze(0), input_length)
nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0],
char_list,
args)
return nbest_hyps