实用指南:Attention计算代码详解

Softmax函数

Softmax函数用于将值变成一个概率分布(和为1)。

softmax 的核心作用可以概括为三个方面:

1. 把一组实数转换成概率分布

softmax 会把任意向量转成非负且总和为 1 的结果,常用来表示概率。
这样模型输出可以被解释为不同类别的概率。

2. 放大差异

softmax 对大的值更敏感,小的值会被压得更小,大的值会更突出。
这让模型更容易做出明确选择。

3. 在训练中提供可微分的概率输出

分类任务通常用交叉熵损失,而交叉熵需要概率分布,softmax 刚好提供了连续可微的概率。
这让模型可以通过梯度下降训练。

# numpy生成矩阵
np.random.randn(2, 3, 4)
# softmax函数示例
# Generate 100 dimensional vector
np.random.seed(0)
x = np.random.randn(100)
# Compute softmax
x_max = np.max(x)
e_x = np.exp(x - x_max)
softmax_x = e_x / np.sum(e_x)

实际使用时的函数:

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    e_x = np.exp(x - x_max)
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

Attention的示意图

假设有一个query Q,维度是8,

5个key:维度也是8,其对应的5个value,维度是10

首先Q会和每个key算内积,得到的值经过softmax就是attention的概率分布(注意力权重,是概率分布)

经过softmax之后,就是这个query对5个key的注意力权重分布。

根据这个权重,将其分配到每个value上,就得到了query最终的结果。

最后,这个计算的可视化展示如下:

对于多个query(比如有三个),其和五个keys计算attention时(QK^T)的计算方式如下:

由上面的过程可以看出,广义的attention计算中:

  • KV的个数必须一样(实际体现在KV都是由同一组token得到)
  • Q的个数可以和KV不一样,但是Q和K的维度要一样(计算内积,用于Q对于每个key的计算注意力权重)
  • V的维度是自由的,其维度为最后的到的token的维度。同时,最后的到的token的数量为Q的数量

单头注意力机制的numpy代码实现:

numpy函数用法:

np.where

np.matmul()

批量矩阵乘法。当输入是三维或更高维,会对前面的维度广播(batch dim 不参与乘法,只逐个对应计算)

np.transpose()

import numpy as np
def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    e_x = np.exp(x - x_max)
    return e_x / np.sum(e_x, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, seq_q, d_k)
    K: (batch, seq_k, d_k)
    V: (batch, seq_k, d_v)
    mask: (batch, seq_q, seq_k) or None. mask entries True means masked (ignore)
    returns: output (batch, seq_q, d_v), attention_weights (batch, seq_q, seq_k)
    """
    d_k = Q.shape[-1]
    # 1) Q @ K^T
    scores = np.matmul(Q, K.transpose(0, 2, 1))  # (batch, seq_q, seq_k)
    # 2) scale
    scores = scores / np.sqrt(float(d_k))
    # 3) apply mask if present
    if mask is not None:
        # set masked positions to large negative value so softmax ~ 0
        scores = np.where(mask, -1e9, scores)
    # 4) softmax to get attention weights
    attn = softmax(scores, axis=-1) # 在seq_k的维度上做softmax,(batch, seq_q, seq_k)
    # 5) weighted sum with V
    output = np.matmul(attn, V)  # (batch, seq_q, d_v)
    return output, attn
# example
if __name__ == "__main__":
    np.random.seed(0)
    B, Tq, Tk, d = 2, 3, 4, 8 # d是dimension
    Q = np.random.randn(B, Tq, d)
    K = np.random.randn(B, Tk, d)
    V = np.random.randn(B, Tk, d)
    out, weights = scaled_dot_product_attention(Q, K, V)
    print("out shape:", out.shape)         # (2, 3, 8)
    print("attn shape:", weights.shape)    # (2, 3, 4)

多头注意力的pytorch实现:

tensor.transpose():用于交换两个维度

tensor.masked_fill()

tensor.view()

tensor.unsqueeze()

.contiguous()

如果一个tensor在transpose之后还需要进一步合并或者分裂维度时,就需要用contiguous

nn.Dropout()

代码实现

单个头的注意力计算代码,基于Q,K,V

import torch
import torch.nn.functional as F
def scaled_dot_product_attention_torch(Q, K, V, mask=None, dropout=None):
    """
    Q: (batch, heads, seq_q, d_k)
    K: (batch, heads, seq_k, d_k)
    V: (batch, heads, seq_k, d_v)
    mask: (batch, 1, seq_q, seq_k) or (batch, heads, seq_q, seq_k) or None
    returns: output (batch, heads, seq_q, d_v), attn (batch, heads, seq_q, seq_k)
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=Q.dtype, device=Q.device))
    if mask is not None:
        # mask entries should be True for positions to mask
        scores = scores.masked_fill(mask, float("-1e9"))
    attn = F.softmax(scores, dim=-1) # softmax 从 torch.nn.functional.F中得到
    if dropout is not None:
        attn = dropout(attn)
    output = torch.matmul(attn, V)
    return output, attn
# quick test
if __name__ == "__main__":
    torch.manual_seed(0)
    B, H, Tq, Tk, d_k = 2, 2, 3, 4, 16
    Q = torch.randn(B, H, Tq, d_k)
    K = torch.randn(B, H, Tk, d_k)
    V = torch.randn(B, H, Tk, d_k)
    # no mask
    out, att = scaled_dot_product_attention_torch(Q, K, V)
    print(out.shape, att.shape)  # (2, 2, 3, 16), (2, 2, 3, 4)

完整代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
    '''
    multihead,对不同的维度,有不同的head
    '''
    def __init__(self, d_model, num_heads, dropout=0.0):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
    def _split_heads(self, x):
        # x: (batch, seq, d_model) -> (batch, heads, seq, d_k)
        B, T, _ = x.size()
        x = x.view(B, T, self.num_heads, self.d_k) # d_k 是每一小份的维度
        return x.transpose(1, 2)
    def _combine_heads(self, x):
        # x: (batch, heads, seq, d_k) -> (batch, seq, d_model)
        x = x.transpose(1, 2).contiguous()
        B, T, _, _ = x.size()
        return x.view(B, T, self.d_model)
    def forward(self, query, key, value, mask=None):
        """
        query/key/value: (batch, seq, d_model)
        mask: (batch, seq_q, seq_k) boolean where True means masked
        returns: outputs (batch, seq_q, d_model), attn_weights (batch, heads, seq_q, seq_k)
        """
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
        Q = self._split_heads(Q)
        K = self._split_heads(K)
        V = self._split_heads(V)
        # adapt mask shape: (batch, seq_q, seq_k) -> (batch, 1, seq_q, seq_k)
        if mask is not None:
            mask = mask.unsqueeze(1)
        out, attn = scaled_dot_product_attention_torch(Q, K, V, mask=mask, dropout=self.dropout)
        # out: (batch, heads, seq_q, d_k)
        out = self._combine_heads(out)  # (batch, seq_q, d_model)
        out = self.w_o(out) # 对拼接在一起之后的token再进行一个线性变换
        return out, attn
# example usage
if __name__ == "__main__":
    B, T, S, d_model, heads = 2, 5, 6, 64, 8
    mha = MultiHeadAttention(d_model, heads, dropout=0.1)
    x_q = torch.randn(B, T, d_model)  # query length T
    x_kv = torch.randn(B, S, d_model) # key/value length S
    # optional mask: mask padded positions in key (True = mask)
    mask = torch.zeros(B, T, S, dtype=torch.bool)  # no mask here
    out, attn = mha(x_q, x_kv, x_kv, mask=mask)
    print("out", out.shape)   # (B, T, d_model)
    print("attn", attn.shape) # (B, heads, T, S)

posted on 2026-01-02 08:50  ljbguanli  阅读(0)  评论(0)    收藏  举报