深度学习(pytorch量化)

pytorch中的动态量化和静态量化是两种主要的模型量化技术,旨在通过使用低精度数据类型(如 int8)代替高精度数据类型(如 float32)来减小模型大小并加速推理。

动态量化:在模型运行时(推理时)动态计算激活(activations)的量化参数(scale 和 zero_point)。权重(weights)的量化通常在模型加载时或第一次运行前进行。

静态量化:在模型部署之前,使用一个代表性的校准数据集(Calibration Dataset)预先确定网络中所有权重和所有激活的量化参数(scale 和 zero_point)。这些参数在推理过程中是固定的(静态的)。

部署时通常静态量化比较常用一些。下面给了两个例子,可以验证一下。

动态量化:

import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
# torch.serialization.add_safe_globals([torch.ScriptObject])

class SimpleLSTM(nn.Module):
    """简单的LSTM模型,适合动态量化"""
    def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # LSTM层
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        
        # 全连接层
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # 初始化隐藏状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        
        # LSTM前向传播
        out, _ = self.lstm(x, (h0, c0))
        
        # 只取最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        return out


def save_fp32_model(model_fp32, x):
    model_fp32.eval()
    y = model_fp32(x)
    print("FP32模型输出:", y)
    torch.save(model_fp32.state_dict(), 'model_fp32.pth')


def load_fp32_model(x):
    model_fp32 = SimpleLSTM()
    model_fp32.load_state_dict(torch.load('model_fp32.pth'))
    model_fp32.eval()
    y_fp32 = model_fp32(x)
    print("加载的FP32模型输出:", y_fp32)
    return model_fp32


def save_int8_model(model_fp32, x):
    model_int8 = torch.quantization.quantize_dynamic(
        model_fp32,
        {nn.LSTM,nn.Linear},
        dtype=torch.qint8
    )
    model_int8.eval()
    y_int8 = model_int8(x)
    print("INT8模型输出:", y_int8)
    torch.save(model_int8.state_dict(), 'model_int8.pth')


def load_int8_model(x):
    model_fp32 = SimpleLSTM()
    model_int8 = torch.quantization.quantize_dynamic(
        model_fp32,
        {nn.LSTM,nn.Linear},
        dtype=torch.qint8
    )

    model_int8.load_state_dict(torch.load('model_int8.pth',weights_only=False))
    model_int8.eval()
    y_int8 = model_int8(x)
    print("加载的INT8模型输出:", y_int8)
    return model_int8


if __name__ == '__main__':
    x = torch.randn(1, 10, 10)
    model_fp32 = SimpleLSTM()

    save_fp32_model(model_fp32,x)
    save_int8_model(model_fp32,x)

    load_fp32_model(x)
    load_int8_model(x)

静态量化:

import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore")

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 100, 1)
        self.conv1 = torch.nn.Conv2d(100, 100, 1)
        self.conv2 = torch.nn.Conv2d(100, 100, 1)
        self.conv3 = torch.nn.Conv2d(100, 1, 1)

        self.relu1 = torch.nn.ReLU()
        self.relu2 = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.dequant(x)
        return x


def save_fp32_model(model_fp32,x):
    model_fp32.eval()
    y = model_fp32(x)
    print("FP32模型输出:", y)
    torch.save(model_fp32.state_dict(), 'model_fp32.pth')
    torch.onnx.export(
        model_fp32,
        x,
        'model_fp32.onnx',
        input_names=['input'],
        output_names=['output'])


def load_fp32_model(x):
    model_fp32 = Model()
    model_fp32.load_state_dict(torch.load('model_fp32.pth'))
    model_fp32.eval()
    y_fp32 = model_fp32(x)
    print("加载的FP32模型输出:", y_fp32)
    return model_fp32


def save_int8_model(model_fp32,x):
    model_fp32.eval()  
    model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
    model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']])
    model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

    #calibration
    with torch.no_grad():  
        for i in range(100):  
            input_data = torch.randn(1, 1, 4, 4)          
            model_fp32_prepared(input_data)

    model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

    model_int8.eval()  
    y_int8 = model_int8(x)
    print("INT8模型输出:", y_int8)
    torch.save(model_int8.state_dict(), 'model_int8.pth')
    torch.onnx.export(
        model_int8,
        x,
        'model_int8.onnx',
        input_names=['input'],
        output_names=['output'])

def load_int8_model(x):
    model_fp32 = Model()
    model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
    model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']])
    model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
    model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
    model_int8.load_state_dict(torch.load('model_int8.pth'))
    model_int8.eval()
    y_int8 = model_int8(x)
    print("加载的INT8模型输出:", y_int8)
    return model_int8


if __name__ == '__main__':

    x = np.array([[0.1,0.2,0.3,0.4],
                  [0.5,0.6,0.7,0.8],
                  [0.9,0.1,0.2,0.3],
                  [0.4,0.5,0.6,0.7]], dtype=np.float32)
    x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)  
        
    model_fp32 = Model()

    save_fp32_model(model_fp32,x)
    save_int8_model(model_fp32,x)

    load_fp32_model(x)
    load_int8_model(x)
posted @ 2025-07-28 20:41  Dsp Tian  阅读(102)  评论(0)    收藏  举报