代码层面上学习Gemma模型

总览

本文留下调试 Gemma 模型的记录。很乱,但我想不出更好的组织方式了。

gemma-2b 模型被封装在 GemmaForCausalLM 类中,这个类继承于 GemmaPreTrainedModel

而模型的本体是 GemmaModel 类(这个对象实例包含在 GemmaForCausalLM 实例中)。 也继承于 GemmaPreTrainedModel

GemmaPreTrainedModel 继承于 transformers.PreTrainedModel,重写了 _init_weights() 负责 nn.Linearnn.Embedding 的权重初始化。额外增加了 _setup_cache()_reset_cache() 两个方法,处理缓存 past_key_value

GemmaModel 中有个 self.layers 对象,由 18 层的 GemmaDecoderLayer 构成。

总之,

  • 继承关系:GemmaForCausalLM -> GemmaModel -> GemmaPreTrainedModel -> PreTrainedModel
  • 自注意力的 18 层所在位置:GemmaModel 实例内

GemmaForCausalLM 步骤

PreTrainedModel 类继承了 GenerationMixin,使得 GemmaForCausalLM 拥有 generate() 方法。用 model.generate() 进行文本生成,这就进入了 GemmaForCausalLM 步骤。

在进入该步骤之前,tokenizer 使用左填充的方式将多序列填充到相同长度。

首先使用 self.modelGemmaModel)让 input_ids 经过一系列的自注意力机制,输出 hidden_states

使用 self.lm_headnn.Linear)映射到 256000 维度,得到 logits。转换到 float32。

GemmaModel 步骤

这是 Gemma 的核心。

经过 Embedding,转换为嵌入后,乘上 hidden_size**0.5 进行标准化。

接下来是 18 层 GemmaDecoderLayer

  • 存储一个 residual,开始自注意力机制
    • self.input_layernormGemmaRMSNorm),首先 \(x·\frac{1}{\sqrt{x^2+\epsilon}}\) 对每个 embedding 归一化,然后使用可训练权重做乘法。全程在 float32 下运算
    • self.self_attnGemmaSdpaAttention),自注意力。是对 F.scaled_dot_product_attention() 的封装。使用 MultiQueryAttention,完成注意力后进行线性变换
    • 使用 residual 进行残差连接
  • 存储一个 residual,开始全连接层
    • self.post_attention_layernormGemmaRMSNorm),再次归一化
    • self.mlpGemmaMLP),多层感知器。从 2048 维到 16348 维映射出 \(x_1\)\(x_2\),进行 \(\text{gelu}(x_1)·x_2\),再映射回 2048 维
    • 使用 residual 进行残差连接
  • 重复 18 次

最后再经过 self.normGemmaRMSNorm)。

各步骤中值得注意的地方

GemmaDecoderLayer 之前

attention_mask,避免对填充标记执行注意力操作。多用于 batch_size 大于 1、对输入序列进行 padding 的情况,避免模型对 padding 施加 attention(因为无意义)。

causal_mask 是用 _update_causal_mask()attention_mask 转换而得。具体来说是从 [0,1] 转换为 [0, -inf](通过 torch.finfo(dtype).min 获得负无穷)。

causal_maskattention_mask 在代码中的角色很混乱。函数嵌套过程中两者会相互转换。

_update_causal_mask() 中,还会使用 AttentionMaskConverter._unmask_unattended() 取消填充部分的 mask,以适应 SDPA 的节省内存注意力方法。

GemmaSdpaAttention 之中

使用到了 past_key_valuetransformers.DynamicCache 类型) 存储每一层的 kv 中间结果,总共存储 18 对 key_statesvalue_states。这是 KV Cache 机制,能够显著提升速度。

经过调试可得知,从生成第二个新词开始,输入到模型的 hidden_states 就只有一个 token 的长度了。多亏了 KV Cache,不需要额外计算前面已有词的 Key 和 Value。

本节参考:

本文参考

谷歌的 Gemma 开源模型和代码,以及 HuggingFace 的 Transformers。

posted @ 2024-04-23 22:06  倒地  阅读(33)  评论(0编辑  收藏  举报