线性注意力(Linear Attention, LA)学习

定义:采用矩阵乘法结合律的特点,所设计的一种\(\mathcal{O}(n)\)时间复杂度的注意力机制

一、softmax注意力机制

设输入特征\(x\)大小为\(N×F\),其是由\(N\)个维度为\(F\)的特征向量构成的序列(往往\(N\gg F\)

Transformer的一般表示形式为:

\[T(x) = f(A(x) + x) \tag{1} \]

其中,\(A(\cdot)\)表示注意力机制,\(f(\cdot)\)表示前馈处理。

针对\(A(\cdot)\),首先,将\(W_Q \in \mathbb{R}^{F \times D}\)\(W_K \in \mathbb{R}^{F \times D}\)\(W_V \in \mathbb{R}^{F \times M}\)作用于\(x\)投影得到对应的\(QKV\),此处的\(QK\)相乘是计算二者之间的相似性,并通过softmax得到相似性权重矩阵作用于\(V\)来修正比例,公式如下:

\[A(x)=V_i'=softmax(\frac{xW_Q(xW_K)^T}{\sqrt{D}})xW_V=softmax(\frac{QK^T}{\sqrt{D}})V \tag{2} \]

二、线性注意力机制

1.基础解释

根据\(QK\)计算相似性的特点,在不考虑因果性的前提下,广义上可表示为:

\[V_i' = \frac{\sum_{j=1}^{N} \operatorname{sim}(Q_i, K_j) V_j}{\sum_{j=1}^{N} \operatorname{sim}(Q_i, K_j)} \tag{3} \]

\(\operatorname{sim}(q, k)=\exp \left( \frac{q^T k}{\sqrt{D}} \right)\)时,公式(3)等价于公式(2)

softmax的一个特点是满足“输出非负”,因为需要的是一个相似性权重矩阵(像是通过打分来调整\(V\)中数据的分配比例)

因此通过某种非负相似度映射函数即可将\(QK\)拆分开(许多相似度函数可以表示为高维空间的内积),论文中采用的公式如下

\[\phi (x) = \text{elu}(x) + 1 \tag{4} \]

更新后的注意力公式如下(采用矩阵乘法交换律):

\[V_{i}^{\prime} = \frac{\sum_{j=1}^{N} \phi(Q_{i})^{T} \phi(K_{j}) V_{j}}{\sum_{j=1}^{N} \phi(Q_{i})^{T} \phi(K_{j})}=\frac{\phi(Q_i)^T \sum_{j=1}^N \phi(K_j) V_j^T}{\phi(Q_i)^T \sum_{j=1}^N \phi(K_j)} \tag{5} \]

公式(2)的时间复杂度为\(\mathcal{O}(N^2max(D,M))\),而优化后的公式(5),首先计算维度为\(C\)的特征映射,最终时间复杂度为\(\mathcal{O}(NCM)\)

2.因果掩码

在考虑因果性的情况下,公式(5)可化简为:

\[V_{i}^{\prime} = \frac{\phi(Q_{i})^{T} \sum_{j=1}^{i} \phi(K_{j}) V_{j}^{T}}{\phi(Q_{i})^{T} \sum_{j=1}^{i} \phi(K_{j})} \tag{6} \]

\(S_{i} = \sum_{j=1}^{i} \phi(K_{j}) V_{j}^{T}\)\(Z_{i} = \sum_{j=1}^{i} \phi(K_{j})\),进一步化简为:

\[V_{i}^{\prime} = \frac{\phi (Q_{i})^{T} S_{i}}{\phi (Q_{i})^{T} Z_{i}} \tag{7} \]

其中,\(S_{i} = S_{i-1} + \phi(K_{i}) V_{i}^{T}\)\(Z_{i} = Z_{i-1} + \phi(K_{i})\),由此可见其与传统RNN之间的相似之处,通过这种方式,能在\(S_{i-1}\)\(Z_{i-1}\)的基础上通过常数时间计算出\(S_{i}\)\(Z_{i}\)

3.梯度计算

在进行梯度计算时,要存储所有的中间值\(S_{i}\),这使得内存消耗增加为原来的\(max(D,M)\)倍,为此本文通过累积和的方式计算公式(6)给定分子\(\bar{V}_i\)和标量损失函数关于该分子的梯度 \(\nabla_{\bar{V}_i} \mathcal{L}\),以通过线性时间和恒定内存计算因果性序列的前向传播(做题)与反向传播(纠错),公式如下:

\[\nabla_{\phi(Q_i)} \mathcal{L} = \frac{\partial \mathcal{L}}{\partial \bar{V}_i} \cdot \frac{\partial \bar{V}_i}{\partial \phi(Q_i)} = \nabla_{\bar{V_i}} \mathcal{L} \left( \sum_{j=1}^i \phi(K_j) V_j^T \right)^T \tag{8} \]

同理:

\[\nabla_{\phi(K_i)} \mathcal{L} = \left( \sum_{j=i}^{N} \phi(Q_j) \left( \nabla_{\bar{v}_j} \mathcal{L} \right)^T \right) V_i \tag{9} \]

\[\nabla_{V_i} \mathcal{L} = \left( \sum_{j=i}^{N} \phi(Q_j) \left( \nabla_{V_j} \mathcal{L} \right)^T \right)^T \phi(K_i) \tag{10} \]

公式(9)和公式(10)计算时累加为\(i \rightarrow N\),可以理解为送快递,所有的\(j>i\)的位置都会收到\(i\)的影响,因为其的包裹都是来自\(i\),故损失\(\mathcal{L}\)\(i\)的敏感度(梯度)必须考虑它对所有\(j>i\)的影响。
综上,其具有线性时间\(\mathcal{O}(NCM)\)和恒定内存\(\mathcal{O}(Nmax(C,M)\)

三、不足

\(S_{i}\)\(Z_{i}\)是无衰减的直接累加,所有信息平等叠加,早期的信息容易被后期噪声淹没,因此,需要通过门控、非线性增强、位置编码等方式来弥补此问题。后续的Mamba一定程度上也可以说是线性注意力的一种改进。

原论文

posted @ 2026-01-21 12:11  O_obk  阅读(3)  评论(0)    收藏  举报