征程 6 | 平台 QAT 精度一致性问题分析流程

QAT 训练完成后,从 torch qat 伪量化模型到 征程 6 板端部署 hbm 模型之间,有模型 export 导出、convert 转定点、插入前处理节点以及 compile 编译等步骤,在这些步骤中,如果出现精度不一致的情况,说明存在一致性问题。一致性问题分为两类:

  1. 用户侧问题。例如:前后处理不一致,代码误用导致训练部署图不一致的问题等。
  2. 工具侧问题。例如:查表算子转定点(非线性函数使用多项式近似或分段线性近似来代替精确计算)、不同硬件对于浮点/定点实现不一致、rgb/yuv444 转 nv12 存在信息损失等,由于神经网络具有一定的鲁棒性,若不存在代码误用以及工具 bug 的情况下,板端 hbm 模型精度 与 torch qat 伪量化模型之间的误差很小。

不论哪类一致性问题,您都可以参考本文进行排查。

1.基础定义

一致性问题从 API 分割看,主要包括 export 前后、convert 前后、compile 前后,在分析过程中,可能还会引入查表算子转定点(pre_export)、插入 nv12 节点前后(insert_nv12)、删除首尾节点前后(remove_op)的一致性问题,在深入分析之前,大家先统一各阶段模型的概念:

d2hpdGVib2FyZF9leHBvcnRlZF9pbWFnZSAoNik=.png

image.png

2. 一致性问题定位流程

当出现一致性问题时,大家先确认自己的 horizon-plugin-pytorch、horizon-plugin-profiler、hbdk4-compiler 已升级到最新版本(本文发布时为 OE3.5.0,最新版本获取可见地平线算法工具链官网​,然后按照如下流程确认一致性问题发生阶段,参考下文介绍的每个阶段一致性定位方法进行排查。

d2hpdGVib2FyZF9leHBvcnRlZF9pbWFnZSAoNSk=.png

3.export 一致性分析

3.1 分析前提

  1. 分析 export 一致性时,请​先确认 qat_model eval 精度与单帧可视化符合预期​;
  2. qat.bc 与 qat_model eval 共用一套前后处理,保证不存在前后处理差异导致的一致性问题;
  3. qat.bc 多帧数据可视化均不符合预期;

3.2 分析思路

3.2.1 仅查表转定点

export 出现一致性问题时,通常需要先判断是否为 查表转定点导致的。具体方式为:将 qat_model 通过 pre_export 接口仅转查表,验证 pre_export_pt 可视化。

from horizon_plugin_pytorch.quantization.hbdk4 import pre_export
pre_export_pt = pre_export(qat_pt)
pre_export_ret = qat_export_pt(example_input) # 查表转定点后模型的推理结果,可以验证此时精度/可视化是否损失
  1. 若 pre_export_pt 多帧可视化 or 验证集精度指标 符合预期:说明查表算子没问题,跳过该章节
  2. 若 pre_export_pt 多帧可视化 or 验证集精度指标 不符合预期:说明是查表算子转定点引起的问题,需要排查具体是哪个查表造成的。

参考如下代码,运行 QAT debug 工具来分析查表算子的误差 qat_pt_vs_pre_export_pt(QAT debug 工具详细用法可见 《工具链在线手册-量化感知训练-开发指南-精度调优工具使用指南》)

from horizon_plugin_profiler import QuantAnalysis
from horizon_plugin_pytorch.quantization.hbdk4 import pre_export

# qat.pt和qat.export.pt跑一致性敏感度和逐层对比
qa = QuantAnalysis(qat_pt, pre_export_pt, "pre_export", out_dir="./qatpt_vs_qatexportpt")
qa.set_bad_case(bad_example_input)
qa.run()
qa.compare_per_layer()
qa.sensitivity()
  1. 【​定位具体查表 ​op​】若从 debug 工具产出物中未分析出是哪个(些)查表算子造成的一致性问题,可根据 plugin debug 工具的敏感度排序,设置敏感度高的部分 查表 op 取消转定点,缩小问题 op 范围。如果将部分 查表 op 取消转定点后,pre_export_pt 精度上升/可视化正常,则说明确实是这些 查表 op 导致。
# 此接口需要在 load qat.ckpt后添加
from horizon_plugin_pytorch.quantization.hbdk4 import pre_export
pre_export_pt = pre_export(qat_pt)
# output_xxx_sensitive_ops.txt top1
pre_export_pt.get_submodule("model.pts_bbox_head_pvb._generated_sin_0.sin").quantized_forward = False

# 取消多个查表转定点时
# op_fallback_list = set()
# op_fallback_list.add("header.cls_header.type_encoder.1.var_mean.mean")
# op_fallback_list.add("backbone.traj_encoder.mlp2.nn.2.lut")
# for op_name in op_fallback_list:
#     module = pre_export_pt.get_submodule(op_name)
#     module.quantized_forward = False
  1. 【查表转定点常见解决方案】常见有一致性问题的查表 op:rsqrt、reciprocal、sin/cos 等,可尝试增大 num_tables ​的数值来优化查表算子的一致性,用于拟合非线性函数的表项 num_tables 需配置为 6 的倍数,不同查表 op 默认 num_tables 不同,经验看,num_tables 超出 126 后对查表一致性几乎不再有收益。在 qat_model 加载权重后,在 pre_export 前配置 num_tables,配置示例如下:
qat_model._generated_rsqrt_0.rsqrt.num_tables = 108

常见有一致性问题的查表 op:sin/cos 算子,发现输入范围较大(超出-pi~pi 一个周期),可以将 sin/cos 替换为 plugin 的自定义算子,并配置 single_period=True,然后​重新 calib/qat(替换后,性能会差一点点,因此未工具层面自动替换)。

import horizon_plugin_pytorch.nn as hnn
class modelnet(nn.module):
    def __init__(self,):
        ...
        self.sin=hnn.Sin(single_period=True)
        self.cos=hnn.Cos(single_period=True)

也可以自行处理 sin/cos 输入,按照周期性将输入处理到[-pi, pi)之间,并​重新 calib/qat

x = x - 2 * torch.floor(x * ( 0.5 / torch.pi) + 0.5) * torch.pi

若上述方案无法解决查表阶段的问题,请准备好​​ qatpt_vs_qatexportpt 产出物中的 txt 文件​,在地平线开发者社区-工具链板块上提问。

3.2.2 图一致性

在确认仅查表转定点 pre_export_pt 模型的精度/多帧可视化符合预期后,若 qat.bc 依旧存在精度问题,请优先检查 export 通路代码中是否存在 if 部署逻辑(只有部署才走的通路),若存在,先尝试不走部署逻辑 export 生成 qat_bc,验证此时 qat_bc 可视化是否符合预期。

  1. 若符合预期​:说明 if 逻辑造成图不一致影响了权重加载或代码有误。

对于图不一致的排查方法,还可以查看 fx_graph.txt,从中获取到模型中 op/module 的上下游调用关系,排查导出计算图是否发生改变。例如当存在算子 called times 为 0 未被调用的情况,可以通过 Graph 定位到上下文算子从而定位未被调用的原因(通常因为存在逻辑判断或循环次数变化);

# 模型Graph图结构信息
Graph:
opcode         name                                           target                                                                    args                                                                                           kwargs
-------------  ---------------------------------------------  ------------------------------------------------------------------------  ---------------------------------------------------------------------------------------------  -----------------------------
placeholder    input_0                                        input_0                                                                   ()                                                                                             {}
call_module    quant                                          quant                                                                     (input_0,)                                                                                     {}
call_module    traj_decoder_src_proj_0_0                      traj_decoder_src_proj.0.0                                                 (quant,)                                                                                       {}
call_function  __getitem__                                    <slot wrapper '__getitem__' of 'torch.Size' objects>                      (__get__, 0)                                                                                   {}
call_function  __getitem___1                                  <slot wrapper '__getitem__' of 'torch.Size' objects>                      (__get__, 1)                                                                                   {}
call_function  __getitem___2                                  <slot wrapper '__getitem__' of 'torch.Size' objects>                      (__get__, 2)                                                                                   {}
...

重点关注的 Graph 信息:

  • opcode 为算子调用类型
  • name 为当前算子名称,需注意和 model_check_result.txt 中的 module.submodule 名称区别
  • target 为算子输出
  • args 为算子输入
  1. 若不符合预期​:往下尝试 3.2.3 plugin debug 工具

3.2.3 plugin debug 工具

当 qat_export.pt 指标正常,qat.bc 精度指标不符合预期,且不存在图不一致问题时,需要运行 plugin debug 工具来分析“export”阶段一致性问题,

from horizon_plugin_profiler import QuantAnalysis

qa = QuantAnalysis(pre_export_pt, qat_bc, "export", out_dir="./pre_export_pt_vs_qatbc")
# torch 与 bc 可接受同一格式输入时,一起跑统计量
qa.set_bad_case(badcase)
qa.run()

# torch 与 bc 不可接受同一格式输入时,分开跑统计量,pt_badcase 与 bc_badcase 除格式外全部相同。
qa.set_bad_case(pt_badcase)
qa.run(run_baseline_model=True, run_analysis_model=False)
qa.set_bad_case(bc_badcase)
qa.run(run_baseline_model=False, run_analysis_model=True)

# 逐层对比
qa.compare_per_layer()

# qat.export.pt 跑一致性敏感度,qat_bc起到占位作用
qa = QuantAnalysis(pre_export_pt, qat_bc, "export", out_dir="./pre_export_pt_vs_qatbc")
qa.set_bad_case(pt_badcase)
qa.sensitivity()

判断正确运行 plugin debug 工具方法:

  1. compare_per_layer_out.txt:存在对比结果
  2. output_xxx_sensitive_ops.txt:敏感度有高有低,且最后几个算子的量化敏感度接近于 0

分析 pre_export_pt_vs_qatbc 阶段的 debug 工具产出物,若未发现问题所在或不知如何修改,请准备好​​ pre_export_pt_vs_qatbc 产出物中的 txt 文件 +qat.bc、qat.onnx​,在地平线开发者社区-工具链板块上提问。

4. convert 一致性分析

4.1 分析前提

  1. 分析 convert 一致性时,说明 qat.bc 精度/可视化符合预期,quantized.bc 多帧数据可视化均不符合预期;
  2. qat.bc 与 quantized.bc 使用相同的输入和后处理,避免非模型部分引起的差异;

4.2 分析思路

4.2.1 征程 6EM 高一致性策略【OE3.5.0 为 beta 功能】

注意​:

  1. 高一致性策略对查表转定点无影响,主要影响 convert 前后的一致性
  2. level0 全局开启会对 latency 有负面影响,大约 10~20%,甚至出现过 40% 的情况
  3. level2 对 latency 有正面收益,推荐优先使用 level2
  4. 高一致性策略仅适用于 征程 6EM
  5. 实现方式未来会进行优化,请大家使用时关注用户手册《QAT-训练部署一致性-高一致性 QAT 策略》章节

高一致性策略封装在 horizon_plugin_pytorch.qat_mode.ConsistencyStrategy 下,可以使用 set_consistency_level 接口设置策略。

当前支持五个等级( 0 - 4 )的策略,等级越高,一致性越好,但 QAT 精度可能受到轻微影响。推荐直接使用 level 2,在绝大多数情况下对 QAT 精度无影响,甚至可以改善因截断误差引起的精度问题,对性能和一致性有正收益。

对于未使用高一致性策略得到的 QAT 模型,如果希望不重训获得一致性更高的定点模型,可以在 prepare export 模型前设置一致性策略等级为 0(不重训的情况下只有 level 0 有效,level 1 - 4 需要设置等级后重训模型)。

from horizon_plugin_pytorch.qat_mode import ConsistencyStrategy

# 必须在 prepare 之前设置一致性策略
ConsistencyStrategy.set_consistency_level(2)
...
qat_pt = prepare(float_model)
...
qat_bc = export(qat_pt, example_inputs)
# 如果在prepare前设置 ConsistencyStrategy.set_consistency_level(0), 可以做如下检查
# print(qat_bc._high_precision_qpp)    # 需要是 true,不要用assert检查
# print(qat_bc._fuse_requantize)       # 需要是 false, 不要用assert检查

quantized_bc = convert(qat_bc, march)

level2 在 convert 阶段,linear 与 conv 会有一个 scale 的误差,其它 op 是对齐的

level4 在 convert 阶段,linear 与 conv 也会有一个 scale 的误差,但概率会降低到万分之几

linear 与 conv 将 bias 去掉,level4 在 convert 阶段将没有误差

4.2.2 plugin debug 工具

当采用高一致性策略未解决 convert 前后的一致性问题时,需要运行 plugin debug 工具来分析“convert”前后一致性问题,建议使用高一致性策略后的模型来对比分析,示例如下

from horizon_plugin_profiler import QuantAnalysis
from horizon_plugin_pytorch.quantization.hbdk4 import pre_export

# qat.bc 和 quantized.bc 跑逐层对比
qa = QuantAnalysis(qat_bc, quantized_bc, "convert", out_dir="./qatbc_vs_quantizedbc")
qa.set_bad_case(bad_example_input)
qa.run()
qa.compare_per_layer()

# qat.export.pt 跑一致性敏感度,quantzed_bc起到占位作用
qa = QuantAnalysis(pre_export_pt, quantized_bc, "convert", out_dir="./qatbc_vs_quantizedbc")
qa.set_bad_case(bad_example_input)    # 注意,此处bad_example_input与跑逐层的一致
qa.sensitivity()

判断正确运行 plugin debug 工具方法:

  1. compare_per_layer_out.txt:存在对比结果
  2. output_xxx_sensitive_ops.txt:敏感度有高有低,且最后几个算子的量化敏感度接近于 0

分析 qatbc_vs_quantizedbc 阶段的 debug 工具产出物,若未发现问题所在或不知如何修改,请准备好​​ qatbc_vs_quantizedbc 产出物中的 txt 文件 +qat.bc+qat.onnx+quantized.bc+quantized.onnx​,在地平线开发者社区-工具链板块上提问。

4.2.3 分段转浮点

绝大部分情况下,plugin debug 工具都可以分析解决 convert 前后一致性问题,若您发现 plugin debug 工具失效或不想适配使用 plugin debug 工具,工具链还支持分段转浮点的方法来分析 convert 前后一致性,具体做法是将 qat.bc 中 某 op 或 一定范围的 op 配置为 CPU 算子,从而定位出引起 convert 定点化中掉点的 op。

在 qat.bc 模型中,每个节点都有一个 id,根据 id 将某些伪量化删除可以使得模型的一部分变成 cpu 算子,下图为 qat.onnx 的可视化图。

aW1hZ2U=.png

bc 编辑工具在 horizon_plugin_profiler/bc_editor/bc_editor.py,使用方式如下:

python bc_editor.py --bc_path qat.bc --new_bc_path new_qat.bc --config_path config.json

config.json 内容可以参考 horizon_plugin_profiler/bc_editor/config_template.json,指定需要删除的伪量化 op id,可以是一个区间 id,也可以是单个 op id,通过该方案,可很容易实现分段浮点。

{
    "remove_fake_quant": [[1, 100], 102]
}

问题确认后,若不知如何修改,请记录分析过程,在地平线开发者社区-工具链板块上提问。

5. nv12 节点插入一致性分析

板端视频通路传输给模型的数据格式为 nv12,通常算法同学会使用 RGB/YUV444 训练模型,由于 nv12 数据量是 RGB/YUV444 等格式的一半,因此必然存在信息损失,通常情况下,神经网络的鲁棒性是可以接受这种误差的。征程 6 工具链支持在模型前端插入一个前处理节点,以实现颜色空间转换(如 NV12 -> BGR),可由 BPU 进行加速,具体实现示例可见《J6 计算平台部署指南 -6.3 模型修改》。

5.1 分析前提

  1. 分析 nv12 节点插入一致性时,说明 quantized.bc 精度/可视化符合预期,nv12_quantized.bc 多帧数据可视化均不符合预期;
  2. quantized.bc 与 nv12_quantized.bc 使用相同的后处理,避免因后处理差异引入一致性问题;

5.2 分析思路

nv12 输入理论上对于模型输出影响很小,可以按照如下三个思路来挨个验证:

  1. nv12 节点插入代码误用
  2. nv12 输入数据准备差异
  3. 确实是 nv12 引入的误差(非 bug 类)

5.2.1 nv12 节点插入代码误用

nv12 节点插入具体细节请参考工具链用户手册 或 配套的迁移文档,常见的误用在 insert_image_preprocess 中的 mode 参数,具体示例如下,详见代码注释:

from hbdk4.compiler import save, convert, visualize, compile, load
    
    qat_model = load("qat.bc")
    quantized_hbir_model = convert(qat_model, march)
    save(quantized_hbir_model, "quantized_no_insert.bc")

    qat_model = load("qat.bc")
    func = qat_model.functions[0]
    for input in func.inputs[::-1]:
        # pyramid&resizer 只支持 NHWC 的 input layout,若原始输入layout为NHWC,则无需插入transpose
        node = input.insert_transpose(permutes=[0, 3, 1, 2])
        # 插入前处理节点,mode=None适用于使用YUV444训练的模型
        # node = node.insert_image_preprocess(mode=None, divisor=1, mean=[128, 128, 128], std=[128, 128, 128])
        # 插入前处理节点,mode="yuvbt601full2rgb"适用于使用RGB训练的模型
        node = node.insert_image_preprocess(mode="yuvbt601full2rgb", divisor=1, mean=[128, 128, 128], std=[128, 128, 128])
        node.insert_image_convert("nv12")
        
    quantized_insert = convert(qat_model, march)
    save(quantized_insert, "nv12_quantized.bc")

5.2.2 nv12 输入数据准备差异

推荐采用如下代码准备 nv12 数据

from hbdk4.compiler import load, visualize
import numpy as np
from PIL import Image

def generate_nv12(img):
    w,h = img.size
    # Convert images to YUV format
    yuv_img = img.convert('YCbCr')
    y_data, u_data, v_data = yuv_img.split()

    # Convert Y, U, and V channel data to byte streams
    y_data_bytes = y_data.tobytes()
    u_data_bytes = u_data.resize((u_data.width // 2, u_data.height // 2)).tobytes()
    v_data_bytes = v_data.resize((v_data.width // 2, v_data.height // 2)).tobytes()

    # Arrange the UV data in the form of UVUVUVUV... 
    uvuvuv_data = bytearray()
    for u_byte, v_byte in zip(u_data_bytes, v_data_bytes):
        uvuvuv_data.extend([u_byte, v_byte])

    # Input for the hbir model
    y = np.frombuffer(y_data_bytes, dtype=np.uint8).reshape(1, h, w, 1).astype(np.uint8)
    # np.save("y_data.npy", y)
    uv = np.frombuffer(uvuvuv_data, dtype=np.uint8).reshape(1, h//2, w//2, 2).astype(np.uint8)
    # np.save("uv_data.npy", uv)
    return y, uv

# Generate random RGB values in the range 0-255
# image_data = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)

# 建议读取使用场景中的真实图片
image = Image.open("test.jpg").convert("RGB")  # 转为RGB三通道
# 转成numpy数组,形状为 [H, W, 3]
image_data = np.array(image, dtype=np.uint8)

# Convert the numpy array to a PIL image
img = Image.fromarray(image_data)
y, uv = generate_nv12(img)
quantized_insert_inputs = {"_input_0_y": y, "_input_0_uv": uv}

5.2.3 非 bug 类 nv12 引入的误差

如果你的网络对 nv12 节点插入造成误差特别敏感,则需要将该误差带入到模型训练中,可参考如下代码:

import horizon_plugin_pytorch.nn.bgr_to_yuv444 as b2y
class BgrToYuv444(object):
    """
    BgrToYuv444 is used for color format convert.
    .. note::
        Affected keys: 'img'.
    Args:
        rgb_input (bool): The input is rgb input or not.
    """
    def __init__(self, affect_key: str = "img", rgb_input: bool = False):
        self.affect_key = affect_key
        self.rgb_input = rgb_input
    def __call__(self, data):
        if isinstance(data, dict) and self.affect_key not in data:
            return data
        image = data[self.affect_key] if isinstance(data, dict) else data
        ndim = image.ndim
        if ndim == 3:
            image = torch.unsqueeze(image, 0)
        if image.dtype is not torch.uint8:
            image = image.to(dtype=torch.uint8)
        if image.shape[1] == 6:
            image1 = b2y.bgr_to_yuv444(image[:, :3], self.rgb_input).float()
            image2 = b2y.bgr_to_yuv444(image[:, 3:], self.rgb_input).float()
            image = torch.cat((image1, image2), dim=1)
        else:
            image = b2y.bgr_to_yuv444(image, self.rgb_input)
            image = image.float()
        if ndim == 3:
            image = image[0]
        if isinstance(data, dict):
            data[self.affect_key] = image
            return data
        else:
            return image

其中,b2y 内部实现了 bgr->nv12->yuv444 的转换。

6.compile 一致性分析

6.1 分析前提

  1. 分析 compile 一致性时,说明 quantized.bc 或 nv12_quantized.bc 精度/可视化没问题。
  2. 模型中没有浮点算子时,可以做到小数点后 4 位一致,如果有浮点算子,由于不同硬件平台对浮点算子的 实现方式、支持精度(FP32/FP16)、底层数学库 等存在差异,存在差异是普遍存在的,不一定能做到小数点后 4 位对齐。
  3. bc 与 hbm 使用的前后处理一致。

6.2 分析思路

为了方便不同编码习惯的客户快速比对 compile 前后 bc 与 hbm 的一致性,工具链提供了三种分析方法:

  1. 使用命令行工具 hb_verifier 快速比对
  2. 使用​​ python ​API​:hbdk 接口快速比对(推理速度相对较慢)
  3. 使用​​ python ​API​:hbm_infer 接口快速比对(推理速度相对较快)

6.2.1 hb_verifier 工具

hb_verifier 比对 bc 与 hbm 一致性时,需要关注的信息如下:

bc 与 hbm 一致性比对时,输出信息如下:

aW1hZ2U= (1).png

比对示例如下:hbm 推理支持板端与 x86 仿真两种运行方式,二者结果是一样的,板端推理速度会更快一些。

hb_verifier -m quantized_nv12.bc,quantized_nv12.hbm -i y_data.npy,uv_data.npy --ip None,xx.xx.xx.xx
  1. 若一致:则一致性问题出现在前后处理没对齐。
  2. 若不一致:请准备好​​ quantized.bc 与 hbm​,在地平线开发者社区-工具链板块上提问。

6.2.2 hbdk 接口推理

使用 hbdk 提供的 API 接口 hbm[0]。feed,在相同输入的情况下(可以是算法侧提供,也可以是软件侧提供),推理 quantized.bc 与 hbm(hbm 推理支持板端与 x86 仿真两种运行方式,二者结果是一样的,板端推理速度会更快一些),验证他们的输出一致性/可视化,带 nv12 节点的验证示例代码如下:

from hbdk4.compiler import load, Hbm
import numpy as np
from PIL import Image

def generate_nv12(img):
    w,h = img.size
    # Convert images to YUV format
    yuv_img = img.convert('YCbCr')
    y_data, u_data, v_data = yuv_img.split()

    # Convert Y, U, and V channel data to byte streams
    y_data_bytes = y_data.tobytes()
    u_data_bytes = u_data.resize((u_data.width // 2, u_data.height // 2)).tobytes()
    v_data_bytes = v_data.resize((v_data.width // 2, v_data.height // 2)).tobytes()

    # Arrange the UV data in the form of UVUVUVUV... 
    uvuvuv_data = bytearray()
    for u_byte, v_byte in zip(u_data_bytes, v_data_bytes):
        uvuvuv_data.extend([u_byte, v_byte])

    # Input for the hbir model
    y = np.frombuffer(y_data_bytes, dtype=np.uint8).reshape(1, h, w, 1).astype(np.uint8)
    # np.save("y_data.npy", y)
    uv = np.frombuffer(uvuvuv_data, dtype=np.uint8).reshape(1, h//2, w//2, 2).astype(np.uint8)
    # np.save("uv_data.npy", uv)
    return y, uv

def compare_arrays(array1, array2, decimal_places=2):
    """
    Compare two arrays for consistency up to a specified number of decimal places.

    Parameters:
    - array1: First numpy array.
    - array2: Second numpy array.
    - decimal_places: Number of decimal places to consider for alignment.

    Returns:
    - are_equal: True if arrays are consistent up to the specified decimal places, False otherwise.
    - max_difference: Maximum difference (absolute value) if arrays are not consistent, else 0.
    """
    # Round the arrays to the specified decimal places
    rounded1 = np.round(array1, decimals=decimal_places)
    rounded2 = np.round(array2, decimals=decimal_places)
    
    # Check equality
    are_equal = np.array_equal(rounded1, rounded2)
    
    # Calculate maximum difference if not equal
    max_difference = 0
    if not are_equal:
        max_difference = np.max(np.abs(array1 - array2))
    
    return are_equal, max_difference

hbir = load("./quantized_nv12_remove_stage3.bc")
hbm = Hbm("./quantized_nv12_remove_stage3.hbm")

# Create a random image with the shape (1, 512, 960, 3)
# Generate random RGB values in the range 0-255
image_data = np.random.randint(0, 256, (512, 960, 3), dtype=np.uint8)
# Convert the numpy array to a PIL image
img = Image.fromarray(image_data)
y, uv = generate_nv12(img)

inputs = {"input_0_y": y, "input_0_uv": uv}

# 分别进行hbir和Hbm推理
hbir_outputs = hbir[0].feed(inputs)
# print("hbir_outputs:", hbir_outputs)
hbm_x86_outputs = hbm[0].feed(inputs)        # x86推理
# print("hbm_x86_outputs:", hbm_x86_outputs)

# # 远程连接BPU,实现板端Hbm推理
# # 运行前需要安装 `hbdk4_runtime_aarch64`的wheel包,根据需要选择nash。
hbm_arrch64_outputs = hbm[0].feed(inputs, remote_ip="10.64.60.165", remote_port="22", remote_work_root="/map/xxx/")
# print("hbm_arrch64_outputs:", hbm_arrch64_outputs)

# 比较Hbir和hbm输出
for idx, v in enumerate(hbir[0].flatten_outputs):
    hbir_data = hbir_outputs[v.name]
    hbm_arrch64_data1 = hbm_x86_outputs[v.name]
    are_equal, max_difference = compare_arrays(hbir_data, hbm_arrch64_data1, decimal_places=4)
    if not are_equal:
        print("Maximum difference:", max_difference)
    else:
        print(f"{v.name} is equal!")

若不一致:请准备好​​ quantized.bc+hbm+ 复现脚本​,在地平线开发者社区-工具链板块上提问。

6.2.3 hbm_infer 接口推理

使用 python 推理 quantized.bc,使用 hbm_infer 工具 推理 hbm(hbm_infer 工具详细介绍可参考用户手册《UCP-模型推理开发-模型推理工具介绍-hbm_infer 工具介绍》)。

输入数据的读取代码需要用户根据实际的目录和文件格式进行修改,如下示例是以。bin 文件为例,经过量化然后介入 bc 与 hbm 模型。如果是 numpy 或者 pkl 文件,需要根据实际情况进行读取和处理。

from hbdk4.compiler import load, Hbm
import numpy as np
from PIL import Image
import os
import pickle
import numpy as np
from hbm_infer.hbm_rpc_session_flexible import HbmRpcSession, init_server, deinit_server, init_hbm, deinit_hbm
    
if __name__ =="__main__":
    data_path="inputs"
    #删除
    hbir = load("./model_quantized_removequant.bc")
    hbm_path1="./modelp_remove_quan.hbm"
    hbm_rpc_server1 = init_server(host="xx.xx.xx.xx")  # 确保有root权限
    hbm_handle1 = init_hbm(hbm_rpc_server=hbm_rpc_server1, local_hbm_path=hbm_path1)
    hbm_model1 = HbmRpcSession(
        hbm_handle=hbm_handle1,
        hbm_rpc_server=hbm_rpc_server1,
    )
    # hbm.show_input_output_info()
    print("========= BEGIN test_validate ! =========")
    inputs=hbir[0].flatten_inputs
    input_data={}
    for i,input in enumerate(inputs):
        path=os.path.join(data_path,input.name,"0.bin")
        data=np.fromfile(path, dtype=np.float32).reshape(input.type.shape)
        scale=input.quant_info.scales[0]
        if input.type.torch_dtype=="torch.int16":
            dtype_=np.int16
            min_=-32768
            max_=32767
        if input.type.torch_dtype=="torch.int8":
            dtype_=np.int8
            min_=-128
            max_=127
        data = data / scale
        data = np.round(data )
        data= np.clip(data, min_, max_)
        data= data.astype(dtype_)
        np.save(f"{i}_quan.npy",data) 
        input_data[input.name]=data
    
    hbir_outputs = hbir[0].feed(input_data)
    hbm_arrch64_outputs1 = hbm_model1(input_data)
    
    for idx, v in enumerate(hbir[0].flatten_outputs):
        hbir_data = hbir_outputs[v.name]
        hbm_arrch64_data1 = hbm_arrch64_outputs1[v.name]
        diff = np.abs(hbm_arrch64_data - hbm_arrch64_data1).reshape(np.prod(hbm_arrch64_data.shape))
        print(f"{v.name} max error is {max(diff)}")
    hbm_model.close_server()        # 删除log
    deinit_server(hbm_rpc_server)   # 删除板端 server 文件,避免资源占用
    deinit_hbm(hbm_handle)          # 删除板端 hbm 文件,避免资源占用

若不一致:请准备好​​ quantized.bc+hbm+ 复现脚本​,在地平线开发者社区-工具链板块上提问。

posted @ 2026-01-13 13:32  地平线智能驾驶开发者  阅读(0)  评论(0)    收藏  举报