大模型剪枝流程总结

在一个预训练好的大模型中,通常会有一部分权重,他在大部分的问答中都是处于低激活,甚至几乎不激活的状态,这显然会浪费一部分显存和算力,在模型每次加载和传递的过程中。

为此,我们采用了大模型剪枝的方法,核心组件有两个,钩子(Hooks)和一套我们自己定义的目标剪枝(Targeted Pruning)策略

大型神经网络常常像一个“黑箱”,我们知道输入和输出,但中间发生了什么却难以捉摸。当我们发现模型表现不稳定时,首要任务就是搞清楚内部到底哪里出了问题。这就是钩子(Hooks)起作用的时候

  • 什么是钩子(hooks)

简单来说,钩子是PyTorch提供的一个强大机制,它允许我们在不改变任何模型源代码的情况下,在模型的前向或反向传播过程中挂载一个“监听器”。这个监听器(一个自定义函数)可以在数据流经特定层时被触发,让我们能够实时查看或修改该层的输入、输出和梯度。

在我们的代码中,诊断过程的核心是 detect_super_weights 函数。它通过以下步骤实现了对模型内部激活情况的精确探测

1.首先要确定把钩子(hooks)这个监听器安放在哪里

我们的目标是ViT模型中每个Transformer Block内部的第二个全连接层 (fc2),因为它一般被认为是特征转换的关键环节

# 遍历模型的所有Block (默认为40个)
for i, block in enumerate(model.blocks[:limit]):
    # 精准定位到 fc2 线性层
    if hasattr(block, 'mlp') and hasattr(block.mlp, 'fc2'):
        # 在这里,我们将挂载我们的钩子函数
        handle = block.mlp.fc2.register_forward_hook(...)

2.然后,我们开始定义我们的钩子(hooks)要研究模型的什么

activations = {} # 用于存储结果的字典

def hook_fn(module, input_tensor, output_tensor, layer_idx):
    # 1. 获取该层的输出张量 (output_tensor)
    # 2. 计算输出张量的绝对值的均值,这代表了这一层输出的整体激活强度
    mean_output_act = torch.abs(output_tensor.mean(dim=0))
    #在强度图中找到峰值,即最大的激活值及其对应的通道索引
    output_max_val, output_max_ch = torch.max(mean_output_act.mean(dim=(0, 1)), dim=0)
    #将层索引、最大激活值等信息存入全局的 activations 字典
    activations[layer_idx] = {
        'output_max_ch': output_max_ch.item(),
        'output_max_val': output_max_val.item(),
    }

 

3.怎么安装钩子(hooks)进我们的全连接层

我们通过 register_forward_hookhook_fn 安装到每个目标层上。然后,我们只需像往常一样向模型输入一批数据,这个过程就会自动触发所有安装好的钩子

#安装钩子
handle = block.mlp.fc2.register_forward_hook(
    lambda m, inp, out, idx=i: hook_fn(m, inp, out, idx)
)
#向模型输入数据,触发所有已安装的钩子
model(batch.to(device))

 

通过这行代码,我们可以告诉PyTorch,当数据前向传播完成fc2层的计算后,立刻调用我指定的hook_fn函数,并把该层的模块本身,输入和输出作为参数传给它

4.触发探针,即前向传播

到目前为止,这个hooks已经安装好了,只需要待命,等数据流进他,即可产生作用

model.eval()
with torch.no_grad():
    for batch in tqdm(sample_loader, desc="Analyzing activations"):
        # 当这句代码执行以后,数据开始在模型中流动,并经过钩子
        model(batch.to(device))

model(batch.to(device)) 执行时,输入的一批图像数据开始从模型的第一层走到最后一层。每当数据流经一个我们安插了“探针”的 fc2 层,该层对应的 hook_fn 就会被自动调用一次

5.后续在activations 字典里已经存满了40个层在这次前向传播中的最大激活值,对这个里面的结果进行可视化分析即可,找出中位数和异常值。

6.最后重点

for h in handles:
    h.remove() #移除所有已注册的钩子

钩子在完成任务后一定要移除,以避免不必要的计算开销和潜在的内存泄漏。这就是之前保存 handle 的原因。

第二部分:对模型进行剪枝

在之前的研究里面我们发现,有些激活层过度活跃了,过度活跃的层通常学会的是一种学习捷径(Shortcut Learning)。它们在训练数据中发现了一些非常表面化但又很管用的简单规律,并死死地抓住不放。因此,我们不能只使用这些活跃的层,因为它们很可能学到了错误或片面的东西。

为此在我们的案例中,问题是过度活跃。简单地移除小权重起不到作用,我们需要直接削弱那些过度活跃的层。

为此,我们重写了 prune_and_quantize 函数,赋予它两大新能力

1.可以指定是剪掉数值最小的权重,还是剪掉数值最大的权重

2.可以通过一个配置字典,为不同的层指定不同的剪枝率和剪枝方法

def prune_and_quantize(model: nn.Module, 
                       super_weights: List[Tuple[int, int, int]], 
                       pruning_config: Dict[int, Dict]) -> nn.Module:
    """
    功能标注: 对模型进行目标明确的剪枝。
              - pruning_config: 一个配置字典,例如 {0: {'ratio': 0.8, 'method': 'prune_largest'}}
                允许为不同层指定不同的剪枝率和方法(剪最大/最小)。
    """
    if not super_weights:
        logging.warning("No super weights detected, skipping pruning.")
        return model
    
    super_weight_layers = {sw[0] for sw in super_weights} # 获取所有超级权重层的索引

    for layer_idx in super_weight_layers:
        if layer_idx not in pruning_config:
            logging.info(f"Layer {layer_idx} is a super-weight layer but not in pruning_config. Skipping.")
            continue

        config = pruning_config[layer_idx]
        prune_ratio = config.get('ratio', 0.2)
        method = config.get('method', 'prune_smallest')
        
        logging.info(f"Pruning Layer {layer_idx} with ratio={prune_ratio} and method='{method}'")

        layer_mlp = model.blocks[layer_idx].mlp.fc2
        weight = layer_mlp.weight.data.clone()
        
        # (保护超级权重的逻辑)
        
        abs_weights = torch.abs(weight)
        #(根据 method 决定保护值为无穷大还是负无穷大)

        flat_abs_weights = abs_weights.flatten()
        
        if method == 'prune_smallest':
            threshold = torch.quantile(flat_abs_weights, prune_ratio)
            prune_mask = (abs_weights < threshold) & (~preserve_mask)
        elif method == 'prune_largest':
            # 关键逻辑:剪掉最大的权重,就是保留(1-prune_ratio)的小权重
            threshold = torch.quantile(flat_abs_weights, 1.0 - prune_ratio)
            prune_mask = (abs_weights > threshold) & (~preserve_mask)
        else:
            continue
            
        weight[prune_mask] = 0.0
        layer_mlp.weight.data = weight
        
    return model

至此,我们完成了对模型的激活检测和剪枝全流程

posted @ 2025-09-26 15:53  liujunxi  阅读(83)  评论(0)    收藏  举报