深度学习(pytorch量化)
pytorch中的动态量化和静态量化是两种主要的模型量化技术,旨在通过使用低精度数据类型(如 int8)代替高精度数据类型(如 float32)来减小模型大小并加速推理。
动态量化:在模型运行时(推理时)动态计算激活(activations)的量化参数(scale 和 zero_point)。权重(weights)的量化通常在模型加载时或第一次运行前进行。
静态量化:在模型部署之前,使用一个代表性的校准数据集(Calibration Dataset)预先确定网络中所有权重和所有激活的量化参数(scale 和 zero_point)。这些参数在推理过程中是固定的(静态的)。
部署时通常静态量化比较常用一些。下面给了两个例子,可以验证一下。
动态量化:
import torch import torch.nn as nn import warnings warnings.filterwarnings("ignore") # torch.serialization.add_safe_globals([torch.ScriptObject]) class SimpleLSTM(nn.Module): """简单的LSTM模型,适合动态量化""" def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers # LSTM层 self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True ) # 全连接层 self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): # 初始化隐藏状态 h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # LSTM前向传播 out, _ = self.lstm(x, (h0, c0)) # 只取最后一个时间步的输出 out = self.fc(out[:, -1, :]) return out def save_fp32_model(model_fp32, x): model_fp32.eval() y = model_fp32(x) print("FP32模型输出:", y) torch.save(model_fp32.state_dict(), 'model_fp32.pth') def load_fp32_model(x): model_fp32 = SimpleLSTM() model_fp32.load_state_dict(torch.load('model_fp32.pth')) model_fp32.eval() y_fp32 = model_fp32(x) print("加载的FP32模型输出:", y_fp32) return model_fp32 def save_int8_model(model_fp32, x): model_int8 = torch.quantization.quantize_dynamic( model_fp32, {nn.LSTM,nn.Linear}, dtype=torch.qint8 ) model_int8.eval() y_int8 = model_int8(x) print("INT8模型输出:", y_int8) torch.save(model_int8.state_dict(), 'model_int8.pth') def load_int8_model(x): model_fp32 = SimpleLSTM() model_int8 = torch.quantization.quantize_dynamic( model_fp32, {nn.LSTM,nn.Linear}, dtype=torch.qint8 ) model_int8.load_state_dict(torch.load('model_int8.pth',weights_only=False)) model_int8.eval() y_int8 = model_int8(x) print("加载的INT8模型输出:", y_int8) return model_int8 if __name__ == '__main__': x = torch.randn(1, 10, 10) model_fp32 = SimpleLSTM() save_fp32_model(model_fp32,x) save_int8_model(model_fp32,x) load_fp32_model(x) load_int8_model(x)
静态量化:
import torch import numpy as np import warnings warnings.filterwarnings("ignore") class Model(torch.nn.Module): def __init__(self): super().__init__() # QuantStub converts tensors from floating point to quantized self.quant = torch.ao.quantization.QuantStub() self.conv = torch.nn.Conv2d(1, 100, 1) self.conv1 = torch.nn.Conv2d(100, 100, 1) self.conv2 = torch.nn.Conv2d(100, 100, 1) self.conv3 = torch.nn.Conv2d(100, 1, 1) self.relu1 = torch.nn.ReLU() self.relu2 = torch.nn.ReLU() # DeQuantStub converts tensors from quantized to floating point self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.conv1(x) x = self.relu1(x) x = self.conv2(x) x = self.relu2(x) x = self.conv3(x) x = self.dequant(x) return x def save_fp32_model(model_fp32,x): model_fp32.eval() y = model_fp32(x) print("FP32模型输出:", y) torch.save(model_fp32.state_dict(), 'model_fp32.pth') torch.onnx.export( model_fp32, x, 'model_fp32.onnx', input_names=['input'], output_names=['output']) def load_fp32_model(x): model_fp32 = Model() model_fp32.load_state_dict(torch.load('model_fp32.pth')) model_fp32.eval() y_fp32 = model_fp32(x) print("加载的FP32模型输出:", y_fp32) return model_fp32 def save_int8_model(model_fp32,x): model_fp32.eval() model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86') model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']]) model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) #calibration with torch.no_grad(): for i in range(100): input_data = torch.randn(1, 1, 4, 4) model_fp32_prepared(input_data) model_int8 = torch.ao.quantization.convert(model_fp32_prepared) model_int8.eval() y_int8 = model_int8(x) print("INT8模型输出:", y_int8) torch.save(model_int8.state_dict(), 'model_int8.pth') torch.onnx.export( model_int8, x, 'model_int8.onnx', input_names=['input'], output_names=['output']) def load_int8_model(x): model_fp32 = Model() model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86') model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']]) model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) model_int8 = torch.ao.quantization.convert(model_fp32_prepared) model_int8.load_state_dict(torch.load('model_int8.pth')) model_int8.eval() y_int8 = model_int8(x) print("加载的INT8模型输出:", y_int8) return model_int8 if __name__ == '__main__': x = np.array([[0.1,0.2,0.3,0.4], [0.5,0.6,0.7,0.8], [0.9,0.1,0.2,0.3], [0.4,0.5,0.6,0.7]], dtype=np.float32) x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) model_fp32 = Model() save_fp32_model(model_fp32,x) save_int8_model(model_fp32,x) load_fp32_model(x) load_int8_model(x)