旋转位置编码

参考自RoPE旋转位置编码深度解析:理论推导、代码实现、长度外推 - 知乎 (zhihu.com)

位置编码: 1.绝对, 直接加到输入中. 2.相对,加在Attn的内积之前, 外推性能强。

 

ROPE:对Attn的K和V矩阵做ROPE

二维场景: 

对于一个二维向量 :

 

偶数维的可以用拆成若干个2维的向量, 对这些向量分别用ROPE再拼接回原始维度(内积满足线性叠加性)

 theta满足远程衰减性质,着相对距离的增加,向量内积结果在减小

import torch

def apply_rope(input_tensor):
    bs, seq_len, channels = input_tensor.shape  # [bs, 77, 4]
    
    # 假设通道为偶数维度,这里c=4
    assert channels % 2 == 0

    # 对通道进行划分 (c=4 -> 两个二维向量)
    c1 = input_tensor[:, :, 0::2]  # shape [bs, 77, 2] -> (c1, c3)
    c2 = input_tensor[:, :, 1::2]  # shape [bs, 77, 2] -> (c2, c4)

    # 生成旋转角度 \theta_p
    positions = torch.arange(seq_len).unsqueeze(1)  # [77, 1]
    theta = positions / (10000 ** (torch.arange(2) / channels))  # 根据RoPE公式计算theta

    cos_theta = torch.cos(theta).unsqueeze(0)  # shape [1, 77, 2]
    sin_theta = torch.sin(theta).unsqueeze(0)  # shape [1, 77, 2]

    # 对每对通道应用旋转矩阵
    c1_prime = c1 * cos_theta - c2 * sin_theta  # 旋转后的第一维
    c2_prime = c1 * sin_theta + c2 * cos_theta  # 旋转后的第二维

    # 拼接回原来的通道维度
    output_tensor = torch.stack([c1_prime, c2_prime], dim=-1).reshape(bs, seq_len, channels)
    
    return output_tensor

# 假设输入是 [bs, 77, 4] 的张量
input_tensor = torch.randn(bs, 77, 4)
output_tensor = apply_rope(input_tensor)

 

posted @ 2024-09-25 11:41  老八蜜汁小憨包  阅读(96)  评论(0)    收藏  举报