MHA学习

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads  # 每个注意力头的维度

        # 为 Q, K, V 准备线性变换
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)

        # 输出线性层
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, D = x.shape  # B: batch size, T: 序列长度, D: 特征维度(应该等于 d_model)

        # 线性映射
        Q = self.q_linear(x)  # (B, T, d_model)
        K = self.k_linear(x)
        V = self.v_linear(x)

        # 将 Q, K, V 划分成多个头(多头并行)
        Q = Q.view(B, T, self.num_heads, self.d_head).transpose(1, 2)  # (B, num_heads, T, d_head)
        K = K.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        # 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head ** 0.5)  # (B, num_heads, T, T)
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)  # (B, num_heads, T, d_head)

        # 合并所有头
        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)  # (B, T, d_model)

        # 最后的线性输出层
        out = self.out_proj(context)
        return out
posted @ 2025-06-05 22:21  咖啡加油条  阅读(13)  评论(0)    收藏  举报