多头注意力机制解读

  由于许多教程在讲解多头注意力机制的时候,只是单独地讲了将通道数分开的操作,并没有非常明确的讲出多头注意力机制与单头注意力机制的区别,这里通过一个简单的例子说明一下:

  这里假设输入为3个token,每个token被编码为4维的向量,得到:

a1,1 a1,2 a1,3 a1,4
a2,1 a2,2 a2,3 a2,4
a3,1 a3,2 a3,3 a3,4

单头注意力机制:

  由于QKV矩阵的运算机制一样,这里就以Q为例,通常WQ的维度与编码为度一直,这里方便理解,就假设WQ是一个4*4的矩阵,得到Q:

q1,1 q1,2 q1,3 q1,4
q2,1 q2,2 q2,3 q2,4
q3,1 q3,2 q3,3 q3,4

 

  K同理,也是得到一个3*4的矩阵,所以注意力矩阵计算得到S=QKT:

s1,1 s1,2 s1,3
s2,1 s2,2 s2,3
s3,1 s3,2 s3,3

  其中,s1,1=q1,1*k1,1+q1,2*k1,2+q1,3*k1,3+q1,4*k1,4

  计算最终的矩阵Z=softmax(S)*V(这里可以先忽略除以维度的平方根,实际不影响结果),由3*3的矩阵softmax(S)和3*4的矩阵V相乘,得到最终的3*4矩阵Z,以α来表示softmax(S):

  

α1,1 α1,2 α1,3
α2,1 α2,2 α2,3
α3,1 α3,2 α3,3
v1,1 v1,2 v1,3 v1,4
v2,1 v2,2 v2,3 v2,4
v3,1 v3,2 v3,3 v3,4

  最终得到的z1,11,1*v1,21,2*v2,11,3*v3,1

多头注意力机制:

  这里为了方便理解,以2头的注意力机制为例,通过两个WQ得到两个Q的矩阵,为了方便对比,下表与单头注意力机制的下表对齐:

q1,1 q1,2
q2,1 q2,2
q3,1 q3,2
q1,3 q1,4
q2,3 q2,4
q3,3 q3,4

 

  K也是同理,考虑第一个头,由于第一个头的Q和K是3*2的矩阵,所以得到的注意力矩阵S还是3*3的矩阵,这里直接列出两个头的S矩阵:

s1,11 s1,21 s1,31
s2,11 s2,21 s2,31
s3,11 s3,21 s3,31
s1,12 s1,12 s1,12
s2,12 s2,22 s2,32
s3,32 s3,32 s3,32

  其中,s1,11=q11*k1,1+q1,2*k1,2

       s1,12=q1,3*k1,3+q1,4*k1,4 ,这里已经能够直观的看出多头注意力在计算的过程中,不同的注意力头之间的独立的了;

  继续推导出最终的Z,这里列出第一个头的Z,得到的是3*2的矩阵:

α11,1 α11,2 α11,3
α12,1 α12,2 α12,3
α13,1 α13,2 α13,3
v1,1 v1,2
v2,1 v2,2
v3,1 v3,2

  最终得到的z1,111,1*v1,111,2*v2,111,3*v3,1

  多头注意力得到的所有Z最后会乘上一个权重矩阵Wo再拼接起来,才能得到和单头注意力维度相同的结果矩阵Z,也就是把不同分类头的信息通过矩阵Wo融合,所以多头注意力会比单头注意力多一个权重矩阵Wo的参数,网络的表达也更加丰富。但是多头注意力和单头注意力的结果对比,最大的区别在于α不同,也就是S矩阵的计算方法不同(个人理解),单头注意力的注意力矩阵S会考虑token中所有通道的Q和K,而多头注意力的注意力矩阵S只会考虑这个头所对应的通道的Q和K,这样就能够实现从多个低维度的角度来实现信息融合和特征提取。

posted @ 2025-07-02 11:50  爱露查  阅读(79)  评论(0)    收藏  举报