使用JAX从零构建Transformer模型全流程解析
使用JAX从零构建Transformer模型全流程解析
在本教程中,我们将探讨如何使用JAX开发神经网络。而Transformer模型无疑是一个绝佳的选择。随着JAX日益流行,越来越多的开发团队开始尝试并将其纳入项目。尽管它尚未达到Tensorflow或PyTorch的成熟度,但它为构建和训练深度学习模型提供了一些强大的特性。
为了扎实理解JAX的基础知识,建议先阅读我之前的相关文章。完整代码可在我们的GitHub仓库中找到。
许多人在开始使用JAX时面临的常见问题是框架的选择。某机构的团队似乎非常忙碌,已经在JAX之上发布了大量框架。以下是最著名的一些框架列表:
- Haiku:Haiku是用于深度学习的首选框架,被许多某中心和某机构的内部团队使用。它为机器学习研究提供了简单、可组合的抽象,以及现成的模块和层。
- Optax:Optax是一个梯度处理和优化库,包含开箱即用的优化器和相关数学运算。
- RLax:RLax是一个强化学习框架,包含许多RL子组件和操作。
- Chex:Chex是一个用于测试和调试JAX代码的实用程序库。
- Jraph:Jraph是JAX中的图神经网络库。
- Flax:Flax是另一个神经网络库,提供各种现成的模块、优化器和实用程序。它很可能最接近我们理想中的一体化JAX框架。
- Objax:Objax是第三个ML库,专注于面向对象编程和代码可读性。同样,它包含了最流行的模块、激活函数、损失函数、优化器以及一些预训练模型。
- Trax:Trax是一个端到端的深度学习库,专注于Transformer模型。
- JAXline:JAXline是一个监督学习库,用于分布式JAX训练和评估。
- ACME:ACME是另一个强化学习研究框架。
- JAX-MD:JAX-MD是一个处理分子动力学的专业框架。
- Jaxchem:JAXChem是另一个强调化学建模的专业库。
当然,问题是我该选择哪一个?
老实说,我也不确定。
但如果我是你,并且想学习JAX,我会从最流行的开始。Haiku和Flax似乎在某中心和某机构内部被大量使用,并且拥有最活跃的GitHub社区。在本文中,我将从第一个开始,看看后续是否需要其他框架。
那么,你准备好用JAX和Haiku构建一个Transformer了吗?顺便说一下,我假设你对Transformer有扎实的理解。如果没有,请参考我们关于注意力和Transformer的文章。
让我们从自注意力块开始。
自注意力块
首先,我们需要导入JAX和Haiku:
import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np
幸运的是,Haiku有一个内置的MultiHeadAttention块,可以扩展以构建掩码自注意力块。我们的块接收查询、键、值以及掩码,并返回一个JAX数组作为输出。你可以看到代码与标准的PyTorch或Tensorflow代码非常相似。我们所做的就是使用np.trill()(该函数将数组中对角线第k个元素以上的所有元素置零)构建因果掩码,与我们的掩码相乘,然后将所有内容传递给hk.MultiHeadAttention模块。
class SelfAttention(hk.MultiHeadAttention):
"""应用了因果掩码的自注意力。"""
def __call__(
self,
query: jnp.ndarray,
key: Optional[jnp.ndarray] = None,
value: Optional[jnp.ndarray] = None,
mask: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
这段代码片段允许我介绍Haiku的第一个关键原则。所有模块都应该是hk.Module的子类。这意味着它们应该实现__init__和__call__方法,以及其他任何方法。从某种意义上说,它与PyTorch模块的架构相同,我们在那里实现__init__和一个forward方法。
为了更清楚地说明这一点,让我们构建一个简单的2层多层感知机作为hk.Module,它将在下面的Transformer中方便地使用。
线性层
一个简单的2层MLP看起来像这样。再一次,你可以注意到它看起来多么熟悉。
class DenseBlock(hk.Module):
"""一个2层的MLP"""
def __init__(self,
init_scale: float,
widening_factor: int = 4,
name: Optional[str] = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape[-1]
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
这里需要注意几点:
- Haiku在
hk.initializers下为我们提供了一组权重初始化器,我们可以在这里找到最常见的方法。 - 它还有内置的许多流行层和模块,例如
hk.Linear。完整列表,请查看官方文档。 - 不提供激活函数,因为JAX已经有一个名为
jax.nn的子包,我们可以在那里找到relu或softmax等激活函数。
归一化层
层归一化是Transformer架构的另一个组成部分,我们也可以在Haiku的公共模块中找到。
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
"""使用默认设置对x应用唯一的LayerNorm。"""
return hk.LayerNorm(axis=-1,
create_scale=True,
create_offset=True,
name=name)(x)
Transformer
现在是重点。下面你可以看到一个非常简化的Transformer,它使用了我们预定义的模块。在__init__中,我们定义了基本变量,如层数、注意力头数和dropout率。在__call__中,我们使用for循环组合了一系列块。
如你所见,每个块包括:
- 一个归一化层
- 一个自注意力块
- 两个dropout层
- 两个归一化层
- 两个跳跃连接 (
h = h + h_attn和h = h + h_dense) - 一个2层的密集块
最后,我们还添加了最终的归一化层。
class Transformer(hk.Module):
"""一个Transformer堆栈。"""
def __init__(self,
num_heads: int,
num_layers: int,
dropout_rate: float,
name: Optional[str] = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self,
h: jnp.ndarray,
mask: Optional[jnp.ndarray],
is_training: bool) -> jnp.ndarray:
"""连接transformer。
Args:
h: 输入, [B, T, H].
mask: 填充掩码, [B, T].
is_training: 是否处于训练模式。
Returns:
形状为[B, T, H]的数组。
"""
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask[:, None, None, :]
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f'h{i}_ln_1')
h_attn = SelfAttention(
num_heads=self._num_heads,
key_size=64,
w_init_scale=init_scale,
name=f'h{i}_attn')(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f'h{i}_ln_2')
h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name='ln_f')
return h
我想现在你已经意识到,用JAX构建神经网络非常简单。
嵌入层
为了完整起见,我们也加入嵌入层。需要知道的是,Haiku也提供了一个嵌入层,它将从我们的输入句子中创建标记。然后将标记添加到位置嵌入中,产生最终的输入。
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :
tokens = data['obs']
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
# 嵌入输入标记和位置。
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs', [seq_length, d_model], init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
hk.get_parameter(param_name, ...)用于访问模块的可训练参数。但你可能会问,为什么不直接使用像在PyTorch中那样的对象属性。这就是Haiku的第二个关键原则发挥作用的地方。我们使用这个API,以便可以使用hk.transform将代码转换为纯函数。这理解起来并不简单,但我会尽量让它清晰明了。
为什么需要纯函数?
JAX的强大之处在于其函数变换:用vmap向量化函数的能力,用pmap自动并行化,用jit即时编译。这里需要注意的是,为了变换一个函数,它必须是纯的。
纯函数是具有以下属性的函数:
- 对于相同的参数,函数返回值是相同的(不随局部静态变量、非局部变量、可变引用参数或输入流的变化而变化)。
- 函数应用没有副作用(不改变局部静态变量、非局部变量、可变引用参数或输入/输出流)。
来源:O'Reily的Scala纯函数
这实际上意味着一个纯函数总是会:
- 如果使用相同的输入调用,则返回相同的结果
- 所有输入数据都通过函数参数传递,所有结果都通过函数结果输出
Haiku提供了一个名为hk.transform的函数转换,它将具有面向对象、功能上“不纯”的模块的函数转换为可以与JAX一起使用的纯函数。为了在实践中看到这一点,让我们继续训练我们的Transformer模型。
前向传播
一个典型的前向传播包括:
- 获取输入并计算输入嵌入
- 通过Transformer的块运行
- 返回输出
上述步骤可以很容易地用JAX组合如下:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
"""创建模型的前向传播。"""
def forward_fn(data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
"""前向传播。"""
input_embeddings, input_mask = embeddings(data, vocab_size)
# 在输入上运行transformer。
transformer = Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
# 反向嵌入(未绑定)。
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
虽然代码很简单,但其结构可能看起来有点奇怪。实际的前向传播是通过forward_fn函数执行的。然而,我们用build_forward_fn函数包装了这个函数,并返回forward_fn。这是怎么回事?
接下来,我们将需要使用hk.transform将forward_fn函数转换为纯函数,以便我们可以利用自动微分、并行化等。
这将通过以下方式完成:
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
这就是为什么我们不是简单地定义一个函数,而是包装并返回函数本身,或者更准确地说,是一个可调用对象。然后可以将这个可调用对象传递给hk.transform并成为一个纯函数。如果清楚了这一点,让我们继续看损失函数。
损失函数
损失函数是我们熟知的交叉熵函数,不同之处在于我们也考虑了掩码。同样,JAX提供了one_hot和log_softmax功能。
def lm_loss_fn(forward_fn,
vocab_size: int,
params,
rng,
data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
"""计算数据相对于参数的损失。"""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data['target'], vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data['obs'], 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
如果你还在坚持,喝一口咖啡,因为从现在开始事情会变得严肃起来。是时候构建我们的训练循环了。
训练循环
因为Jax和Haiku都没有内置优化功能,所以我们将使用另一个名为Optax的框架。如开头所述,Optax是用于梯度处理的包。
首先,关于Optax需要了解的一些事项:
- Optax的关键变换是
GradientTransformation。该变换由两个函数定义,__init__和__update__。__init__初始化状态,__update__根据状态和参数的当前值转换梯度:state = init(params) grads, state = update(grads, state, params=None)
在看代码之前,还需要了解Python内置的functools.partial函数。functools包处理高阶函数和可调用对象的操作。
如果一个函数包含其他函数作为参数或返回一个函数作为输出,则称为高阶函数。
partial(也可以用作注解)返回一个基于原始函数的新函数,但具有更少或固定的参数。例如,如果f将两个值x,y相乘,则partial将创建一个新函数,其中x将被固定为2:
from functools import partial
def f(x,y):
return x * y
# 创建一个乘以2的新函数(x将被固定为2)
g = partial(f,2)
print(g(4))
#返回 8
在这个简短插曲之后,让我们继续。为了简化我们的主函数,我们将把梯度更新提取到它自己的类中。
首先,GradientUpdater接受模型、损失函数和优化器。
- 模型将是通过
hk.transform转换的纯forward_fn函数forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate) forward_fn = hk.transform(forward_fn) - 损失函数将是具有固定
forward_fn和vocab_size的partial的结果loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) - 优化器是一系列将按顺序运行的优化变换(可以使用
optax.chain组合操作)optimizer = optax.chain( optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99))
梯度更新器将初始化如下:
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
并将如下所示:
class GradientUpdater:
"""围绕 init_fn/update_fn 对的无状态抽象。
这从训练循环中提取了一些常见的样板代码。
"""
def __init__(self, net_init, loss_fn,
optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
"""初始化更新器的状态。"""
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
"""使用一些数据更新状态并返回指标。"""
rng, new_rng = jax.random.split(state['rng'])
params = state['params']
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
new_state = {
'step': state['step'] + 1,
'rng': new_rng,
'opt_state': opt_state,
'params': params,
}
metrics = {
'step': state['step'],
'loss': loss,
}
return new_state, metrics
在__init__中,我们使用self._opt.init(params)初始化优化器,并声明优化状态。状态将是一个包含以下内容的字典:
- 当前步骤
- 优化器状态
- 可训练参数
- (一个随机生成器密钥,用于传递给
jax.random.split)
update函数将更新优化器的状态以及可训练参数。最后,它将返回新状态。
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
这里还有两件事需要注意:
jax.value_and_grad()是一个特殊的函数,它返回一个带有梯度的可微函数。__init__和__update__都用@functools.partial(jax.jit, static_argnums=0)注解,这将触发即时编译器并在运行时将其编译为XLA。请注意,如果我们没有将forward_fn转换为纯函数,这是不可能的。
最后,我们准备构建整个训练循环,它结合了迄今为止提到的所有想法和代码。
def main():
# 创建数据集。
train_dataset, vocab_size = load(batch_size,
sequence_length)
# 设置模型、损失和更新器。
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
# 初始化参数。
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
prev_time = time.time()
for step in range(MAX_STEPS):
data = next(train_dataset)
state, metrics = updater.update(state, data)
注意我们是如何整合GradientUpdate的。只需要两行代码:
state = updater.init(rng, data)
state, metrics = updater.update(state, data)
就是这样。我希望现在你对JAX及其功能有了更清晰的理解。
致谢
本文呈现的代码深受Haiku框架官方示例的启发。为适应本文需求进行了修改。完整示例列表,请查看官方仓库。
结论
在本文中,我们看到了如何使用JAX和Haiku开发和训练一个普通的Transformer。虽然代码不一定难以理解,但它仍然缺乏PyTorch或Tensorflow的可读性。强烈建议你尝试使用它,发现JAX的优势和劣势,看看它是否适合你的下一个项目。根据我的经验,JAX对于需要高性能的研究应用非常强大,但对于现实项目来说还相当不成熟。在我们的discord频道中告诉我们你的想法。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)
公众号二维码

公众号二维码


浙公网安备 33010602011771号