多头注意力机制解读
由于许多教程在讲解多头注意力机制的时候,只是单独地讲了将通道数分开的操作,并没有非常明确的讲出多头注意力机制与单头注意力机制的区别,这里通过一个简单的例子说明一下:
这里假设输入为3个token,每个token被编码为4维的向量,得到:
| a1,1 | a1,2 | a1,3 | a1,4 |
| a2,1 | a2,2 | a2,3 | a2,4 |
| a3,1 | a3,2 | a3,3 | a3,4 |
单头注意力机制:
由于QKV矩阵的运算机制一样,这里就以Q为例,通常WQ的维度与编码为度一直,这里方便理解,就假设WQ是一个4*4的矩阵,得到Q:
| q1,1 | q1,2 | q1,3 | q1,4 |
| q2,1 | q2,2 | q2,3 | q2,4 |
| q3,1 | q3,2 | q3,3 | q3,4 |
K同理,也是得到一个3*4的矩阵,所以注意力矩阵计算得到S=QKT:
| s1,1 | s1,2 | s1,3 |
| s2,1 | s2,2 | s2,3 |
| s3,1 | s3,2 | s3,3 |
其中,s1,1=q1,1*k1,1+q1,2*k1,2+q1,3*k1,3+q1,4*k1,4
计算最终的矩阵Z=softmax(S)*V(这里可以先忽略除以维度的平方根,实际不影响结果),由3*3的矩阵softmax(S)和3*4的矩阵V相乘,得到最终的3*4矩阵Z,以α来表示softmax(S):
|
|
最终得到的z1,1=α1,1*v1,2+α1,2*v2,1+α1,3*v3,1
多头注意力机制:
这里为了方便理解,以2头的注意力机制为例,通过两个WQ得到两个Q的矩阵,为了方便对比,下表与单头注意力机制的下表对齐:
|
|
K也是同理,考虑第一个头,由于第一个头的Q和K是3*2的矩阵,所以得到的注意力矩阵S还是3*3的矩阵,这里直接列出两个头的S矩阵:
|
|
其中,s1,11=q11*k1,1+q1,2*k1,2
s1,12=q1,3*k1,3+q1,4*k1,4 ,这里已经能够直观的看出多头注意力在计算的过程中,不同的注意力头之间的独立的了;
继续推导出最终的Z,这里列出第一个头的Z,得到的是3*2的矩阵:
|
|
最终得到的z1,1=α11,1*v1,1+α11,2*v2,1+α11,3*v3,1
多头注意力得到的所有Z最后会乘上一个权重矩阵Wo再拼接起来,才能得到和单头注意力维度相同的结果矩阵Z,也就是把不同分类头的信息通过矩阵Wo融合,所以多头注意力会比单头注意力多一个权重矩阵Wo的参数,网络的表达也更加丰富。但是多头注意力和单头注意力的结果对比,最大的区别在于α不同,也就是S矩阵的计算方法不同(个人理解),单头注意力的注意力矩阵S会考虑token中所有通道的Q和K,而多头注意力的注意力矩阵S只会考虑这个头所对应的通道的Q和K,这样就能够实现从多个低维度的角度来实现信息融合和特征提取。

浙公网安备 33010602011771号