代码层面上学习Gemma模型
总览
本文留下调试 Gemma 模型的记录。很乱,但我想不出更好的组织方式了。
gemma-2b 模型被封装在 GemmaForCausalLM 类中,这个类继承于 GemmaPreTrainedModel。
而模型的本体是 GemmaModel 类(这个对象实例包含在 GemmaForCausalLM 实例中)。 也继承于 GemmaPreTrainedModel。
GemmaPreTrainedModel 继承于 transformers.PreTrainedModel,重写了 _init_weights() 负责 nn.Linear 和 nn.Embedding 的权重初始化。额外增加了 _setup_cache() 和 _reset_cache() 两个方法,处理缓存 past_key_value。
GemmaModel 中有个 self.layers 对象,由 18 层的 GemmaDecoderLayer 构成。
总之,
- 继承关系:
GemmaForCausalLM->GemmaModel->GemmaPreTrainedModel->PreTrainedModel - 自注意力的 18 层所在位置:
GemmaModel实例内
GemmaForCausalLM 步骤
PreTrainedModel 类继承了 GenerationMixin,使得 GemmaForCausalLM 拥有 generate() 方法。用 model.generate() 进行文本生成,这就进入了 GemmaForCausalLM 步骤。
在进入该步骤之前,tokenizer 使用左填充的方式将多序列填充到相同长度。
首先使用 self.model(GemmaModel)让 input_ids 经过一系列的自注意力机制,输出 hidden_states。
使用 self.lm_head(nn.Linear)映射到 256000 维度,得到 logits。转换到 float32。
GemmaModel 步骤
这是 Gemma 的核心。
经过 Embedding,转换为嵌入后,乘上 hidden_size**0.5 进行标准化。
接下来是 18 层 GemmaDecoderLayer。
- 存储一个
residual,开始自注意力机制self.input_layernorm(GemmaRMSNorm),首先 \(x·\frac{1}{\sqrt{x^2+\epsilon}}\) 对每个 embedding 归一化,然后使用可训练权重做乘法。全程在 float32 下运算self.self_attn(GemmaSdpaAttention),自注意力。是对F.scaled_dot_product_attention()的封装。使用 MultiQueryAttention,完成注意力后进行线性变换- 使用
residual进行残差连接
- 存储一个
residual,开始全连接层self.post_attention_layernorm(GemmaRMSNorm),再次归一化self.mlp(GemmaMLP),多层感知器。从 2048 维到 16348 维映射出 \(x_1\) 和 \(x_2\),进行 \(\text{gelu}(x_1)·x_2\),再映射回 2048 维- 使用
residual进行残差连接
- 重复 18 次
最后再经过 self.norm(GemmaRMSNorm)。
各步骤中值得注意的地方
GemmaDecoderLayer 之前
attention_mask,避免对填充标记执行注意力操作。多用于 batch_size 大于 1、对输入序列进行 padding 的情况,避免模型对 padding 施加 attention(因为无意义)。
causal_mask 是用 _update_causal_mask() 从 attention_mask 转换而得。具体来说是从 [0,1] 转换为 [0, -inf](通过 torch.finfo(dtype).min 获得负无穷)。
causal_mask与attention_mask在代码中的角色很混乱。函数嵌套过程中两者会相互转换。
在 _update_causal_mask() 中,还会使用 AttentionMaskConverter._unmask_unattended() 取消填充部分的 mask,以适应 SDPA 的节省内存注意力方法。
GemmaSdpaAttention 之中
使用到了 past_key_value(transformers.DynamicCache 类型) 存储每一层的 kv 中间结果,总共存储 18 对 key_states 和 value_states。这是 KV Cache 机制,能够显著提升速度。
经过调试可得知,从生成第二个新词开始,输入到模型的 hidden_states 就只有一个 token 的长度了。多亏了 KV Cache,不需要额外计算前面已有词的 Key 和 Value。
本节参考:
- Young,“大模型推理性能优化之KV Cache解读”,https://zhuanlan.zhihu.com/p/630832593
本文参考
谷歌的 Gemma 开源模型和代码,以及 HuggingFace 的 Transformers。

浙公网安备 33010602011771号