attention 代码

 

import torch
import torch.nn as nn
from math import sqrt

class SelfAttention(nn.Module):
    def __init__(self, input_dim, dim_q_k, dim_v):
        self.q = nn.Linear(input_dim, dim_q_k)
        self.k = nn.Linear(input_dim, dim_q_k)
        self.v = nn.Linear(input_dim, dim_v)

        self.norm = sqrt(dim_q_k)
    
    def forward(self, x):
        # (b, s, input_dim) -> (b, s, dim_q_k)
        q = self.q(x)
        k = self.k(x)
        # (b, s, input_dim) -> (b, s, dim_v)
        v = self.v(x)

        # (b, s, s)
        d = torch.bmm(q, k.transpose(1, 2))/self.norm
        softmax_out = torch.softmax(d, dim=-1)

        # (b, s, dim_v)
        attn_out = torch.bmm(softmax_out, v)
        return attn_out




实际上是target attention

import torch
import torch.nn as nn
from math import sqrt

class CrossAttention(nn.Module):
    def __init__(self, input_dim, dim_q_k, dim_v):
        self.q = nn.Linear(input_dim, dim_q_k)
        self.k = nn.Linear(input_dim, dim_q_k)
        self.v = nn.Linear(input_dim, dim_v)

        self.norm = sqrt(dim_q_k)
        
    def forward(self, encoder_input, decoder_input):
        # (b, s, input_dim) -> (b, s, dim_q_k)
        q = self.q(encoder_input)
        k = self.k(decoder_input)
        # (b, s, input_dim) -> (b, s, dim_v)
        v = self.v(decoder_input)

        # (b, s, s)
        a = torch.bmm(q, k.transpose(1, 2))/self.norm
        softmax_out = torch.softmax(a, dim=-1)

        # (b, s, dim_v)
        cross_attn_out = torch.bmm(softmax_out, v)
        return cross_attn_out





class MultiHeadSelfAttention(nn.Module):
    def __init__(self, input_dim, head, dim_q_k, dim_v):
        self.head = head
        self.dim_per_head_q_k = dim_q_k // head
        self.dim_per_head_v = dim_v // head

        self.q = nn.Linear(input_dim, self.dim_per_head_q_k)
        self.k = nn.Linear(input_dim, self.dim_per_head_q_k)
        self.v = nn.Linear(input_dim, self.dim_per_head_v)
        
        self.norm = sqrt(self.dim_per_head_q_k)

        self.out = nn.Linear(dim_v, input_dim)
    
    def forward(self, x):
        # (b, s, word_dim)
        b = x.size(0)
        s = x.size(1)

        # (b, s, word_dim) -> (b, head, s, dimdim_per_head_q_k)
        q = self.q(x).view(b, -1, self.head, self.dim_per_head_q_k).transpose(1, 2)
        k = self.k(x).view(b, -1, self.head, self.dim_per_head_q_k).transpose(1, 2)
        v = self.v(x).view(b, -1, self.head, self.dim_per_head_v).transpose(1, 2)
        
        # (b, head, s, s)
        d = torch.matmul(q, k.transpose(2, 3)) // self.norm
        softmax_out = torch.softmax(d, dim=-1)

        # (b, head, s, dim_per_head_v)
        attn_out = torch.matmul(softmax_out, v)
        
        # (b, s, head, dim_per_head_v) -> (b, s, dim_v) 
        attn_out = attn_out.transpose(1, 2).contiguous().view(b, s, -1)
        
        # (b, s, dim_v) -> # (b, s, input_dim)
        out = self.out(attn_out) 
        return out

  

posted @ 2025-03-21 14:17  15375357604  阅读(19)  评论(0)    收藏  举报