transformer中每个阶段的张量形状
通用部分
批次数量: batch_size
句子长度:seq_len
模型维度:d_model
多头头数:n_head
词典总数:vocab_num
输入时:
input_ids : [batch_size,seq_len]
input_mask: [batch_size,seq_len]
target_ids: [batch_size,seq_len]
target_mask: [batch_size,seq_len]
词向量 + 位置编码
input_ids : [batch_size,seq_len,d_model]
target_ids: [batch_size,seq_len, d_model]
拆分多头:
input_ids : [batch_size,seq_len,n_head,d_k]
target_ids: [batch_size,seq_len, n_head,d_k]
d_k = d_model/n_head
计算注意力之前的位置交换:
input_ids : [batch_size,n_head,seq_len,d_k]
target_ids: [batch_size,n_head,seq_len,d_k]
编码器部分
编码器做相似度计算:
Q = input_ids
K = input_ids
V = input_ids
相似度矩阵A
A = Q@KT : [batch_size,n_head,seq_len,seq_len]
第一个为Q的seq_len,第二个为K的seq_len
对A做掩码
此时 input_mask: [batch_size,seq_len]
对 input_mask做 unsqueeze 得到
input_mask: [batch_size,1,1,seq_len]
掩码完成后
A: [batch_size,n_head,seq_len,seq_len]
A乘V计算注意力
此时V: [batch_size,n_head,seq_len,d_k]
注意力矩阵O
O=A@V : [batch_size,n_head,seq_len,d_k]
seq_len为相似度矩阵A的seq_len
将多头和句子序列长度互换
O = [batch_size,seq_len,n_head,d_k]
合并多头:
O = [batch_size,seq_len,d_model]
所以,编码器计算出的最终注意力形状为
O = [batch_size,seq_len,d_model]
解码器部分
此时解码器的
target_ids: [batch_size,n_head,seq_len,d_k]
target_mask: [batch_size,seq_len]
解码器做相似度计算
Q = target_ids
K = target_ids
V = target_ids
相似度矩阵A
A = Q@KT : [batch_size,n_head,seq_len,seq_len]
第一个为Q的seq_len,第二个为K的seq_len
对A做因果掩码
生成下三角掩码矩阵:
M: [seq_len,seq_len]
对掩码矩阵做扩充 unsqueeze
M:[1,seq_len,seq_len]
将target_mask进行扩充
target_mask: [batch_size,1,seq_len]
将target_mask 与 M 逐乘得到最终掩码
M=target_mask * M
M: [batch_size,seq_len,seq_len]
扩充M
M: [batch_size,1,seq_len,seq_len]
对相似度矩阵A做掩码
A:[batch_size,n_head,seq_len,seq_len]
A和V相乘得到目标句子的注意力
O=A@V : [batch_size,n_head,seq_len,d_k]
将多头和句子序列长度互换
O = [batch_size,seq_len,n_head,d_k]
合并多头:
O = [batch_size,seq_len,d_model]
所以,解码器多头计算出的最终注意力形状为
O = [batch_size,seq_len,d_model]
交叉注意力
定义:
编码器的注意为 O1 : [batch_size,seq_len,d_model]
解码器的注意为 O2 : [batch_size,seq_len,d_model]
则:
Q = O2
K = O1
V = O1
计算相似度
A = Q@KT: [batch_size,seq_len,seq_len]
第一个seq_len 来自Q, 即来自目标句子, 第二个seq_len 来自K, 即原句子
对A用解码器侧输入的input_mask做掩码
input_mask: [batch_size,seq_len]
对input_mask做填充
input_mask: [batch_size,1,seq_len]
掩码后A的形状为
A: [batch_size,seq_len,seq_len]
最终的交叉注意力:
O = A@V: [batch_size,seq_len,d_model]
即得到目标句子中每个token在d_model上的概率分布数据
出口
再通过 一个线性层将 O: [batch_size,seq_len,d_model] 转换为 [batch_size,seq_len,vocab_num ]的形状出去

浙公网安备 33010602011771号