• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
pytorch(10.3) 多头注意

 

 

 

10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation (d2l.ai)

 

 

Multi-Head Attention | 算法 + 代码_哔哩哔哩_bilibili

 

 

 

 

 代码实现

 

x[1,4,2] 1几个样本(句子) 4 预测步长(4个单词) 2每个单词的编码后特征长度

 

 

 

from math import sqrt
import torch
import torch.nn as nn


class MultiHeadSelfAttention(nn.Module):

    def __init__(self, dim_in, d_model, num_heads=3):
        super(MultiHeadSelfAttention, self).__init__()

        self.dim_in = dim_in
        self.d_model = d_model
        self.num_heads = num_heads

        # 维度必须能被num_head 整除
        assert d_model % num_heads == 0, "d_model must be multiple of num_heads"

        # 定义线性变换矩阵
        self.linear_q = nn.Linear(dim_in, d_model)
        self.linear_k = nn.Linear(dim_in, d_model)
        self.linear_v = nn.Linear(dim_in, d_model)
        self.scale = 1 / sqrt(d_model // num_heads)

        # 最后的线性层
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: tensor of shape (batch, n, dim_in)
        batch, n, dim_in = x.shape
        assert dim_in == self.dim_in

        nh = self.num_heads
        dk = self.d_model // nh  # dim_k of each head

        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)
        v = self.linear_v(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)

        dist = torch.matmul(q, k.transpose(2, 3)) * self.scale  # batch, nh, n, n
        dist = torch.softmax(dist, dim=-1)  # batch, nh, n, n

        att = torch.matmul(dist, v)  # batch, nh, n, dv
        att = att.transpose(1, 2).reshape(batch, n, self.d_model)  # batch, n, dim_v

        # 最后通过一个线性层进行变换
        output = self.fc(att)

        return output


x = torch.rand((1, 4, 2))
multi_head_att = MultiHeadSelfAttention(x.shape[2], 6, 3)  # (6, 3)
output = multi_head_att(x)

  

 

 101. 101 - 101 Multi-head的作用_哔哩哔哩_bilibili

 

 

 

 

 

最后一层处理下 压缩下维度

 

 

 

 

 

 

posted on 2023-10-23 16:02  MKT-porter  阅读(42)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3