hook 工具随笔

hook 是 pytorch 中的一个工具,主要作用是 在模型前向/反向传播过程之前/之后执行一些自定义操作

比如说: 打印查看一些模型参数;修改梯度等等

基本形态

基本的形式是:

# 自定义操作
def forward_hook(module, input, output):
    print(f"Module: {module}")
    print(f"Input: {input}")
    print(f"Output: {output}")
  
# 注册 hook
hook = model.fc1.register_forward_hook(forward_hook)

# 传递
output = model(input_data)

# 移除 hook
hook.remove()

其中 forward_hook 是自定义的函数,其 输入参数 随 register_xx_hook 的不同而有所不同,常见的 register 类型如下表所示:

Hook 类型 注册函数 函数签名 调用时机
forward hook register_forward_hook(hook) hook(module, input, output) 模块执行完前向传播后调用
forward pre-hook register_forward_pre_hook(hook) hook(module, input) 模块执行前向传播前调用
full backward hook register_full_backward_hook(hook) hook(module, grad_input, grad_output) 模块反向传播时调用,兼容复杂的 autograd 结构
Tensor 级别 hook tensor.register_hook(hook) hook(grad) 注册在 Tensor 上,在该 Tensor 的梯度被计算时调用

实现机制

hook 实际上是基于 Autograd 引擎的回调机制 实现的。大致流程如下:

  1. 每个 nn.Module 在执行 forward() 时,都会在 Autograd 图中注册节点(Function)。
  2. 当你注册 hook 时,PyTorch 会把你的函数加入到这些节点的 回调列表。
  3. 前向传播:如果是 forward_pre_hook → 在执行 module.forward() 前调用;如果是 forward_hook → 在 forward() 结束后调用。
  4. 反向传播:当 Autograd 计算梯度经过该节点时,会触发该节点的 backward hook。

内部实现类似这样:

# 伪代码示例
for pre_hook in module._forward_pre_hooks:
    x = pre_hook(module, x)
    
out = module.forward(x)

for hook in module._forward_hooks:
    out = hook(module, x, out)

实例介绍

这段代码来自论文中的一个实验,其核心思想可以理解为:针对两个不同的 query,它们在语义上应当得到相同的 answer。研究者通过将第一个 query 在前向传播过程中生成的中间隐藏状态,替换到第二个 query 的对应层位置中,然后再让模型继续生成输出。若此时模型仍然能够产生目标 answer,就说明模型的推理结果更多地依赖于其内部的表征与推导能力 😎 ,而非仅仅依靠上下文记忆或表面模式匹配

这个过程正是通过 forward_hook 实现的——在前向传播结束后,动态地修改指定层的输出,从而实现对模型中间表示的干预与验证。

def cross_query_semantic_patching(model, tokenizer, device, queries, position, layer):
    # initialize counters
    success_counts = 0
    total_counts = 0

    for source_prompt, target_prompt, expected_e3 in tqdm(queries):

        # get the source hidden states
        decoder_temp = tokenizer([source_prompt], return_tensors="pt", padding=True)
        decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
        decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)

        with torch.no_grad():
            outputs1 = model(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                output_hidden_states=True
            )

        hidden_states_batch = outputs1.hidden_states  # [1+num_layers, batch_size, seq_len, hidden_size]

        # replace the hidden states of the target position with the source hidden states
        def hook_fn(module, input, output):
            # output ([batch_size, seq_len, hidden_size], ...)
            main_output = output[0].clone()
            main_output[0, position, :] = hidden_states_batch[layer][0, position, :]
            return (main_output,) + output[1:]

        # 注册前向钩子
        handle = model.transformer.h[layer - 1].register_forward_hook(hook_fn) # [num_layers, batch_size, seq_len, hidden_size]

        # target prompt
        decoder_temp = tokenizer([target_prompt], return_tensors="pt", padding=True)
        decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
        target_decoder_input_ids, target_decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)

        with torch.no_grad():
            outputs2 = model(
                input_ids=target_decoder_input_ids,
                attention_mask=target_decoder_attention_mask,
                # output_hidden_states=True
            )

        # 移除钩子,避免影响后续推理
        handle.remove()

        # decode the predicted token
        logits = outputs2.logits  # [batch_size, seq_len, vocab_size]
        predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]
        decoded_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
        decoded_token = decoded_text[0].split()[-1]         

        # check if the decoded token is the expected token
        total_counts += 1
        if decoded_token == expected_e3:
            success_counts += 1
    return success_counts/total_counts

模型结构介绍

主要聚焦于这一部分:

hidden_states_batch = outputs1.hidden_states  # [1+num_layers, batch_size, seq_len, hidden_size]

# replace the hidden states of the target position with the source hidden states
def hook_fn(module, input, output):
    # output ([batch_size, seq_len, hidden_size], ...)
    main_output = output[0].clone()
    main_output[0, position, :] = hidden_states_batch[layer][0, position, :]
    return (main_output,) + output[1:]

# 注册前向钩子
handle = model.transformer.h[layer - 1].register_forward_hook(hook_fn) # [num_layers, batch_size, seq_len, hidden_size]

这里主要是回顾一下模型结构:

hidden_states_batch 是模型在前向传播过程中保存的各层隐藏状态,其维度为 [1 + num_layers, batch_size, seq_len, hidden_size],其中第 0 个元素对应 embedding 层的输出,而后续的每个元素 hidden_states_batch[layer] 分别对应第 layer 个 Transformer 层的隐藏状态;model.transformer.h 是模型中所有 Transformer 层的列表,因此 model.transformer.h[layer - 1] 就表示第 layer 层,索引从 0 开始。

在 hook_fn 中,input 表示传入该层的张量,output 表示该层前向传播后的输出结果。由于这里的目标是修改该层在前向传播后的输出隐藏状态,而不是输入,因此选择对 output 进行处理;output 通常是一个元组 (hidden_states, other_outputs),其中 output[0] 是该层的主要输出张量,即 [batch_size, seq_len, hidden_size] 的隐藏状态,而其余部分可能包含注意力缓存或其他附加信息,之后的替换就很好理解。

posted @ 2025-10-30 19:43  亦可九天揽月  阅读(7)  评论(0)    收藏  举报