pytorch动态量化函数
PyTorch 动态量化 API
PyTorch 提供了丰富的动态量化 API,可以帮助开发者轻松地将模型转换为动态量化模型。主要 API 包括:
torch.quantization.quantize_dynamic:将模型转换为动态量化模型。
torch.quantization.QuantStub:观察模型层的输入和输出分布。
torch.quantization.Observer:收集模型层的统计信息。
torch.quantization.DeQuantStub:将定点结果转换回浮点数。
PyTorch torch.quantization.quantize_dynamic 函数详解
torch.quantization.quantize_dynamic 函数是 PyTorch 提供的用于动态量化模型的主要 API。该函数可以将浮点模型转换为动态量化模型,从而显著降低模型大小和提高推理速度。
函数定义
torch.quantization.quantize_dynamic(
model: torch.nn.Module,
qconfig: Dict[Type[torch.nn.Module], Dict],
dtype: torch.qscheme = torch.qint8
) -> torch.nn.Module
参数说明
model: 要转换的浮点模型。qconfig: 指定要量化的模块类型和量化配置。dtype: 指定量化的定点数据类型,可以是torch.qint8或torch.float16。
函数返回值
quantize_dynamic 函数返回一个新的动态量化模型,该模型与原始模型具有相同的架构和功能。
函数功能
quantize_dynamic 函数主要执行以下操作:
- 遍历模型中的每个模块。
- 对于每个模块,检查其类型是否在
qconfig中定义。 - 如果模块类型在
qconfig中定义,则根据qconfig中的配置对该模块进行动态量化。 - 将量化的模块替换到新的模型中。
动态量化配置
qconfig 参数用于指定要量化的模块类型和量化配置。qconfig 是一个字典,其中键是模块类型,值是量化配置字典。量化配置字典可以包含以下键:
- ``activation`: 指定激活的量化配置。
- ``weight`: 指定权重的量化配置。
- ``qscheme
: 指定量化方案,可以是torch.per_tensor或torch.per_channel`。 - ``dynamic`: 指定是否动态量化。
示例
以下是一个简单的示例,演示如何使用 quantize_dynamic 函数将模型转换为动态量化模型:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 1)
)
# 定义量化配置
qconfig = {
nn.Linear: {
'activation': {'dtype': torch.qint8},
'weight': {'dtype': torch.qint8}
}
}
# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
model,
qconfig,
dtype=torch.qint8
)
# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)
在这个示例中,我们定义了一个简单的模型,并使用 qconfig 参数指定了量化配置。qconfig 参数指示 quantize_dynamic 函数对模型中的所有 nn.Linear 模块进行动态量化,并将激活和权重量化为 torch.qint8 格式。
注意事项
在使用 torch.quantization.quantize_dynamic 函数时,需要注意以下几点:
- 动态量化可能会导致模型精度下降,需要根据具体情况权衡性能和精度。
- 动态量化目前还不支持所有模型类型和操作。
- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
PyTorch torch.quantization.QuantStub 模块详解
torch.quantization.QuantStub 模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。
模块定义
class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,.
It will be swapped as nnq.Quantize in convert .
Parameters:
qconfig(Dict): quantization configuration for the tensor, if qconfig is not
provided, we will use the global qconfig
"""
def __init__(self, qconfig=None):
super(QuantStub, self).__init__()
self.qconfig = qconfig
def forward(self, x):
return x
模块属性
qconfig: 量化配置字典。
模块方法
forward(x): 该方法只是简单地返回输入x,不做任何处理。
模块功能
QuantStub 模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,QuantStub 模块会被替换为 nnq.Quantize 模块,nnq.Quantize 模块会使用收集的统计信息对输入进行量化。
示例
以下是一个简单的示例,演示如何使用 QuantStub 模块观察模型层的输入和输出分布:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(10, 20),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.ReLU(),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(20, 1)
)
# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)
在这个示例中,我们为模型中的每个层都添加了 QuantStub 模块。QuantStub 模块会观察每个层的输入和输出分布,并收集统计信息。
注意事项
在使用 torch.quantization.QuantStub 模块时,需要注意以下几点:
QuantStub模块只用于观察模型层的输入和输出分布,不进行任何量化操作。QuantStub模块必须与torch.quantization.DeQuantStub模块搭配使用,才能完成动态量化。- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
PyTorch torch.quantization.Observer 模块详解
torch.quantization.Observer 模块是 PyTorch 提供的用于动态量化模型的观察模块。该模块可以观察模型层的输入和输出分布,并收集统计信息,为动态量化提供必要的数据支持。
模块定义
class Observer(nn.Module):
r"""
Observer module, which observes tensor quantization ranges for dynamic quantization.
It attaches to the downstream module to observe the output of the module
and records the min/max values for quantization.
Parameters:
dtype(torch.qscheme): quantization dtype, e.g torch.qint8
quant_scheme(torch.qscheme): quantization scheme, e.g torch.per_tensor or
torch.per_channel
"""
def __init__(self, dtype=torch.qint8, quant_scheme=torch.per_tensor):
super(Observer, self).__init__()
assert dtype in [
torch.qint8, torch.quint8, torch.bfloat16
], 'Only support torch.qint8, torch.quint8, torch.bfloat16 for now'
self.dtype = dtype
self.quant_scheme = quant_scheme
self.qmin = None
self.qmax = None
self._called_once = False
def forward(self, x):
r"""Calculates the min/max values for quantization.
Args:
x(torch.Tensor): The input tensor to observe.
Returns:
torch.Tensor: The input tensor.
"""
if not self._called_once:
self._called_once = True
if self.quant_scheme == torch.per_tensor:
self.qmin = x.min()
self.qmax = x.max()
elif self.quant_scheme == torch.per_channel:
self.qmin = x.data.min(dim=1)[0]
self.qmax = x.data.max(dim=1)[0]
else:
raise NotImplementedError
return x
模块属性
dtype: 量化数据类型,可以是torch.qint8、torch.quint8或torch.bfloat16。quant_scheme: 量化方案,可以是torch.per_tensor或torch.per_channel。qmin: 最小值。qmax: 最大值。
模块方法
forward(x): 该方法计算输入x的最小值和最大值,并将其存储在qmin和qmax属性中。
模块功能
Observer 模块主要用于观察模型层的输入和输出分布,并收集统计信息。在动态量化过程中,Observer 模块收集的统计信息将被用于计算量化参数,例如量化尺度和零点。
示例
以下是一个简单的示例,演示如何使用 Observer 模块观察模型层的输入和输出分布:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
nn.Linear(10, 20),
Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
nn.ReLU(),
Observer(dtype=torch.qint8, quant_scheme=torch.per_tensor),
nn.Linear(20, 1)
)
# 测试模型
input = torch.randn(1, 10)
output = model(input)
print(output)
在这个示例中,我们为模型中的每个层都添加了 Observer 模块。Observer 模块会观察每个层的输入和输出分布,并收集统计信息。
注意事项
在使用 torch.quantization.Observer 模块时,需要注意以下几点:
Observer模块只用于观察模型层的输入和输出分布,不进行任何量化操作。Observer模块必须与torch.quantization.DeQuantStub模块搭配使用,才能完成动态量化。- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
PyTorch torch.quantization.DeQuantStub 模块详解
torch.quantization.DeQuantStub 模块是 PyTorch 提供的用于动态量化模型的反量化模块。该模块可以将定点张量转换为浮点张量,从而恢复模型的精度。
模块定义
class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,.
It will be swapped as nnq.DeQuantize in convert .
Parameters:
qconfig(Dict): quantization configuration for the tensor, if qconfig is not
provided, we will use the global qconfig
"""
def __init__(self, qconfig=None):
super(DeQuantStub, self).__init__()
self.qconfig = qconfig
def forward(self, x):
return x
模块属性
qconfig: 量化配置字典。
模块方法
forward(x): 该方法只是简单地返回输入x,不做任何处理。
模块功能
DeQuantStub 模块主要用于将定点张量转换为浮点张量。在动态量化过程中,DeQuantStub 模块会被替换为 nnq.DeQuantize 模块,nnq.DeQuantize 模块会将定点张量转换为浮点张量,从而恢复模型的精度。
示例
以下是一个简单的示例,演示如何使用 DeQuantStub 模块将定点张量转换为浮点张量:
import torch
import torch.nn as nn
import torch.quantization
# 定义模型
model = nn.Sequential(
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(10, 20),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.ReLU(),
QuantStub(qconfig={'dtype': torch.qint8}),
nn.Linear(20, 1),
DeQuantStub(qconfig={'dtype': torch.qint8})
)
# 将模型转换为动态量化模型
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear: torch.quantization.QuantStub, nn.ReLU: torch.quantization.QuantStub},
dtype=torch.qint8
)
# 测试模型
input = torch.randn(1, 10)
output = quantized_model(input)
print(output)
在这个示例中,我们在模型的最后添加了一个 DeQuantStub 模块。DeQuantStub 模块会将模型输出的定点张量转换为浮点张量,从而恢复模型的精度。
注意事项
在使用 torch.quantization.DeQuantStub 模块时,需要注意以下几点:
DeQuantStub模块只用于将定点张量转换为浮点张量,不进行任何量化操作。DeQuantStub模块必须与torch.quantization.QuantStub模块搭配使用,才能完成动态量化。- 建议使用最新版本的 PyTorch 和 torchvision,以获得最佳性能和支持。
更多资源
- PyTorch 动态量化文档:https://pytorch.org/
- 动态量化教程:https://blog.csdn.net/lk142500/article/details/138860037
- PyTorch 量化感知训练示例:https://github.com/leimao/PyTorch-Quantization-Aware-Training

浙公网安备 33010602011771号