深度学习(onnx量化)

onnx中的动态量化和静态量化概念与pytorch中的核心思想一致,但实现工具、流程和具体api有所不同。

onnx量化通常依赖onnxrunntime来执行量化模型,并使用onnx工具库进行模型转换。

除了pytorch量化和onnx量化,实际工作中一般像英伟达、地平线、昇腾等不同的芯片都会有各自独特的工具链和加速算子,按照官方教程使用即可。

下面同样给了两个例子,可以验证一下。结合上篇代码可以做个对比。

动态量化:

import torch
import torch.nn as nn
import warnings
import numpy as np
import onnxruntime as ort

warnings.filterwarnings("ignore")
from onnxruntime.quantization import QuantType, quantize_dynamic

class SimpleLSTM(nn.Module):
    """简单的LSTM模型,适合动态量化"""
    def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # LSTM层
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        
        # 全连接层
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # 初始化隐藏状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        
        # LSTM前向传播
        out, _ = self.lstm(x, (h0, c0))
        
        # 只取最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        return out

def quantize_onnx_model(input_model_path, output_model_path):
    quantize_dynamic(
        input_model_path,
        output_model_path,
        weight_type=QuantType.QInt8    
    )
    print(f"ONNX模型已动态量化并保存到: {output_model_path}")


if __name__ == '__main__':
    model = SimpleLSTM()
    x = torch.randn(1, 10, 10)  # 假设输入

    torch.onnx.export(
        model,                  # model being run
        x,                      # model input (or a tuple for multiple inputs)
        "simple_lstm.onnx",     # where to save the model
        export_params=True,     # store the trained parameter weights inside the model file
        opset_version=12,       # the ONNX version to export the model to
        input_names = ['input'],   # the model's input names
        output_names = ['output'])

    quantize_onnx_model("simple_lstm.onnx", "simple_lstm_quantized.onnx")
    
    # 测试ONNX模型和量化后的模型
    x = np.random.randn(1, 10, 10).astype(np.float32)  # 假设输入
    ort_session = ort.InferenceSession("simple_lstm.onnx")
    ort_session_quantized = ort.InferenceSession("simple_lstm_quantized.onnx")

    inputs = {ort_session.get_inputs()[0].name: x}
    outputs = ort_session.run(None, inputs)
    print("ONNX模型输出:\n", outputs[0])

    inputs = {ort_session_quantized.get_inputs()[0].name: x}
    outputs_quantized = ort_session_quantized.run(None, inputs)
    print("动态量化后的ONNX模型输出:\n", outputs_quantized[0])

 静态量化:

import torch
import numpy as np
import warnings
import onnx
import onnxruntime as ort
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static, CalibrationDataReader
warnings.filterwarnings("ignore")

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
       # self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 100, 1)
        self.conv1 = torch.nn.Conv2d(100, 100, 1)
        self.conv2 = torch.nn.Conv2d(100, 100, 1)
        self.conv3 = torch.nn.Conv2d(100, 1, 1)

        self.relu1 = torch.nn.ReLU()
        self.relu2 = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
      #  self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
       # x = self.quant(x)
        x = self.conv(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
      #  x = self.dequant(x)
        return x


# 1. 准备校准数据集类
class CustomCalibrationDataReader(CalibrationDataReader):
    def __init__(self, calibration_data_path, input_name):
        """
        初始化校准数据读取器
        
        参数:
            calibration_data_path: 校准数据.npz文件路径
            input_name: 模型输入名称
        """
        self.data = np.load(calibration_data_path)
        self.input_name = input_name
        self.datasize = len(self.data.files[0])
        self.enum_data = iter(self.data[self.data.files[0]])
    
    def get_next(self):
        """
        获取下一批校准数据
        """
        try:
            batch = next(self.enum_data)
            return {self.input_name: np.expand_dims(batch, axis=0)}
        except StopIteration:
            return None
    
    def rewind(self):
        """
        重置数据迭代器
        """
        self.enum_data = iter(self.data[self.data.files[0]])

# 2. 主量化函数
def quantize_onnx_model_static(original_model_path, quantized_model_path, calibration_data_path):
    """
    执行ONNX模型静态量化
    
    参数:
        original_model_path: 原始FP32模型路径
        quantized_model_path: 量化后模型保存路径
        calibration_data_path: 校准数据集路径(.npz格式)
    """
    # 加载原始模型
    model = onnx.load(original_model_path)
    
    # 获取模型输入名称
    input_name = model.graph.input[0].name
    
    # 创建校准数据读取器
    calibration_data_reader = CustomCalibrationDataReader(
        calibration_data_path, 
        input_name
    )
    
    quantize_static(
        model_input=original_model_path,
        model_output=quantized_model_path,
        calibration_data_reader=calibration_data_reader,
        quant_format=QuantFormat.QDQ ,  # QDQ 或 QOperator
        per_channel=True,               # 每通道量化
        reduce_range=True,              # 减少量化范围(某些CPU需要)
        activation_type=QuantType.QInt8,  # 激活量化类型
        weight_type=QuantType.QInt8,      # 权重量化类型
    )
    
    print(f"量化完成!量化模型已保存至: {quantized_model_path}")

# 3. 辅助函数:生成校准数据集
def generate_calibration_data(output_path, num_samples=100):
    """
    生成校准数据集
    
    参数:
        output_path: 校准数据保存路径(.npz)
        num_samples: 生成样本数量
    """        
    # 创建随机输入数据 (根据实际模型调整)
    calibration_data = []
    for _ in range(num_samples):
        data = np.random.randn(1,4,4).astype(np.float32)  # 生成随机数据
        calibration_data.append(data)
    
    # 保存为.npz文件
    np.savez(output_path, calibration_data=np.array(calibration_data))
    print(f"已生成 {num_samples} 个校准样本到: {output_path}")

# 4. 使用示例
if __name__ == "__main__":

    MODEL_FP32 = 'model_fp32.onnx'
    MODEL_INT8 = 'model_int8.onnx'

    model_fp32 = Model()
    x = torch.randn(1, 1, 4, 4)  # 假设输入
    torch.onnx.export(
        model_fp32,
        x,
        MODEL_FP32,
        input_names=['input'],
        output_names=['output'])

    # 步骤1: 生成校准数据 (如果已有数据可跳过)
    generate_calibration_data("calibration_data.npz", num_samples=100)
    
    # 步骤2: 执行静态量化
    quantize_onnx_model_static(
        original_model_path=MODEL_FP32 ,
        quantized_model_path=MODEL_INT8,
        calibration_data_path="calibration_data.npz"
    )
    
    # 步骤3: 验证量化模型 (可选)
    # 加载量化模型
    x = np.random.randn(1, 1, 4, 4).astype(np.float32)  # 假设输入
    ort_session = ort.InferenceSession(MODEL_FP32)
    ort_session_quantized = ort.InferenceSession(MODEL_INT8)

    inputs = {ort_session.get_inputs()[0].name: x}
    outputs = ort_session.run(None, inputs)
    print("ONNX模型输出:\n", outputs[0])

    inputs = {ort_session_quantized.get_inputs()[0].name: x}
    outputs_quantized = ort_session_quantized.run(None, inputs)
    print("静态量化后的ONNX模型输出:\n", outputs_quantized[0])
posted @ 2025-07-28 20:58  Dsp Tian  阅读(212)  评论(0)    收藏  举报