LLaMA 2

0 Introduction

What's new

  • Rotary Position Embedding (RoPE)
  • RMS Norm
  • Grouped Query Attention + KV Cache
  • SwiGLU

Diagram prospect

1 Model Architecture

1.1 Rotary Position Embedding

Paper: ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING

\(f(q,m)f(k,n) = g(q,k,m-n)\)
\( f_q(q,m)f_k(k,n) = \begin{bmatrix} cosm\theta & -sinm\theta\\ sinm\theta & cosm\theta \end{bmatrix}q \begin{bmatrix} cosn\theta & -sinn\theta\\ sinn\theta & cosn\theta \end{bmatrix}k \)

Euler's formula
\(e^{ix} = \cos x + i\sin x\)
\(e^{im\theta} = \cos m\theta + i\sin m\theta\)

\(Q_iR(i\theta) = x_iW_Q^TR(i\theta) = (e_i+p_i)W_Q^TR(i\theta)\)
\(K_jR(j\theta) = x_jW_K^TR(j\theta) = (e_j+p_j)W_K^TR(j\theta)\)

$\begin{align*} f_q(x_i,i)f_k(x_j,j) & = [Q_iR(i\theta)][K_jR(j\theta)]^T\\ & = [x_iW_Q^TR(i\theta)][x_jW_K^TR(j\theta)]^T\\ & = (e_i+p_i)W_Q^TR(i\theta)R(j\theta)^TW_K(e_j+p_j)^T\\ & = (e_i+p_i)W_Q^TR(i\theta)R(-j\theta)W_K(e_j+p_j)^T\\ & = (e_i+p_i)W_Q^TR[(i-j)\theta]W_K(e_j+p_j)^T\\ & = g(x_i, x_j, i-j) \end{align*}$

1.2 RMS Norm

1.3 Grouped Query Attention + KV Cache

<1> Grouped Query Attention
GQA is the trade-off between Efficiency and Accuracy.

  • Efficiency: MHA < GQA < MQA
  • Accuracy: MHA > GQA > MQA

Figures from GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

<2> KV Cache

1.4 SwiGLU

SwiGLU means Swish(also refers to SiLU) and Gated Linear Unit, which is commonly used in the feed forward network of LLaMA 2, Mixtral 7B, Mixtral 8×7B.

import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, config):
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size)
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)
    def forward(self, x):
        hidden_states = self.down_proj(F.silu(self.gate_proj(x), dim = -1) * self.up_proj(x))
        return hidden_states

Reference

Video 1: Llama 2 模型结构解析 - CodeLearner | Bilibili
Blog 1: Llama 2详解 - CodeLearner | Zhihu
Blog 2: Understanding Llama2: KV Cache, Grouped Query Attention, Rotary Embedding and More
Video 2: Transformer的位置编码(Position Encoding)进展梳理
Blog 3: 二维旋转矩阵与向量旋转

posted @ 2024-01-14 09:42  ForHHeart  阅读(9)  评论(0编辑  收藏  举报