import torch
import torch.nn as nn
import torch.nn.functional as F
from config import IGNORE_ID
from .attention import MultiHeadAttention
from .module import PositionalEncoding, PositionwiseFeedForward
from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list
# filename = 'bigram_freq.pkl'
# print('loading {}...'.format(filename))
# with open(filename, 'rb') as file:
# bigram_freq = pickle.load(file)
class Decoder(nn.Module):
''' A decoder model with self attention mechanism. '''
def __init__(
self, sos_id=0, eos_id=1,
n_tgt_vocab=4335, d_word_vec=512,
n_layers=6, n_head=8, d_k=64, d_v=64,
d_model=512, d_inner=2048, dropout=0.1,
tgt_emb_prj_weight_sharing=True,
pe_maxlen=5000):
super(Decoder, self).__init__()
# parameters 参数实例化
self.sos_id = sos_id # Start of Sentence
self.eos_id = eos_id # End of Sentence
self.n_tgt_vocab = n_tgt_vocab
self.d_word_vec = d_word_vec
self.n_layers = n_layers
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing
self.pe_maxlen = pe_maxlen
self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec)
self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
self.dropout = nn.Dropout(dropout)
self.layer_stack = nn.ModuleList([
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)]) #解码器个数
self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) #线性变换
nn.init.xavier_normal_(self.tgt_word_prj.weight) #初始化
if tgt_emb_prj_weight_sharing: #默认为true
# Share the weight matrix between target word embedding & the final logit dense layer
self.tgt_word_prj.weight = self.tgt_word_emb.weight #将目标词嵌入权重共享给线性函数的权重
self.x_logit_scale = (d_model ** -0.5) #?
else:
self.x_logit_scale = 1.
def preprocess(self, padded_input): #预处理
"""Generate decoder input and output label from padded_input
Add <sos> to decoder input, and add <eos> to decoder output label
"""
ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys IGNOR_ID=-1
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos_id]) #定义新的零阶tensor
# .new():创建一个新的Tensor,该Tensor的type和device都和原有Tensor一致,且无内容。
sos = ys[0].new([self.sos_id])
ys_in = [torch.cat([sos, y], dim=0) for y in ys] #合并两个tensor,添加起始标签
ys_out = [torch.cat([y, eos], dim=0) for y in ys] #添加结束标签
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos_id) #ys_in:填充对象;self.eos_id:填充值
ys_out_pad = pad_list(ys_out, IGNORE_ID)
assert ys_in_pad.size() == ys_out_pad.size() #assert判断后面代码的布尔值,若为假就报错
return ys_in_pad, ys_out_pad #返回添加标签和填充后的数据
def forward(self, padded_input, encoder_padded_outputs,
encoder_input_lengths, return_attns=False):
"""
Args:
padded_input: N x To
encoder_padded_outputs: N x Ti x H
Returns:
"""
dec_slf_attn_list, dec_enc_attn_list = [], [] #定义解码器注意力和编码解码注意力列表
# Get Deocder Input and Output
ys_in_pad, ys_out_pad = self.preprocess(padded_input) #提取预处理后的数据
# Prepare masks
non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) #对输入mask
slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) #对目标序列mask
slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
seq_q=ys_in_pad,
pad_idx=self.eos_id) #对key mask
slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) #自注意力mask
output_length = ys_in_pad.size(1)
dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
encoder_input_lengths,
output_length) #编码解码注意力mask
# Forward
dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
self.positional_encoding(ys_in_pad)) #输入等词嵌入加位置编码
for dec_layer in self.layer_stack: #进入decoder层
dec_outpsk=slf_aut, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, encoder_padded_outputs,
non_pad_mask=non_pad_mask,
slf_attn_mattn_mask,
dec_enc_attn_mask=dec_enc_attn_mask)
if return_attns: #默认False
dec_slf_attn_list += [dec_slf_attn]
dec_enc_attn_list += [dec_enc_attn]
# before softmax
seq_logit = self.tgt_word_prj(dec_output)#编码器的输出放入线性网络中
# Return
pred, gold = seq_logit, ys_out_pad #得到目标值和预测值
if return_attns:
return pred, gold, dec_slf_attn_list, dec_enc_attn_list
return pred, gold
def recognize_beam(self, encoder_outputs, char_list, args):
"""Beam search, decode one utterence now.
Args:
encoder_outputs: T x H
char_list: list of character
args: args.beam
Returns:
nbest_hyps:
"""
# search params
beam = args.beam_size
nbest = args.nbest
if args.decode_max_len == 0:
maxlen = encoder_outputs.size(0)
else:
maxlen = args.decode_max_len
encoder_outputs = encoder_outputs.unsqueeze(0) #unsqueeze(0)对零维添加一个维度
# prepare sos
# 在数据中添加起始标志
ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
#.ones(size):生成一个全是1的tensor;a.type_as(b):将a的数据类型转换为b的数据类型;
#a.fill_(b):将a中的数据替换为b;long():数据类型
# yseq: 1xT
hyp = {'score': 0.0, 'yseq': ys}
hyps = [hyp]
ended_hyps = []
for i in range(maxlen):
hyps_best_kept = []
for hyp in hyps:
ys = hyp['yseq'] # 1 x i
# last_id = ys.cpu().numpy()[0][-1]
# freq = bigram_freq[last_id]
# freq = torch.log(torch.from_numpy(freq))
# # print(freq.dtype)
# freq = freq.type(torch.float).to(device)
# print(freq.dtype)
# print('freq.size(): ' + str(freq.size()))
# print('freq: ' + str(freq))
# -- Prepare masks
non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
slf_attn_mask = get_subsequent_mask(ys)
# -- Forward
dec_output = self.dropout(
self.tgt_word_emb(ys) * self.x_logit_scale +
self.positional_encoding(ys))
for dec_layer in self.layer_stack:
dec_output, _, _ = dec_layer(
dec_output, encoder_outputs,
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask,
dec_enc_attn_mask=None)
seq_logit = self.tgt_word_prj(dec_output[:, -1])
# local_scores = F.log_softmax(seq_logit, dim=1)
local_scores = F.log_softmax(seq_logit, dim=1)
# print('local_scores.size(): ' + str(local_scores.size()))
# local_scores += freq
# print('local_scores: ' + str(local_scores))
# topk scores
local_best_scores, local_best_ids = torch.topk(
local_scores, beam, dim=1)
for j in range(beam):
new_hyp = {}
new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long()
new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(hyps_best_kept,
key=lambda x: x['score'],
reverse=True)[:beam]
# end for hyp in hyps
hyps = hyps_best_kept
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
for hyp in hyps:
hyp['yseq'] = torch.cat([hyp['yseq'],
torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()],
dim=1)
# add ended hypothes to a final list, and removed them from current hypothes
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][0, -1] == self.eos_id:
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
hyps = remained_hyps
# if len(hyps) > 0:
# print('remeined hypothes: ' + str(len(hyps)))
# else:
# print('no hypothesis. Finish decoding.')
# break
#
# for hyp in hyps:
# print('hypo: ' + ''.join([char_list[int(x)]
# for x in hyp['yseq'][0, 1:]]))
# end for i in range(maxlen)
nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
:min(len(ended_hyps), nbest)]
# compitable with LAS implementation
for hyp in nbest_hyps:
hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
return nbest_hyps
class DecoderLayer(nn.Module):
''' Compose with three layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(DecoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
dec_output *= non_pad_mask
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output *= non_pad_mask
dec_output = self.pos_ffn(dec_output)
dec_output *= non_pad_mask
return dec_output, dec_slf_attn, dec_enc_attn