昇腾 msmodelslim w8a8量化代码解析


最近有很多朋友都在部署deepseek模型,而且都用到了模型量化这个功能,目的是减少显存占用、提升推理速度。

上图是w8a8量化算法流程,主要包含4步:

①,使用昇腾 msmodelslim 仓库提供的量化接口对原始模型权重进行量化,生成int8格式的权重文件,以及后续在推理的时候要用到的激活值的量化参数和 matmul 结果的反量化参数;

②,推理执行过程中,把Matmul的激活值(也就是输入X)进行int8量化;

③,执行int8格式的Matmul计算;

④,把int8的乘法结果进行反量化。

这篇文章讲解第①步的内容。msmodelslim提供的deepseek模型量化的参考脚本的链接如下:

Ascend/msit​gitee.com/ascend/msit/tree/br_noncom_MindStudio_8.0.0_POC_20251231/msmodelslim/example/DeepSeek​编辑

入口脚本 quant_deepseek_w8a8.py 的代码内容如下(br_noncom_MindStudio_8.0.0_POC_20251231分支,commit 06a6e8):

#Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import argparse
import functools
import json
import torch
import torch.nn.functional as F
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from msmodelslim.tools.convert_fp8_to_bf16 import auto_convert_model_fp8_to_bf16, OpsType
from msmodelslim.tools.copy_config_files import copy_config_files, modify_config_json
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim.tools.logger import set_logger_level

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, help="model and tokenizer path"),
    parser.add_argument('--save_path', type=str, help="save path"),
    parser.add_argument('--layer_count', type=int, default=0)
    parser.add_argument('--anti_dataset', type=str, default="./anti_prompt.json")
    parser.add_argument('--calib_dataset', type=str, default="./calib_prompt.json")
    parser.add_argument('--fp8', action='store_true')
    parser.add_argument('--bf16', action='store_true')
    return parser.parse_args()

def custom_hook(model_config):
    model_config["mla_quantize"] = "w8a8"

args = parse_args()
set_logger_level("warning")
pbar = tqdm(total=4, position=0, desc="Total Process")
model_path = args.model_path
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_path, trust_remote_code=True)
config.num_hidden_layers = args.layer_count if args.layer_count != 0 else config.num_hidden_layers

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_path,
                                          config=config,
                                          trust_remote_code=True,
                                          use_fast=True,
                                          add_eos_token=True)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path,
                                             config=config,
                                             trust_remote_code=True,
                                             device_map="auto",
                                             torch_dtype="auto",
                                             max_memory={
                                                 0: "50GiB",
                                                 "cpu": "1500GiB"
                                             },
                                             attn_implementation='eager')

auto_convert_model_fp8_to_bf16(model, model_path, OpsType.get_ops_type(args.bf16, args.fp8))

pbar.update(1)

def get_anti_dataset(tokenizer, calib_list, device="npu"):
    calib_dataset = []
    max_len = 0
    for calib_data in calib_list:
        inputs = tokenizer(calib_data, return_tensors='pt')
        calib_dataset.append(inputs.data['input_ids'].to(device))
        max_len = max(max_len, inputs.data['input_ids'].size(1))
    for i in range(len(calib_dataset)):
        calib_dataset[i] = F.pad(calib_dataset[i], (0, max_len - calib_dataset[i].size(1)), value=0)
    return torch.cat(calib_dataset)

def get_calib_dataset(tokenizer, calib_list, device="npu"):
    calib_dataset = []
    for calib_data in calib_list:
        inputs = tokenizer(calib_data, return_tensors='pt').to(device)
        calib_dataset.append([inputs.data['input_ids']])
    return calib_dataset

with open(args.anti_dataset, "r") as file:
    anti_prompt = json.load(file)
with open(args.calib_dataset, "r") as file:
    calib_prompt = json.load(file)

anti_data = []
for i in range(len(anti_prompt)):
    tmp = get_anti_dataset(tokenizer, anti_prompt[i])
    anti_data.append(tmp)

anti_dataset = []
for data in anti_data:
    anti_dataset.append([data])

dataset_calib = []
for i in range(len(calib_prompt)):
    tmp = get_calib_dataset(tokenizer,calib_prompt[i])
    dataset_calib += (tmp)

with torch.no_grad():
    anti_config = AntiOutlierConfig(w_bit=8,
                                    a_bit=8,
                                    anti_method='m4',
                                    dev_type='npu',
                                    dev_id=model.device.index)
    anti_outlier = AntiOutlier(model, calib_data=anti_dataset, cfg=anti_config)
    anti_outlier.process()
pbar.update(1)

disable_names = []
for ids in range(config.num_hidden_layers):
    disable_names.append("model.layers." + str(ids) + ".self_attn.kv_b_proj")

quant_config = QuantConfig(
    a_bit=8,
    w_bit=8,
    disable_names=disable_names,
    dev_type='npu',
    dev_id=model.device.index,
    act_method=1,
    pr=1.0,
    w_sym=True,
    mm_tensor=False,
    is_dynamic=True
)

calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level="L0")
calibrator.run()
pbar.update(1)
calibrator.save(args.save_path, save_type=["safe_tensor"], part_file_size=4)

custom_hooks = {
    'config.json': functools.partial(modify_config_json, custom_hook=custom_hook)
}
copy_config_files(input_path=args.model_path, output_path=args.save_path, quant_config=quant_config, custom_hooks=custom_hooks)
pbar.update(1)

这篇文章会从入口脚本出发,对w8a8量化的技术原理和代码进行解析。

1. 算法原理

上面的代码涉及到2个类:AntiOutlier 和 Calibrator,AntiOutlier 代表的是激活异常值抑制,Calibrator 是激活值和权重量化。

1.1 激活异常值抑制

对于int8量化算法,浮点数量化后的取值是有限的(-128、-127、...、127),所以浮点数的分布范围越广的话,量化步长就越大,那么就有更多的浮点数会被量化成同一个数值,也就会引入更大的误差。而且大家发现,对于大模型里面的Matmul,激活值X和权重W的浮点数分布很不相同,X的分布范围更大,这样就会导致激活值X的量化误差较大。

为了解决激活值的量化误差问题,有人提出了 smoothQuant 算法。这个算法的原理很简单:让X除以一个值s,W乘以s,这样的话,X/s和W/s的分布就会更加平滑,同时(X/s)(Ws)=X*W,保证乘积不变。s的计算公式如下:

其中 Xj 代表 X 的第 j 列, Wj 代表 W 的第 j 行, α 一般取0.5。示例如下:

在代码实现中,因为X是norm层的输出,所以会把X除以s的操作转移到norm层,让norm层的权重除以s,这样就不用在推理的过程中,再做一个除法。

1.2 w8a8量化

这部分没什么好说的,quant_deepseek_w8a8.py 中用到的就是 min-max 量化算法。对于一个权重tensor或者激活X的tensor来说,假如它的最大、最小值分别为max、min,那么首先可以求出 scale=(max-min)/255,然后得出量化公式为

x是tensor的每个元素。当然,除了以tensor粒度求解scale,还有以tensor的每个通道值分布求scale的,我们称作per-channel。

需要注意的是,权重在推理之前就是已知的,所以不需要做模型推理、直接对权重文件的数据进行量化即可;但是激活值X是在推理的时候才能获取的,所以我们还需要准备一些“校准数据集”,让模型做一些前向推理,以此确定激活值的量化参数。

算法理论部分到这里就结束了,比较简单,接下来我们看看代码层面是如何实现的。

2. 代码解析

上图是deepseek w8a8量化入口脚本,主要包含3个部分:异常值抑制、w8a8量化、保存量化权重和相关参数。

2.1 异常值抑制

anti_config包含的参数如下:


    anti_config = AntiOutlierConfig(w_bit=8,
                                    a_bit=8,
                                    anti_method='m4',
                                    dev_type='npu',
                                    dev_id=model.device.index)

其中w_bit和a_bit代表权重量化位数和激活值量化位数,anti_method代表抑制算法,'m4'是 smooth_quant(m1) 的改进方法,相比于 smooth_quant 增加了量化层。dev_type和dev_id代表运行异常值抑制使用的设备。

anti_outlier的核心代码在 msmodelslim\msmodelslim\pytorch\llm_ptq\anti_outlier\anti_outlier.py。init()函数和process()函数的核心逻辑流程图如下:

步骤1 是在AntiOutlier类的init()函数中完成的,后续步骤是在process()函数中完成的。

初始化有向无环图是在init()函数的这个部分执行的:

        try:
            self.init_dag()
        except Exception as e:
            raise Exception("Please check your config, model and input!", e) from e

对于attention模型,构建DAG图的过程就是找出RMSNorm算子和它们连接的linear层。对于抑制算法"m1",我们做w8a8量化的目标层是qkv乘法和up、gate的全连接层;对于抑制算法"m4",还包含了O层和down层(O层的激活值scale转移到V层,down层的激活值scale转移到up层)。

init_dag()函数核心代码如下:

        if self.norm_class_name is not None:  # 可以手动指定norm层
            norm_class = list(OrderedDict.fromkeys([m.__class__ for m in self.model.modules() if
                                                    self.norm_class_name.lower() == m.__class__.__name__.lower()]))
        else:
            # 查找包含“norm”字段的层
            norm_class = list(
                OrderedDict.fromkeys(
                    [m.__class__ for m in self.model.modules() if "norm" in m.__class__.__name__.lower()]))
            norm_class = [norm_class[0]]
            self.norm_class_name = norm_class[0].__name__.lower()
        if ProcessHook.GET_NORM_LINEAR_SUBGRAPH not in self.hooks or self.hooks[
            ProcessHook.GET_NORM_LINEAR_SUBGRAPH] is None:
            # 调用extract_dag()获取DAG图
            dag = extract_dag(self.model, dummy_input,
                            hook_nodes=norm_class, anti_method=self.cfg.anti_method)
            self.norm_linear_subgraph = dag.get_norm_linear_subgraph()
            if self.cfg.anti_method == 'm4':
                self.linear_linear_subgraph = dag.get_linear_linear_subgraph()
                self.norm_linear_subgraph.update(self.linear_linear_subgraph)
            del dag

上面的代码中主要调用了 extract_dag() 函数获取DAG图,然后得到norm_linear_subgraph。extract_dag调用的又是TorchDAGAdapter类,这篇文章不做详细分析。norm_linear_subgraph 的格式如下所示:
norm_linear_subgraph{'model.layers0.input_layernorm': ['model.layers0.attn.q_proj', 'model.layers0.attn.k_proj', 'model.layers0.attn.j_proj'], 'model.layers0.post_attention_layernorm': ['model.layers0.mlp.gate_proj', 'model.layers0.mlp.up_proj'], ...}

anti_outlier的process的核心代码如下:

    def _process(self):
        ...
        # 给模型层注册hook,执行推理,记录每层的输入输出
        act_stats = self.os_stats()
        ....
        # 遍历需要做量化的层
        for norm_name_group in tqdm(iterable=self.norm_linear_subgraph.keys(), desc="AntiOutlier Process", position=1):
            linear_names = self.norm_linear_subgraph[norm_name_group]
            if isinstance(norm_name_group, str):
                norm_module = PatternProcess.get_module_by_name(self.model, norm_name_group)
            ...
            stats = act_stats[linear_name]

            is_expert = any("expert" in name.lower() for name in linear_names)
            if (is_expert):
                continue

            self.logger.debug(f"smooth {norm_name_group} -> {linear_names}")

            for name in linear_names:
                mod = PatternProcess.get_module_by_name(self.model, name)
                linear_modules.append(mod)

            ...
            # 对权重进行smooth
            if Multiplier is not None and norm_module is None:
                norm_module = Multiplier(
                    torch.ones_like(stats[STAT_KEY_SMOOTH_SCALE]).to(linear_modules[0].weight.device)
                )

            prepare_list = [PrepareWeight(norm_module, post_force=True, post_recurse=True)]
            prepare_list += [PrepareWeight(mod, post_force=True) for mod in linear_modules]
            # 对norm层权重进行smooth
            with ResListToRelease(*prepare_list):
                if self.cfg.anti_method == 'm1' or self.cfg.anti_method == 'm5':
                    smooth_ln_fcs(self.cfg, norm_module, linear_modules, stats, alpha=self.cfg.alpha)
                elif self.cfg.anti_method == 'm2':
                    os_ln_fcs(self.cfg, norm_module, linear_modules, stats, os_k=self.cfg.os_k)
                elif self.cfg.anti_method == 'm3':
                    weight_aware(self.cfg, norm_module, linear_modules, stats)
                elif self.cfg.anti_method == 'm4':
                    if 'scale_min' in inspect.signature(iter_smooth).parameters:
                        fusion_kwargs.update({"scale_min": scale_min})
                    if 'check_group_fusions' not in inspect.signature(iter_smooth).parameters:
                        fusion_kwargs.pop("check_group_fusions", None)
                    self.logger.debug(f"fusion_kwargs is {fusion_kwargs}")
                    iter_smooth(
                        self.cfg, norm_module, linear_modules, stats, num_attention_heads, **fusion_kwargs
                        )
                    if attach_op is not None and Multiplier is not None and isinstance(norm_module, Multiplier):
                        attach_op(self.model, norm_module, linear_modules, linear_names)

上面的代码,首先执行self.os_stats(),这个函数的功能是在模型层上注册hook,然后使用校准数据进行推理,收集每层的输入输出。

然后遍历 norm_linear_subgraph 的值,把norm层和对应的linear层找出来,先把权重乘以s,然后使用CANN包路径 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/msmodelslim 下的so里面的方法做激活层的smooth,也就是对norm层的权重进行处理。

以上就是异常值处理的主要逻辑,完成异常值处理后,model里面的norm层和linear层的权重已经发生了变化,model会继续传给后续的calibrator处理。

2.2 w8a8量化代码

首先需要设置量化方法的参数:

quant_config = QuantConfig(
    a_bit=8,
    w_bit=8,
    disable_names=disable_names,
    dev_type='npu',
    dev_id=model.device.index,
    act_method=1,
    pr=1.0,
    w_sym=True,
    mm_tensor=False,
    is_dynamic=True
)

a_bit和w_bit代表量化bit数;disable_names代表不做量化层的名称;act_method代表激活值的量化方法,“1”代表min-max;pr是概率参数,非1时量化生成的参数带有随机性;w_sym是指权重是否做对称量化;mm_tensor=False代表权重是per-channel量化;is_dynamic=True代表激活量化使用动态量化,也就是量化参数是在推理的时候生成,is_dynamic=False代表静态量化,在调用calibrator量化权重的时候就把激活值的量化参数计算好,动态量化精度更高,但是性能更差。

再来看一下实例化calibrator:

calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level="L0")

传入了异常值抑制后的model、quant_config、校准数据集和disable_level。disable_level='“Ln”代表模型结构从最后一层往前数的n层不做量化。校准数据集一般采用模型实际应用场景的数据,而且在调试精度的时候,如果发现量化模型在某条数据上精度较差,可以把该条数据加入校准数据集,再进行校准量化。

再来看一下init()函数做了哪些事情:

    def __init__(self, model,
                 cfg: QuantConfig,
                 calib_data=None,
                 disable_level='L0',
                 all_tensors=None):
        ...
        # 获取校准数据集
        self.calib_data = self.get_calib_data([]) if calib_data is None else self.get_calib_data(calib_data)
        self.use_kvcache_quant = cfg.use_kvcache_quant  # false
        self.norm_class_name = cfg.norm_class_name

        ...
        # 创建字典记录量化参数
        self.quant_param_dict = AutoSaveDict(self.cfg, max_gb_size=1)
        # 记录被量化module名称,相关的scale、offset等参数名称 key:weight的名称, value:scale、offset等参数的名称
        self.quantized_module_param_dict = defaultdict(list)
        self.fa_module_param_dict = defaultdict(list)
        ...
        # 初始化模型权重json描述
        self.quant_model_json_description = QuantModelJsonDescription(self.cfg.model_quant_type,
                                                                      self.cfg.use_kvcache_quant,
                                                                      self.cfg.use_fa_quant)
        if not re.match(r'^L((?!0)\d+|0)$', disable_level):
            raise ValueError('Please check the `disable_level` configuration.')
        self.disable_level = disable_level

        model = self.init_model_device(model)
        self.last_layer_name = None
        self.rollback_names = None
        self.quant_linear_names = None
        # 记录激活值量化相关的参数
        self.act_states = None
        # 确认不参与量化的层
        self.rollback_names_process(model)
        ...
        # 对模型做量化的层进行替换,替换成可以计算量化参数的“quant modules”
        try:
            self.model = self.quantize(model)
            if self.calib_data:
                self.enable_quant()
        except Exception as e:
            raise Exception("Please check the model and configuration.", e) from e

        self.named_module_count = len(list(self.model.named_modules()))
        self.logger.info("Quantizer initialized successful!")

首先创建了一个字典 self.quant_param_dict 用来保存量化参数(scale、offset等值),然后创建了一个文件quant_model_json_description,这个其实就是量化完成之后,文件夹下面生成的quant_model_description_w8a8.json。接着 rollback_names_process() 就是量化目标层回滚。如果不回滚的话,就是模型所有层都做量化。这个函数是根据disable_names和disable_levels把不量化的层剔除掉。做量化精度调试的时候,一个主要的方式就是尝试回滚不同的层数。

接着就到了self.quantize(model),这是一个重要的步骤。它调用了quantize_model()函数,主要内容如下:

        for name, mod in model.named_modules():
            with PrepareWeight(mod):
                # 跳过不做量化的层
                if name in self.rollback_names:
                    continue
                if isinstance(mod, nn.Linear) or isinstance(mod, nn.modules.linear.NonDynamicallyQuantizableLinear):
                    ...
                    elif self.cfg.model_quant_type is not QuantType.W8A8S:
                        is_dynamic = self.cfg.is_dynamic
                        if "mlp" in name and self.is_deepseek_v2:
                            if self.cfg.model_quant_type is QuantType.W8A8:
                                is_dynamic = True
                        # 生成“quant modules”
                        quant_mod = LinearQuantizer(cfg=self.cfg, logger=self.logger, is_dynamic=is_dynamic)
                    else:
                        quant_mod = LinearSparseQuantizer(cfg=self.cfg, logger=self.logger)
                    quant_mod.set_param(mod)
                    move_update_weight_hook_if_need(mod, quant_mod)
                    # 把需要做量化的线性层替换成具备量化功能的quant modules
                    _set_module(model, name, quant_mod)
        ...

这段代码的功能主要是遍历模型的所有模块,把需要做量化的层找到,然后把它们替换成 LinearQuantizer 类,这个类的forward函数在做前向推理的过程中会计算量化参数。

calibrator.run()调用了self._run,self._run主要调用的是self.run_calib_mode(),run_calib_mode()函数的核心代码如下:

        for data in tqdm(iterable=self.calib_data, position=1, desc="Calibrator Process"):
            if not amp_done and self.cfg.fa_amp:
                enable_fa_quantizer_record(self.model)

            if isinstance(data, tuple) or isinstance(data, list):
                self.model(*data)
            elif isinstance(data, dict):
                self.model(**data)

            ...

上面的代码主要是遍历校准数据,把每条数据传给模型做一次前向推理。在前向推理的过程中,LinearQuantizer 的forward()函数会进行关键的量化操作。我们可以看一下相关代码:

首先看一下 LinearQuantizer 的初始化函数:

    def __init__(self, cfg=None, logger=None, is_dynamic=False):
        """
        cfg: quantizaton configuration
        """
        super(LinearQuantizer, self).__init__()
        self.in_features = None
        self.out_features = None
        self.weight = None
        self.bias = None
        # 激活值的Tensor量化器
        self.quant_input = TensorQuantizer(
            bit=cfg.a_bit, is_signed=cfg.a_signed, is_enable=True,
            is_input=True, cfg=cfg, logger=logger, is_dynamic=is_dynamic
        )
        # 权重的Tensor量化器
        self.quant_weight = TensorQuantizer(
            bit=cfg.w_bit, is_signed=cfg.w_signed, is_enable=True,
            is_input=False, cfg=cfg, logger=logger
        )

可以看到,代码中初始化了激活值量化器和权重量化器,我们继续看一下它们的forward函数逻辑:

    def forward(self, x):
        if self.quant_weight.int_infer and (not self.quant_weight.is_calib):
            return self._int_infer_forward(x)
        else:
            if self.quant_input.w_hessian:  # gptq
                weight = self.quant_weight(self.weight, y=x.clone())
            else:
                weight = self.quant_weight(self.weight)
            if self.quant_input.bit <= 8:
                x = self.quant_input(x)
            return F.linear(x, weight, self.bias)

首先是调用self.quant_weight获得了量化后的权重,然后调用self.quant_input获得了量化后的激活值,再执行linear操作。self.quant_weight 和 self.quant_input 都是 TensorQuantizer 类,TensorQuantizer的forward函数代码如下:

    def tensor_forward(self, tensor, y=None):
        ...
        # weight quantization
        with torch.no_grad():
            # 对权重进行量化
            if not self.is_input:
                return self._quant_weight_forward(tensor, y)

            # activation quantization
            if self.is_dynamic:
                self._stat_dynamic_input(tensor)
            # 对输入进行量化
            return self._quant_activation_forward(tensor)

_quant_weight_forward 和 _quant_activation_forward 调用的 _init_weight_quant_normal、fake_quantize 和 linear_quantization_params 都是在CANN包中的so里面实现的,这里不做解析,其原理就是int8量化。

上面的代码在执行过程中会生成权重量化和激活值量化的参数,并作为TensorQuantizer的属性保存在内存中,后续调用calibrator.save()的收获收集保存。

执行完calibrator.run()后,就可以执行 calibrator.save() 保存权重和量化参数了,save() 函数中主要调用 self.get_quant_params() 来获取量化权重和相关参数,代码比较简单,在这里不做讲解。

以上就是这篇文章的所有内容,下面内容会解析mindie做量化推理的代码逻辑,敬请期待!

本文由博客一文多发平台 OpenWrite 发布!

posted @ 2025-03-31 10:19  AI布道Mr-Jin  阅读(176)  评论(0)    收藏  举报