使用JAX从零构建Transformer模型全流程解析

使用JAX从零构建Transformer模型全流程解析

在本教程中,我们将探讨如何使用JAX开发神经网络。而Transformer模型无疑是一个绝佳的选择。随着JAX日益流行,越来越多的开发团队开始尝试并将其纳入项目。尽管它尚未达到Tensorflow或PyTorch的成熟度,但它为构建和训练深度学习模型提供了一些强大的特性。

为了扎实理解JAX的基础知识,建议先阅读我之前的相关文章。完整代码可在我们的GitHub仓库中找到。

许多人在开始使用JAX时面临的常见问题是框架的选择。某机构的团队似乎非常忙碌,已经在JAX之上发布了大量框架。以下是最著名的一些框架列表:

  • Haiku:Haiku是用于深度学习的首选框架,被许多某中心和某机构的内部团队使用。它为机器学习研究提供了简单、可组合的抽象,以及现成的模块和层。
  • Optax:Optax是一个梯度处理和优化库,包含开箱即用的优化器和相关数学运算。
  • RLax:RLax是一个强化学习框架,包含许多RL子组件和操作。
  • Chex:Chex是一个用于测试和调试JAX代码的实用程序库。
  • Jraph:Jraph是JAX中的图神经网络库。
  • Flax:Flax是另一个神经网络库,提供各种现成的模块、优化器和实用程序。它很可能最接近我们理想中的一体化JAX框架。
  • Objax:Objax是第三个ML库,专注于面向对象编程和代码可读性。同样,它包含了最流行的模块、激活函数、损失函数、优化器以及一些预训练模型。
  • Trax:Trax是一个端到端的深度学习库,专注于Transformer模型。
  • JAXline:JAXline是一个监督学习库,用于分布式JAX训练和评估。
  • ACME:ACME是另一个强化学习研究框架。
  • JAX-MD:JAX-MD是一个处理分子动力学的专业框架。
  • Jaxchem:JAXChem是另一个强调化学建模的专业库。

当然,问题是我该选择哪一个?
老实说,我也不确定。
但如果我是你,并且想学习JAX,我会从最流行的开始。Haiku和Flax似乎在某中心和某机构内部被大量使用,并且拥有最活跃的GitHub社区。在本文中,我将从第一个开始,看看后续是否需要其他框架。

那么,你准备好用JAX和Haiku构建一个Transformer了吗?顺便说一下,我假设你对Transformer有扎实的理解。如果没有,请参考我们关于注意力和Transformer的文章。

让我们从自注意力块开始。

自注意力块

首先,我们需要导入JAX和Haiku:

import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np

幸运的是,Haiku有一个内置的MultiHeadAttention块,可以扩展以构建掩码自注意力块。我们的块接收查询、键、值以及掩码,并返回一个JAX数组作为输出。你可以看到代码与标准的PyTorch或Tensorflow代码非常相似。我们所做的就是使用np.trill()(该函数将数组中对角线第k个元素以上的所有元素置零)构建因果掩码,与我们的掩码相乘,然后将所有内容传递给hk.MultiHeadAttention模块。

class SelfAttention(hk.MultiHeadAttention):
    """应用了因果掩码的自注意力。"""
    def __call__(
            self,
            query: jnp.ndarray,
            key: Optional[jnp.ndarray] = None,
            value: Optional[jnp.ndarray] = None,
            mask: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        key = key if key is not None else query
        value = value if value is not None else query

        seq_len = query.shape[1]
        causal_mask = np.tril(np.ones((seq_len, seq_len)))
        mask = mask * causal_mask if mask is not None else causal_mask

        return super().__call__(query, key, value, mask)

这段代码片段允许我介绍Haiku的第一个关键原则。所有模块都应该是hk.Module的子类。这意味着它们应该实现__init____call__方法,以及其他任何方法。从某种意义上说,它与PyTorch模块的架构相同,我们在那里实现__init__和一个forward方法。

为了更清楚地说明这一点,让我们构建一个简单的2层多层感知机作为hk.Module,它将在下面的Transformer中方便地使用。

线性层

一个简单的2层MLP看起来像这样。再一次,你可以注意到它看起来多么熟悉。

class DenseBlock(hk.Module):
    """一个2层的MLP"""
    def __init__(self,
                 init_scale: float,
                 widening_factor: int = 4,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self._init_scale = init_scale
        self._widening_factor = widening_factor

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        hiddens = x.shape[-1]
        initializer = hk.initializers.VarianceScaling(self._init_scale)
        x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
        x = jax.nn.gelu(x)
        return hk.Linear(hiddens, w_init=initializer)(x)

这里需要注意几点:

  • Haiku在hk.initializers下为我们提供了一组权重初始化器,我们可以在这里找到最常见的方法。
  • 它还有内置的许多流行层和模块,例如hk.Linear。完整列表,请查看官方文档。
  • 不提供激活函数,因为JAX已经有一个名为jax.nn的子包,我们可以在那里找到relusoftmax等激活函数。

归一化层

层归一化是Transformer架构的另一个组成部分,我们也可以在Haiku的公共模块中找到。

def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
    """使用默认设置对x应用唯一的LayerNorm。"""
    return hk.LayerNorm(axis=-1,
                        create_scale=True,
                        create_offset=True,
                        name=name)(x)

Transformer

现在是重点。下面你可以看到一个非常简化的Transformer,它使用了我们预定义的模块。在__init__中,我们定义了基本变量,如层数、注意力头数和dropout率。在__call__中,我们使用for循环组合了一系列块。

如你所见,每个块包括:

  • 一个归一化层
  • 一个自注意力块
  • 两个dropout层
  • 两个归一化层
  • 两个跳跃连接 (h = h + h_attnh = h + h_dense)
  • 一个2层的密集块

最后,我们还添加了最终的归一化层。

class Transformer(hk.Module):
    """一个Transformer堆栈。"""
    def __init__(self,
                 num_heads: int,
                 num_layers: int,
                 dropout_rate: float,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self._num_layers = num_layers
        self._num_heads = num_heads
        self._dropout_rate = dropout_rate

    def __call__(self,
                 h: jnp.ndarray,
                 mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """连接transformer。
        Args:
          h: 输入, [B, T, H].
          mask: 填充掩码, [B, T].
          is_training: 是否处于训练模式。
        Returns:
          形状为[B, T, H]的数组。
        """

        init_scale = 2. / self._num_layers
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = SelfAttention(
                num_heads=self._num_heads,
                key_size=64,
                w_init_scale=init_scale,
                name=f'h{i}_attn')(h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h

我想现在你已经意识到,用JAX构建神经网络非常简单。

嵌入层

为了完整起见,我们也加入嵌入层。需要知道的是,Haiku也提供了一个嵌入层,它将从我们的输入句子中创建标记。然后将标记添加到位置嵌入中,产生最终的输入。

def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]

    # 嵌入输入标记和位置。
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter(
        'pos_embs', [seq_length, d_model], init=embed_init)
    input_embeddings = token_embs + positional_embeddings
    return input_embeddings, input_mask

hk.get_parameter(param_name, ...)用于访问模块的可训练参数。但你可能会问,为什么不直接使用像在PyTorch中那样的对象属性。这就是Haiku的第二个关键原则发挥作用的地方。我们使用这个API,以便可以使用hk.transform将代码转换为纯函数。这理解起来并不简单,但我会尽量让它清晰明了。

为什么需要纯函数?

JAX的强大之处在于其函数变换:用vmap向量化函数的能力,用pmap自动并行化,用jit即时编译。这里需要注意的是,为了变换一个函数,它必须是纯的。

纯函数是具有以下属性的函数:

  • 对于相同的参数,函数返回值是相同的(不随局部静态变量、非局部变量、可变引用参数或输入流的变化而变化)。
  • 函数应用没有副作用(不改变局部静态变量、非局部变量、可变引用参数或输入/输出流)。

来源:O'Reily的Scala纯函数

这实际上意味着一个纯函数总是会:

  • 如果使用相同的输入调用,则返回相同的结果
  • 所有输入数据都通过函数参数传递,所有结果都通过函数结果输出

Haiku提供了一个名为hk.transform的函数转换,它将具有面向对象、功能上“不纯”的模块的函数转换为可以与JAX一起使用的纯函数。为了在实践中看到这一点,让我们继续训练我们的Transformer模型。

前向传播

一个典型的前向传播包括:

  • 获取输入并计算输入嵌入
  • 通过Transformer的块运行
  • 返回输出

上述步骤可以很容易地用JAX组合如下:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
                     num_layers: int, dropout_rate: float):
    """创建模型的前向传播。"""

    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """前向传播。"""
        input_embeddings, input_mask = embeddings(data, vocab_size)

        # 在输入上运行transformer。
        transformer = Transformer(
            num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
        output_embeddings = transformer(input_embeddings, input_mask, is_training)

        # 反向嵌入(未绑定)。
        return hk.Linear(vocab_size)(output_embeddings)

    return forward_fn

虽然代码很简单,但其结构可能看起来有点奇怪。实际的前向传播是通过forward_fn函数执行的。然而,我们用build_forward_fn函数包装了这个函数,并返回forward_fn。这是怎么回事?

接下来,我们将需要使用hk.transformforward_fn函数转换为纯函数,以便我们可以利用自动微分、并行化等。

这将通过以下方式完成:

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
                                  num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)

这就是为什么我们不是简单地定义一个函数,而是包装并返回函数本身,或者更准确地说,是一个可调用对象。然后可以将这个可调用对象传递给hk.transform并成为一个纯函数。如果清楚了这一点,让我们继续看损失函数。

损失函数

损失函数是我们熟知的交叉熵函数,不同之处在于我们也考虑了掩码。同样,JAX提供了one_hotlog_softmax功能。

def lm_loss_fn(forward_fn,
               vocab_size: int,
               params,
               rng,
               data: Mapping[str, jnp.ndarray],
               is_training: bool = True) -> jnp.ndarray:
    """计算数据相对于参数的损失。"""
    logits = forward_fn(params, rng, data, is_training)
    targets = jax.nn.one_hot(data['target'], vocab_size)
    assert logits.shape == targets.shape

    mask = jnp.greater(data['obs'], 0)
    loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.sum(loss * mask) / jnp.sum(mask)

    return loss

如果你还在坚持,喝一口咖啡,因为从现在开始事情会变得严肃起来。是时候构建我们的训练循环了。

训练循环

因为Jax和Haiku都没有内置优化功能,所以我们将使用另一个名为Optax的框架。如开头所述,Optax是用于梯度处理的包。

首先,关于Optax需要了解的一些事项:

  • Optax的关键变换是GradientTransformation。该变换由两个函数定义,__init____update____init__初始化状态,__update__根据状态和参数的当前值转换梯度:
    state = init(params)
    grads, state = update(grads, state, params=None)
    

在看代码之前,还需要了解Python内置的functools.partial函数。functools包处理高阶函数和可调用对象的操作。

如果一个函数包含其他函数作为参数或返回一个函数作为输出,则称为高阶函数。

partial(也可以用作注解)返回一个基于原始函数的新函数,但具有更少或固定的参数。例如,如果f将两个值x,y相乘,则partial将创建一个新函数,其中x将被固定为2:

from functools import partial
def f(x,y):
        return x * y

# 创建一个乘以2的新函数(x将被固定为2)
g = partial(f,2)

print(g(4))
#返回 8

在这个简短插曲之后,让我们继续。为了简化我们的主函数,我们将把梯度更新提取到它自己的类中。

首先,GradientUpdater接受模型、损失函数和优化器。

  • 模型将是通过hk.transform转换的纯forward_fn函数
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
                                  num_layers, dropout_rate)
    forward_fn = hk.transform(forward_fn)
    
  • 损失函数将是具有固定forward_fnvocab_sizepartial的结果
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
    
  • 优化器是一系列将按顺序运行的优化变换(可以使用optax.chain组合操作)
    optimizer = optax.chain(
            optax.clip_by_global_norm(grad_clip_value),
            optax.adam(learning_rate, b1=0.9, b2=0.99))
    

梯度更新器将初始化如下:

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

并将如下所示:

class GradientUpdater:
    """围绕 init_fn/update_fn 对的无状态抽象。
    这从训练循环中提取了一些常见的样板代码。
    """

    def __init__(self, net_init, loss_fn,
                 optimizer: optax.GradientTransformation):
        self._net_init = net_init
        self._loss_fn = loss_fn
        self._opt = optimizer

    @functools.partial(jax.jit, static_argnums=0)
    def init(self, master_rng, data):
        """初始化更新器的状态。"""
        out_rng, init_rng = jax.random.split(master_rng)
        params = self._net_init(init_rng, data)
        opt_state = self._opt.init(params)
        out = dict(
            step=np.array(0),
            rng=out_rng,
            opt_state=opt_state,
            params=params,
        )
        return out

    @functools.partial(jax.jit, static_argnums=0)
    def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
        """使用一些数据更新状态并返回指标。"""
        rng, new_rng = jax.random.split(state['rng'])
        params = state['params']
        loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)

        updates, opt_state = self._opt.update(g, state['opt_state'])
        params = optax.apply_updates(params, updates)

        new_state = {
            'step': state['step'] + 1,
            'rng': new_rng,
            'opt_state': opt_state,
            'params': params,
        }

        metrics = {
            'step': state['step'],
            'loss': loss,
        }
        return new_state, metrics

__init__中,我们使用self._opt.init(params)初始化优化器,并声明优化状态。状态将是一个包含以下内容的字典:

  • 当前步骤
  • 优化器状态
  • 可训练参数
  • (一个随机生成器密钥,用于传递给jax.random.split)

update函数将更新优化器的状态以及可训练参数。最后,它将返回新状态。

updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)

这里还有两件事需要注意:

  • jax.value_and_grad()是一个特殊的函数,它返回一个带有梯度的可微函数。
  • __init____update__都用@functools.partial(jax.jit, static_argnums=0)注解,这将触发即时编译器并在运行时将其编译为XLA。请注意,如果我们没有将forward_fn转换为纯函数,这是不可能的。

最后,我们准备构建整个训练循环,它结合了迄今为止提到的所有想法和代码。

def main():
    # 创建数据集。
    train_dataset, vocab_size = load(batch_size,
                                     sequence_length)
    # 设置模型、损失和更新器。
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
                                  num_layers, dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(
        optax.clip_by_global_norm(grad_clip_value),
        optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

    # 初始化参数。
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)

注意我们是如何整合GradientUpdate的。只需要两行代码:

state = updater.init(rng, data)
state, metrics = updater.update(state, data)

就是这样。我希望现在你对JAX及其功能有了更清晰的理解。

致谢

本文呈现的代码深受Haiku框架官方示例的启发。为适应本文需求进行了修改。完整示例列表,请查看官方仓库。

结论

在本文中,我们看到了如何使用JAX和Haiku开发和训练一个普通的Transformer。虽然代码不一定难以理解,但它仍然缺乏PyTorch或Tensorflow的可读性。强烈建议你尝试使用它,发现JAX的优势和劣势,看看它是否适合你的下一个项目。根据我的经验,JAX对于需要高性能的研究应用非常强大,但对于现实项目来说还相当不成熟。在我们的discord频道中告诉我们你的想法。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

公众号二维码

公众号二维码

posted @ 2025-12-12 12:10  CodeShare  阅读(0)  评论(0)    收藏  举报