线性注意力(Linear Attention, LA)学习
定义:采用矩阵乘法结合律的特点,所设计的一种\(\mathcal{O}(n)\)时间复杂度的注意力机制
一、softmax注意力机制
设输入特征\(x\)大小为\(N×F\),其是由\(N\)个维度为\(F\)的特征向量构成的序列(往往\(N\gg F\))
Transformer的一般表示形式为:
其中,\(A(\cdot)\)表示注意力机制,\(f(\cdot)\)表示前馈处理。
针对\(A(\cdot)\),首先,将\(W_Q \in \mathbb{R}^{F \times D}\)、\(W_K \in \mathbb{R}^{F \times D}\)和\(W_V \in \mathbb{R}^{F \times M}\)作用于\(x\)投影得到对应的\(QKV\),此处的\(QK\)相乘是计算二者之间的相似性,并通过softmax得到相似性权重矩阵作用于\(V\)来修正比例,公式如下:
二、线性注意力机制
1.基础解释
根据\(QK\)计算相似性的特点,在不考虑因果性的前提下,广义上可表示为:
当\(\operatorname{sim}(q, k)=\exp \left( \frac{q^T k}{\sqrt{D}} \right)\)时,公式(3)等价于公式(2)
softmax的一个特点是满足“输出非负”,因为需要的是一个相似性权重矩阵(像是通过打分来调整\(V\)中数据的分配比例)
因此通过某种非负相似度映射函数即可将\(QK\)拆分开(许多相似度函数可以表示为高维空间的内积),论文中采用的公式如下
更新后的注意力公式如下(采用矩阵乘法交换律):
公式(2)的时间复杂度为\(\mathcal{O}(N^2max(D,M))\),而优化后的公式(5),首先计算维度为\(C\)的特征映射,最终时间复杂度为\(\mathcal{O}(NCM)\)。
2.因果掩码
在考虑因果性的情况下,公式(5)可化简为:
令\(S_{i} = \sum_{j=1}^{i} \phi(K_{j}) V_{j}^{T}\),\(Z_{i} = \sum_{j=1}^{i} \phi(K_{j})\),进一步化简为:
其中,\(S_{i} = S_{i-1} + \phi(K_{i}) V_{i}^{T}\)、\(Z_{i} = Z_{i-1} + \phi(K_{i})\),由此可见其与传统RNN之间的相似之处,通过这种方式,能在\(S_{i-1}\)和\(Z_{i-1}\)的基础上通过常数时间计算出\(S_{i}\)和\(Z_{i}\)
3.梯度计算
在进行梯度计算时,要存储所有的中间值\(S_{i}\),这使得内存消耗增加为原来的\(max(D,M)\)倍,为此本文通过累积和的方式计算公式(6)给定分子\(\bar{V}_i\)和标量损失函数关于该分子的梯度 \(\nabla_{\bar{V}_i} \mathcal{L}\),以通过线性时间和恒定内存计算因果性序列的前向传播(做题)与反向传播(纠错),公式如下:
同理:
公式(9)和公式(10)计算时累加为\(i \rightarrow N\),可以理解为送快递,所有的\(j>i\)的位置都会收到\(i\)的影响,因为其的包裹都是来自\(i\),故损失\(\mathcal{L}\)对\(i\)的敏感度(梯度)必须考虑它对所有\(j>i\)的影响。
综上,其具有线性时间\(\mathcal{O}(NCM)\)和恒定内存\(\mathcal{O}(Nmax(C,M)\)
三、不足
\(S_{i}\)和\(Z_{i}\)是无衰减的直接累加,所有信息平等叠加,早期的信息容易被后期噪声淹没,因此,需要通过门控、非线性增强、位置编码等方式来弥补此问题。后续的Mamba一定程度上也可以说是线性注意力的一种改进。

浙公网安备 33010602011771号