transformer中每个阶段的张量形状

通用部分

批次数量: batch_size
句子长度:seq_len
模型维度:d_model
多头头数:n_head
词典总数:vocab_num

输入时:
input_ids : [batch_size,seq_len]
input_mask: [batch_size,seq_len]
target_ids: [batch_size,seq_len]
target_mask: [batch_size,seq_len]

词向量 + 位置编码
input_ids : [batch_size,seq_len,d_model]
target_ids: [batch_size,seq_len, d_model]

拆分多头:
input_ids : [batch_size,seq_len,n_head,d_k]
target_ids: [batch_size,seq_len, n_head,d_k]

d_k = d_model/n_head
计算注意力之前的位置交换:
input_ids : [batch_size,n_head,seq_len,d_k]
target_ids: [batch_size,n_head,seq_len,d_k]

编码器部分

编码器做相似度计算:
Q = input_ids
K = input_ids
V = input_ids

相似度矩阵A
A = Q@KT : [batch_size,n_head,seq_len,seq_len]
第一个为Q的seq_len,第二个为K的seq_len

对A做掩码
此时 input_mask: [batch_size,seq_len]
对 input_mask做 unsqueeze 得到
input_mask: [batch_size,1,1,seq_len]
掩码完成后
A: [batch_size,n_head,seq_len,seq_len]

A乘V计算注意力
此时V: [batch_size,n_head,seq_len,d_k]
注意力矩阵O
O=A@V : [batch_size,n_head,seq_len,d_k]
seq_len为相似度矩阵A的seq_len

将多头和句子序列长度互换
O = [batch_size,seq_len,n_head,d_k]
合并多头:
O = [batch_size,seq_len,d_model]
所以,编码器计算出的最终注意力形状为
O = [batch_size,seq_len,d_model]

解码器部分

此时解码器的
target_ids: [batch_size,n_head,seq_len,d_k]
target_mask: [batch_size,seq_len]

解码器做相似度计算
Q = target_ids
K = target_ids
V = target_ids

相似度矩阵A
A = Q@KT : [batch_size,n_head,seq_len,seq_len]
第一个为Q的seq_len,第二个为K的seq_len

对A做因果掩码
生成下三角掩码矩阵:
M: [seq_len,seq_len]

对掩码矩阵做扩充 unsqueeze
M:[1,seq_len,seq_len]

将target_mask进行扩充
target_mask: [batch_size,1,seq_len]

将target_mask 与 M 逐乘得到最终掩码
M=target_mask * M
M: [batch_size,seq_len,seq_len]

扩充M
M: [batch_size,1,seq_len,seq_len]

对相似度矩阵A做掩码
A:[batch_size,n_head,seq_len,seq_len]

A和V相乘得到目标句子的注意力
O=A@V : [batch_size,n_head,seq_len,d_k]

将多头和句子序列长度互换
O = [batch_size,seq_len,n_head,d_k]
合并多头:
O = [batch_size,seq_len,d_model]
所以,解码器多头计算出的最终注意力形状为
O = [batch_size,seq_len,d_model]

交叉注意力

定义:
编码器的注意为 O1 : [batch_size,seq_len,d_model]
解码器的注意为 O2 : [batch_size,seq_len,d_model]

则:
Q = O2
K = O1
V = O1

计算相似度
A = Q@KT: [batch_size,seq_len,seq_len]
第一个seq_len 来自Q, 即来自目标句子, 第二个seq_len 来自K, 即原句子

对A用解码器侧输入的input_mask做掩码
input_mask: [batch_size,seq_len]
对input_mask做填充
input_mask: [batch_size,1,seq_len]
掩码后A的形状为
A: [batch_size,seq_len,seq_len]

最终的交叉注意力:
O = A@V: [batch_size,seq_len,d_model]

即得到目标句子中每个token在d_model上的概率分布数据

出口

再通过 一个线性层将 O: [batch_size,seq_len,d_model] 转换为 [batch_size,seq_len,vocab_num ]的形状出去

posted @ 2025-03-09 23:22  xiezhengcai  阅读(76)  评论(0)    收藏  举报