基于 PyTorch 动手实现 LLM

本文是 Stanford CS336 课程学习记录的一部分。

撰写本文时,模型尚未经过训练,尽管模块可以通过测试,但可能仍然存在一些潜在问题。本文仅供参考。

概述

在 CS336 课程的作业 1 中,要求我们基于 PyTorch 框架自己实现一个 Transformer 语言模型并训练它。本文记录模型架构的实现。对着作业指导加上问大模型连滚带爬算是写了个大概,能通过所有测试了。

我们需要实现的是一个 Pre-Norm 的 Transformer 语言模型,带有现代 LLM 的常见优化:LayerNorm 前置、SwiGLU 激活、RMSNorm、无 bias 的线性层、旋转位置嵌入(RoPE)等。架构图如下:

img

具体来说,此模型的输入是经过分词器(Tokenizer,通常是 BPE)分词的 token ID 序列,输出是针对输入序列中每个 token,预测其下一个 token 的概率。

在我们的实际实现中,我们不会进行最后的 softmax 操作,而是直接返回 softmax 之前的张量,我们通常把这个张量称为 logits。

所有模块的实现

Linear 线性层

作业指导不允许我们使用 torch.nn.Linear,因此我们需要自己实现它。我们实现的是无 bias 的线性层,所以其实就是一个简单的矩阵乘法,没什么可说的,但是要注意一下参数初始化。

作业指导的 3.4.1 节指出,对于线性层的权重,我们使用以下的正态分布进行初始化:

\[\mathcal{N}\left(\mu = 0, \sigma^2 = \frac{2}{d_{\text{in}}+d_{\text{out}}}\right),\text{截断到}[-3\sigma,3\sigma] \]

这可以用以下代码实现:

sigma = math.sqrt(2 / (in_features + out_features))
w = torch.normal(mean=0, std=sigma, size=(out_features, in_features), device=device, dtype=dtype)
w = torch.nn.init.trunc_normal_(w, a=-3*sigma, b=3*sigma)

注意这里将权重矩阵初始化为 (out_features, in_features),因为作业指导中要求我们将矩阵存储为 \(W\) 而非 \(W^T\)。完成权重初始化后,只需只用 nn.Parameter 封装即可。

在 forward 函数中,我们使用 einx 的 dot 函数来完成矩阵乘法。在后面我们还会使用几次 einx,einx 可以视为 einops 和 einsum 的统一接口(可参考我之前写的一篇 einops 学习笔记),是终极张量操作库。

einx.dot("d_out d_in, ... d_in -> ... d_out", self.weight, x)

我们的权重的形状是 d_out d_in,输入张量 x 的形状是 ... d_in... 在 einx 语法中表示 0 个或多个维度。最终输出的张量形状是 ... d_out

Linear 的完整代码如下:

class Linear(nn.Module):
    def __init__(
        self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None
    ):
        """
        Args:
            in_features (int): final dimension of the input
            out_features (int): final dimension of the output
            device (torch.device | None = None): Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        sigma = math.sqrt(2 / (in_features + out_features))
        w = torch.normal(mean=0, std=sigma, size=(out_features, in_features), device=device, dtype=dtype)
        w = torch.nn.init.trunc_normal_(w, a=-3*sigma, b=3*sigma)
        self.weight = nn.Parameter(w)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return einx.dot("d_out d_in, ... d_in -> ... d_out", self.weight, x)

Embedding 嵌入层

Embedding 其实是一个查表操作,这个表是可训练的,输入长度是 vocab_size,即模型的词表大小,输出长度一般是 d_model

我们同样注意一下初始化。作业指导指出,Embedding 权重使用以下的正态分布进行初始化:

\[\mathcal{N}\left(\mu = 0, \sigma^2 = 1\right),\text{截断到}[-3,3] \]

完整代码如下:

class Embedding(nn.Module):
    def __init__(
        self, num_embeddings, embedding_dim, device: torch.device | None = None, dtype: torch.dtype | None = None
    ):
        """
        Args:
            num_embeddings (int): Size of the vocabulary
            embedding_dim (int): Dimension of the embedding vectors, i.e., d_model
            device (torch.device | None = None): Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        e = torch.normal(mean=0, std=1, size=(num_embeddings, embedding_dim), device=device, dtype=dtype)
        e = torch.nn.init.trunc_normal_(e, a=-3, b=3)
        self.weight = nn.Parameter(e)

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
        return self.weight[token_ids]

RMSNorm

对于向量 \(a\in \mathbb{R}^{d_\text{model}}\) 中的每个元素,RMSNorm 的公式如下:

\[\text{RMSNorm}\left(a_i \right)=\frac{a_i}{\text{RMS}(a)}g_i, \]

其中

\[\text{RMS}(a)=\sqrt{\frac{1}{d_{\text{model}}}\sum_{i=1}^{d_\text{model}}a_i^2+\epsilon}, \]

\(g\) 是一个可学习的参数向量(共有 \(d_\text{model}\) 个元素),\(\epsilon\) 是一个为了避免除以 0 的很小的常量,一般取 1e-5。

需要注意的是,在对输入取根号时,为了避免数值溢出,需要先扩大到 torch.float32 类型(考虑到实际训练中经常使用 FP16 或 BF16)。

RMSNorm 包含可学习的参数 \(g\),因此它应当被实现为 nn.Module 而不是一个无状态的函数。作业指导指出我们应当将 \(g\) 初始化为全 1 的张量。

完整代码如下:

class RMSNorm(nn.Module):
    def __init__(
        self, d_model: int, eps: float = 1e-5, device: torch.device | None = None, dtype: torch.dtype | None = None
    ):
        """
        Args:
            d_model (int): Hidden dimension of the model
            eps (float = 1e-5) Epsilon value for numerical stability
            device (torch.device | None = None) Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        self.epsilon = eps
        # 此处 d_model 后的逗号是必须的
        g = torch.ones(size=(d_model,), device=device, dtype=dtype)
        self.weight = nn.Parameter(g)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x 的形状是:(batch_size, seq_len, d_model)
        # upcast input to torch.float32 to prevent overflow when square the input
        in_dtype = x.dtype
        x = x.to(torch.float32)

        # RMSNorm implementation
        # x 中所有元素的平方和
        square_sum = einx.reduce("... [d_model]", x * x, op=torch.sum)
        # x.shape[-1] 即为 d_model
        rms = torch.sqrt(1 / x.shape[-1] * square_sum + self.epsilon)
        rms = einx.rearrange("... -> ... 1", rms)
        result = x / rms * self.weight

        # Return the result in the original dtype
        return result.to(in_dtype)

带有 SwiGLU 激活的 FeedForward 网络

SwiGLU 是 Swish 激活函数(也叫 SiLU)和线性门控单元 GLU 的结合。SiLU(Swish)激活函数公式如下:

\[\text{SiLU}(x)=x\cdot \sigma(x)=\frac{x}{1+e^{-x}}, \]

其中 \(\sigma(x)\) 是 Sigmoid 函数。

下图展示了 SiLU(Swish)函数的图象。它与 ReLU 类似,但在 0 处更加平滑。

img

SiLU 可以使用以下代码实现。作业指导中指出,为了数值稳定性,我们应当直接使用 torch.sigmoid 来实现 SiLU。

def silu(x: torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(x)

GLU 公式如下:

\[\text{GLU}(x,W_1,W_2)=\sigma(W_1 x)\odot W_2 x, \]

其中 \(\sigma\) 是 Sigmoid 函数,\(\odot\) 表示逐元素乘法。\(W_1\)\(W_2\) 是两个权重矩阵。

将二者结合起来,就得到了我们需要实现的前馈网络:

\[\text{FFN}(x)=\text{SwiGLU}(x,W_1,W_2,W_3)=W_2(\text{SiLU}(W_1 x)\odot W_3 x), \]

其中 \(x\in\mathbb{R}^{d_\text{model}},W_1,W_3\in\mathbb{R}^{d_\text{ff}\times d_\text{model}},W_2\in\mathbb{R}^{d_\text{model}\times d_\text{ff}}\)。一般我们令 \(d_\text{ff}=\frac{8}{3}d_\text{model}\),在实际实现中,为了充分利用硬件,通常取 \(d_\text{ff}\) 为最接近 \(\frac{8}{3}d_\text{model}\) 的 64 的整数倍。这可以通过以下代码实现:

def _get_d_ff(self, d_model: int) -> int:
    # d_ff 应当大致等于 f_model 的 8/3。因此,我们首先将 d_model 乘以 8/3,然后将其四舍五入到最近的 64 的整数倍,从而充分利用硬件。
    return int(round(d_model * (8 / 3) / 64) * 64)

我们的前馈网络应当包含 3 个线性模块,分别对应 \(W_1,W_2\)\(W_3\)。完整代码如下:

class FeedForward(nn.Module):
    """
    Position-wise feed forward network with SwiGLU activation
    """

    def _get_d_ff(self, d_model: int) -> int:
        # d_ff 应当大致等于 f_model 的 8/3。因此,我们首先将 d_model 乘以 8/3,然后将其四舍五入到最近的 64 的整数倍,从而充分利用硬件。
        return int(round(d_model * (8 / 3) / 64) * 64)

    def __init__(
        self, d_model: int, d_ff: int = None, device: torch.device | None = None, dtype: torch.dtype | None = None
    ):
        """
        Args:
            d_model (int): Hidden dimension of the model
            device (torch.device | None = None) Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        self.d_model = d_model
        self.d_ff = self._get_d_ff(d_model) if d_ff is None else d_ff
        self.w1 = Linear(self.d_model, self.d_ff, device=device, dtype=dtype)
        self.w2 = Linear(self.d_ff, self.d_model, device=device, dtype=dtype)
        self.w3 = Linear(self.d_model, self.d_ff, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 此处不再使用 einx 了,因为需要访问 Linear 模块内部的权重,这样不太好。
        w1x = self.w1(x)
        w1x_silu = silu(w1x)
        w3x = self.w3(x)
        return self.w2(w1x_silu * w3x)

RoPE 旋转位置嵌入

这部分没搞太明白,代码算是 AI 写的,大概记录一下。

RoPE 的核心思想是,对于一个向量中的每一对维度(比如 \(i_{th}\) 位置的 \(q\) 向量的第 2k 和 2k+1 个维度),我们都对其应用一个二维旋转。这个旋转的角度 \(\theta_{i,k}\) 取决于:

  1. Token 在序列中的位置 \(i\)
  2. 维度对的索引 \(k\)

公式是 \(\theta_{i,k} = i / \Theta^{2k/d_k}\),其中 \(\Theta\) 是一个大常数(通常是10000),\(d_k\) 是向量总维度。

所以,对于一个输入查询向量 \(q\),其在位置 \(i\) 且维度为 \((2k, 2k+1)\) 的分量 \((q_{2k}, q_{2k+1})\) 会被旋转矩阵 [[cos( \(\theta_{i,k}\)), -sin( \(\theta_{i,k}\))], [sin( \(\theta_{i,k}\)), cos( \(\theta_{i,k}\))]] 作用。

RoPE 不包含任何可学习的参数,这意味着它是可以复用的。为了优化性能,我们在 RoPE 模块初始化时就预先计算出所有需要的 sin 和 cos 值,这样之后可以重复使用。

完整代码如下:

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        """
        Args:
            theta (float): theta value for the RoPE
            d_k (int): dimension of query and key vectors
            max_seq_len (int): Maximum sequence length that will be inputted
            device (torch.device | None = None): Device to store the buffer on
        """
        super().__init__()
        # RoPE 本身不包含任何可学习的参数,__init__ 主要用来计算所有可能的位置 i 和所有可能的维度对 k 的 cos 和 sin 值,并存储起来,从而减少重复计算。
        self.theta = theta
        inv_freq = 1.0 / (self.theta ** (torch.arange(0, d_k, 2).float() / d_k))  # shape: (d_k // 2,)
        position_indices = torch.arange(max_seq_len, dtype=torch.float)  # shape: (max_seq_len,)
        self.position_frequencies = einx.dot(
            "inv_freq_len, max_seq_len -> max_seq_len inv_freq_len", inv_freq, position_indices
        )
        cos_cached = torch.cos(self.position_frequencies)
        sin_cached = torch.sin(self.position_frequencies)
        self.register_buffer("cos_cached", cos_cached, persistent=False)
        self.register_buffer("sin_cached", sin_cached, persistent=False)

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (..., seq_len, d_k)
            token_positions (..., seq_len)
        """
        current_cos = self.cos_cached[token_positions]  # shape: (..., seq_len, d_k // 2)
        current_sin = self.sin_cached[token_positions]  # 同上

        x_even = x[..., ::2]  # shape: (..., seq_len, d_k // 2)
        x_odd = x[..., 1::2]  # 同上

        rotated_x_even = x_even * current_cos - x_odd * current_sin
        rotated_x_odd = x_even * current_sin + x_odd * current_cos

        return torch.stack([rotated_x_even, rotated_x_odd], dim=-1).reshape(x.shape)

softmax

softmax 的公式如下:

\[\text{softmax}(v)_i=\frac{\text{exp}(v_i)}{\sum_{j=1}^{n}\text{exp}(v_j)}. \]

注意,当 \(v_i\) 较大时,\(\text{exp}(v_i)\) 可能变成无穷大,从而分母也是无穷大,两个无穷大相除得到 NaN。由于对向量整体加上一个常量 \(c\) 不会影响 softmax 的结果,我们对向量整体减去向量中最大的元素,从而增加数值稳定性。

完整代码如下:

def softmax(x: torch.Tensor, dim: int):
    """
    对输入张量 x 的第 i 个维度进行 softmax,通过减去目标维度中最大的值来增加数值稳定性。
    Args:

    """
    # keepdim=True 表示保持其他维度的值不变
    max_val = torch.max(x, dim=dim, keepdim=True).values
    x_stable = x - max_val
    exp_x = torch.exp(x_stable)
    exp_sum = torch.sum(exp_x, dim=dim, keepdim=True)
    return exp_x / exp_sum

缩放点积注意力

于是我们来到 self attention 的经典公式:

\[\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{Q^T K}{\sqrt{d_k}} \right)V \]

此处与经典公式有一点不同:原公式中的 \(QK^T\) 在这里是 \(Q^T K\)。这是因为作业指导要求我们使用列优先顺序存储矩阵。如果我们使用 einx 或者 einsum,那么这通常不需要太多注意。

我们还需要支持一个额外的功能:Masking。mask 是一个 bool 类型的矩阵,对于值为 False 的位置,我们需要将对应值设为 -torch.inf

完整代码如下:

def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None):
    """
        Args:
            q (batch_size, ..., seq_len, d_k)
            k (batch_size, ..., seq_len, d_k)
            v (batch_size, ..., seq_len, d_v)
            mask (seq_len, seq_len)
        Returns a tensor with shape (batch_size, ..., d_v)
        The attention probabilities of positions with a mask value of True should collectively sum
    to 1, and the attention probabilities of positions with a mask value of False should be zero.
    """
    qk = einx.dot("... seq_len_q d_k, ... seq_len_k d_k -> ... seq_len_q seq_len_k", q, k)
    d_k = q.shape[-1]
    pre_softmax = qk / math.sqrt(d_k)
    if mask is not None:
        pre_softmax.masked_fill_(mask=~mask, value=-torch.inf)
    post_softmax = softmax(pre_softmax, dim=-1)
    return einx.dot("... seq_len_q seq_len_k, ... seq_len_k d_v -> ... seq_len_q d_v", post_softmax, v)

MultiHeadAttention 多头注意力

上面的缩放点积注意力描述了一个注意力头的计算。在多头注意力中,我们将所有注意力头的 Attention 拼接到一起,然后通过一个统一的 \(W_O\) 矩阵映射到最终的输出。

直观上每个头都拥有各自的 Q、K、V 映射矩阵。然而,为了减少计算量,我们使用统一的 Q、K、V 矩阵,然后在进行自注意力计算之前将它们拆分,最后再拼接起来。

我们还需要构建一个因果掩码矩阵,这可以通过 torch.tril 来实现。

完整代码如下:

class MultiHeadSelfAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        rope_max_seq_len: int = None,
        rope_theta: int = None,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """
        Args:
            d_model (int) Dimensionality of the Transformer block inputs.
            num_heads (int) Number of heads to use in multi-head self-attention.
            rope_max_seq_len (int = None) max_seq_len for RoPE. If None, will not use RoPE
            rope_theta (int = None) theta for RoPE. If None, will not use RoPE
            device (torch.device | None = None) Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        # 使用整数除法来确保整除
        d_k = d_model // num_heads

        self.num_heads = num_heads
        self.d_model = d_model

        # 线性投影层
        self.q_proj = Linear(d_model, d_model, device=device, dtype=dtype)
        self.k_proj = Linear(d_model, d_model, device=device, dtype=dtype)
        self.v_proj = Linear(d_model, d_model, device=device, dtype=dtype)

        # 最终的输出层
        self.output_proj = Linear(d_model, d_model, device=device, dtype=dtype)

        if rope_max_seq_len is not None or rope_theta is not None:
            assert rope_max_seq_len is not None and rope_theta is not None, (
                "rope_max_seq_len and rope_theta must be not None to enable RoPE"
            )
            self.rope = RotaryPositionalEmbedding(
                theta=rope_theta, d_k=d_k, max_seq_len=rope_max_seq_len, device=device
            )
        else:
            self.rope = None

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor = None):
        """
        Args:
            x should be (batch_size, seq_len, d_model) in shape
            token_positions should be (..., seq_len) in shape
        """
        _, seq_len, _ = x.shape

        q_proj = self.q_proj(x)  # (batch_size, seq_len, d_model)
        k_proj = self.k_proj(x)  # 同上
        v_proj = self.v_proj(x)  # 同上

        q = einx.rearrange(
            "batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim",
            q_proj,
            num_heads=self.num_heads,
        )
        k = einx.rearrange(
            "batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim",
            k_proj,
            num_heads=self.num_heads,
        )
        v = einx.rearrange(
            "batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim",
            v_proj,
            num_heads=self.num_heads,
        )

        # apply rope
        if self.rope:
            assert token_positions is not None, "token_positions is necessary when RoPE is enabled"
            q = self.rope(q, token_positions)  # (batch_size, num_heads, seq_len, head_dim)
            k = self.rope(k, token_positions)

        # scaled_dot_product_attention 对 mask 的要求:True 表示允许注意(不填充),False 表示需要掩盖(填充 -inf)。
        # 对因果掩码而言,这意味着下三角(包括对角线)应该是 True,上三角应该是 False。
        causal_mask = torch.ones((seq_len, seq_len), dtype=torch.bool).tril(diagonal=0)
        attn_output_per_head = scaled_dot_product_attention(
            q, k, v, mask=causal_mask
        )  # (batch_size, num_heads, seq_len, head_dim)

        attn_output_combined = einx.rearrange(
            "batch num_heads seq_len head_dim -> batch seq_len (num_heads head_dim)",
            attn_output_per_head,
            num_heads=self.num_heads,
        )

        return self.output_proj(attn_output_combined)

Transformer Block

Transformer Block 是 Transformer 语言模型的基本组成单位。Transformer 语言模型包含若干个层,每层是一个 Transformer Block。

Transformer Block 包含两个子层:MultiHeadAttention 子层和 FFN 子层。具体可以参考本文开头的图片。

MultiHeadAttention 子层计算可描述为以下公式:

\[y=x+\text{MultiHeadAttention}(\text{RMSNorm}(x)) \]

FFN 子层计算可描述为以下公式:

\[z=y+\text{FFN}(\text{RMSNorm}(y)) \]

注意这里的两个 RMSNorm 是不共享的,需要独立初始化。

前面的 RoPE 模块和所有用到 RoPE 的模块都需要 token_positions,经过查看作业提供的框架代码,这些模块在测试时都提供了现成的 token_positions,而 Transformer Block 在测试时却没有提供 token_positions。因此,我们需要手动构造它。

完整代码如下:

class TransformerBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        rope_max_seq_len: int,
        rope_theta: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """
        Args:
            d_model (int) Dimensionality of the Transformer block inputs
            num_heads (int) Number of heads to use in multi-head self-attention
            d_ff (int) Dimensionality of the position-wise feed-forward inner layer
            rope_max_seq_len (int) max_seq_len for RoPE
            rope_theta (int) theta for RoPE
            device (torch.device | None = None) Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        self.attn = MultiHeadSelfAttention(
            d_model=d_model,
            num_heads=num_heads,
            rope_max_seq_len=rope_max_seq_len,
            rope_theta=rope_theta,
            device=device,
            dtype=dtype,
        )
        # layer norm for sub-layer 1
        self.ln1 = RMSNorm(d_model=d_model, device=device, dtype=dtype)
        # layer norm for sub-layer 2
        self.ln2 = RMSNorm(d_model=d_model, device=device, dtype=dtype)
        self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor):
        # 构造 token_positions:[0, 1, 2, ..., seq_len-1]
        _, seq_len, _ = x.shape  # 输入形状是 (batch_size, seq_len, d_model)
        token_positions = torch.arange(seq_len, device=x.device, dtype=torch.long)

        # 执行包含两个子层的前向传播过程
        sub_layer_1 = x + self.attn(self.ln1(x), token_positions)
        sub_layer_2 = sub_layer_1 + self.ffn(self.ln2(sub_layer_1))
        return sub_layer_2

完整的 Transformer 语言模型

终于到最后一步了。我们将会把之前实现的所有模块拼装起来,形成一个完整的 Transformer 语言模型。

通过本文开头的图片可以看到 Transformer 语言模型的计算步骤:

  1. 进行 Token Embedding,将分词器产生的离散 Token 序列转换成连续的嵌入向量。
  2. 嵌入向量依次通过所有 Transformer Block 形成的层
  3. 进行一次 RMSNorm
  4. 进行一次输出映射,算是 Token Embedding 的逆操作,将连续的向量映射到 logits。

我们的 Transformer 语言模型得到 logits 后就直接返回了。如果对 logits 再进行一次 softmax 操作,就能得到模型对下一个 token 在词表中的位置预测概率。

完整代码如下:

class TransformerLM(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        vocab_size: int,
        context_length: int,
        num_layers: int,
        rope_theta: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """
        Args:
            d_model (int) Dimensionality of the Transformer block inputs
            num_heads (int) Number of heads to use in multi-head self-attention
            d_ff (int) Dimensionality of the position-wise feed-forward inner layer
            vocab_size (int) The size of the vocabulary
            context_length (int) The maximum context length
            num_layers (int) The number of Transformer blocks to use
            rope_theta (int) theta for RoPE
            device (torch.device | None = None) Device to store the parameters on
            dtype (torch.dtype | None = None): Data type of the parameters
        """
        super().__init__()
        self.token_embeddings = Embedding(num_embeddings=vocab_size, embedding_dim=d_model, device=device, dtype=dtype)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    d_model=d_model,
                    num_heads=num_heads,
                    d_ff=d_ff,
                    rope_max_seq_len=context_length,
                    rope_theta=rope_theta,
                    device=device,
                    dtype=dtype,
                )
                for _ in range(num_layers)
            ]
        )
        # layer norm final
        self.ln_final = RMSNorm(d_model=d_model, device=device, dtype=dtype)
        # output embedding (a linear layer)
        self.lm_head = Linear(d_model, vocab_size, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor):
        token_embeddings = self.token_embeddings(x)  # (batch_size, seq_len, d_model)

        hidden_states = token_embeddings
        for layer in self.layers:
            hidden_states = layer(hidden_states)  # (batch_size, seq_len, d_model)

        post_norm = self.ln_final(hidden_states)  # (batch_size, seq_len, d_model)
        logits = self.lm_head(post_norm)  # (batch_size, seq_len, vocab_size)
        return logits
posted @ 2025-08-03 18:00  Eslzzyl  阅读(127)  评论(0)    收藏  举报